diff --git a/src/batdetect2/api/base.py b/src/batdetect2/api/base.py index f56b7b6..cdfc843 100644 --- a/src/batdetect2/api/base.py +++ b/src/batdetect2/api/base.py @@ -5,7 +5,7 @@ from soundevent import data from batdetect2.audio import build_audio_loader from batdetect2.config import BatDetect2Config -from batdetect2.evaluate import Evaluator, build_evaluator +from batdetect2.evaluate import build_evaluator, evaluate from batdetect2.models import Model, build_model from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor @@ -14,6 +14,7 @@ from batdetect2.train import train from batdetect2.train.lightning import load_model_from_checkpoint from batdetect2.typing import ( AudioLoader, + EvaluatorProtocol, PostprocessorProtocol, PreprocessorProtocol, TargetProtocol, @@ -28,7 +29,7 @@ class BatDetect2API: audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, - evaluator: Evaluator, + evaluator: EvaluatorProtocol, model: Model, ): self.config = config @@ -70,6 +71,27 @@ class BatDetect2API: ) return self + def evaluate( + self, + test_annotations: Sequence[data.ClipAnnotation], + num_workers: Optional[int] = None, + output_dir: data.PathLike = ".", + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, + ): + return evaluate( + self.model, + test_annotations, + targets=self.targets, + audio_loader=self.audio_loader, + preprocessor=self.preprocessor, + config=self.config, + num_workers=num_workers, + output_dir=output_dir, + experiment_name=experiment_name, + run_name=run_name, + ) + @classmethod def from_config(cls, config: BatDetect2Config): targets = build_targets(config=config.targets) @@ -118,8 +140,14 @@ class BatDetect2API: ) @classmethod - def from_checkpoint(cls, path: data.PathLike): - model, config = load_model_from_checkpoint(path) + def from_checkpoint( + cls, + path: data.PathLike, + config: Optional[BatDetect2Config] = None, + ): + model, stored_config = load_model_from_checkpoint(path) + + config = config or stored_config targets = build_targets(config=config.targets) diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index 20f30b6..6b1afe5 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -10,11 +10,17 @@ from batdetect2.cli.base import cli __all__ = ["evaluate_command"] +DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation" + + @cli.command(name="evaluate") @click.argument("model-path", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True)) -@click.option("--output-dir", type=click.Path()) -@click.option("--workers", type=int) +@click.option("--config", "config_path", type=click.Path()) +@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR) +@click.option("--experiment-name", type=str) +@click.option("--run-name", type=str) +@click.option("--workers", "num_workers", type=int) @click.option( "-v", "--verbose", @@ -24,13 +30,16 @@ __all__ = ["evaluate_command"] def evaluate_command( model_path: Path, test_dataset: Path, - output_dir: Optional[Path] = None, - workers: Optional[int] = None, + config_path: Optional[Path], + output_dir: Path = DEFAULT_OUTPUT_DIR, + num_workers: Optional[int] = None, + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, verbose: int = 0, ): + from batdetect2.api.base import BatDetect2API + from batdetect2.config import load_full_config from batdetect2.data import load_dataset_from_config - from batdetect2.evaluate.evaluate import evaluate - from batdetect2.train.lightning import load_model_from_checkpoint logger.remove() if verbose == 0: @@ -49,16 +58,16 @@ def evaluate_command( num_annotations=len(test_annotations), ) - model, train_config = load_model_from_checkpoint(model_path) + config = None + if config_path is not None: + config = load_full_config(config_path) - df, results = evaluate( - model, + api = BatDetect2API.from_checkpoint(model_path, config=config) + + api.evaluate( test_annotations, - config=train_config, - num_workers=workers, + num_workers=num_workers, + output_dir=output_dir, + experiment_name=experiment_name, + run_name=run_name, ) - - print(results) - - if output_dir: - df.to_csv(output_dir / "results.csv") diff --git a/src/batdetect2/core/registries.py b/src/batdetect2/core/registries.py index 2535b82..9bd2b09 100644 --- a/src/batdetect2/core/registries.py +++ b/src/batdetect2/core/registries.py @@ -2,6 +2,7 @@ import sys from typing import Generic, Protocol, Type, TypeVar from pydantic import BaseModel +from typing_extensions import assert_type if sys.version_info >= (3, 10): from typing import ParamSpec @@ -44,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]): config_cls: Type[T_Config], logic_cls: LogicProtocol[T_Config, T_Type, P_Type], ) -> None: - """A decorator factory to register a new item.""" fields = config_cls.model_fields if "name" not in fields: diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index 211edf9..3e02ed0 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,9 +1,11 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config +from batdetect2.evaluate.evaluate import evaluate from batdetect2.evaluate.evaluator import Evaluator, build_evaluator __all__ = [ "EvaluationConfig", "load_evaluation_config", + "evaluate", "Evaluator", "build_evaluator", ] diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index 90e3f14..2ed5bf3 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -11,6 +11,7 @@ from batdetect2.evaluate.metrics import ( MetricConfig, ) from batdetect2.evaluate.plots import PlotConfig +from batdetect2.logging import CSVLoggerConfig, LoggerConfig __all__ = [ "EvaluationConfig", @@ -28,6 +29,7 @@ class EvaluationConfig(BaseConfig): ] ) plots: List[PlotConfig] = Field(default_factory=list) + logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) def load_evaluation_config( diff --git a/src/batdetect2/evaluate/dataset.py b/src/batdetect2/evaluate/dataset.py new file mode 100644 index 0000000..fb3458c --- /dev/null +++ b/src/batdetect2/evaluate/dataset.py @@ -0,0 +1,144 @@ +from typing import List, NamedTuple, Optional, Sequence + +import torch +from loguru import logger +from pydantic import Field +from soundevent import data +from torch.utils.data import DataLoader, Dataset + +from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper +from batdetect2.audio.clips import PaddedClipConfig +from batdetect2.core import BaseConfig +from batdetect2.core.arrays import adjust_width +from batdetect2.preprocess import build_preprocessor +from batdetect2.typing import ( + AudioLoader, + ClipperProtocol, + PreprocessorProtocol, +) + +__all__ = [ + "TestDataset", + "build_test_dataset", + "build_test_loader", +] + + +class TestExample(NamedTuple): + spec: torch.Tensor + idx: torch.Tensor + start_time: torch.Tensor + end_time: torch.Tensor + + +class TestDataset(Dataset[TestExample]): + clip_annotations: List[data.ClipAnnotation] + + def __init__( + self, + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + clipper: Optional[ClipperProtocol] = None, + audio_dir: Optional[data.PathLike] = None, + ): + self.clip_annotations = list(clip_annotations) + self.clipper = clipper + self.preprocessor = preprocessor + self.audio_loader = audio_loader + self.audio_dir = audio_dir + + def __len__(self): + return len(self.clip_annotations) + + def __getitem__(self, idx: int) -> TestExample: + clip_annotation = self.clip_annotations[idx] + + if self.clipper is not None: + clip_annotation = self.clipper(clip_annotation) + + clip = clip_annotation.clip + wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) + wav_tensor = torch.tensor(wav).unsqueeze(0) + spectrogram = self.preprocessor(wav_tensor) + return TestExample( + spec=spectrogram, + idx=torch.tensor(idx), + start_time=torch.tensor(clip.start_time), + end_time=torch.tensor(clip.end_time), + ) + + +class TestLoaderConfig(BaseConfig): + num_workers: int = 0 + clipping_strategy: ClipConfig = Field( + default_factory=lambda: PaddedClipConfig() + ) + + +def build_test_loader( + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: Optional[AudioLoader] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + config: Optional[TestLoaderConfig] = None, + num_workers: Optional[int] = None, +) -> DataLoader[TestExample]: + logger.info("Building test data loader...") + config = config or TestLoaderConfig() + logger.opt(lazy=True).debug( + "Test data loader config: \n{config}", + config=lambda: config.to_yaml_string(exclude_none=True), + ) + + test_dataset = build_test_dataset( + clip_annotations, + audio_loader=audio_loader, + preprocessor=preprocessor, + config=config, + ) + + num_workers = num_workers or config.num_workers + return DataLoader( + test_dataset, + batch_size=1, + shuffle=False, + num_workers=num_workers, + collate_fn=_collate_fn, + ) + + +def build_test_dataset( + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: Optional[AudioLoader] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + config: Optional[TestLoaderConfig] = None, +) -> TestDataset: + logger.info("Building training dataset...") + config = config or TestLoaderConfig() + + clipper = build_clipper(config=config.clipping_strategy) + + if audio_loader is None: + audio_loader = build_audio_loader() + + if preprocessor is None: + preprocessor = build_preprocessor() + + return TestDataset( + clip_annotations, + audio_loader=audio_loader, + clipper=clipper, + preprocessor=preprocessor, + ) + + +def _collate_fn(batch: List[TestExample]) -> TestExample: + max_width = max(item.spec.shape[-1] for item in batch) + return TestExample( + spec=torch.stack( + [adjust_width(item.spec, max_width) for item in batch] + ), + idx=torch.stack([item.idx for item in batch]), + start_time=torch.stack([item.start_time for item in batch]), + end_time=torch.stack([item.end_time for item in batch]), + ) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 1a70b76..2fd723f 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -1,35 +1,44 @@ -from typing import List, Optional, Tuple +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Sequence -import pandas as pd +from lightning import Trainer from soundevent import data from batdetect2.audio import build_audio_loader -from batdetect2.evaluate.config import EvaluationConfig -from batdetect2.evaluate.dataframe import extract_matches_dataframe +from batdetect2.evaluate.dataset import build_test_loader from batdetect2.evaluate.evaluator import build_evaluator -from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP +from batdetect2.evaluate.lightning import EvaluationModule +from batdetect2.logging import build_logger from batdetect2.models import Model -from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol -from batdetect2.postprocess import get_raw_predictions from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets -from batdetect2.train.dataset import ValidationDataset -from batdetect2.train.labels import build_clip_labeler -from batdetect2.train.train import build_val_loader -from batdetect2.typing import ClipLabeller, TargetProtocol + +if TYPE_CHECKING: + from batdetect2.config import BatDetect2Config + from batdetect2.typing import ( + AudioLoader, + PreprocessorProtocol, + TargetProtocol, + ) + +DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations" def evaluate( model: Model, - test_annotations: List[data.ClipAnnotation], - targets: Optional[TargetProtocol] = None, - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - labeller: Optional[ClipLabeller] = None, - config: Optional[EvaluationConfig] = None, + test_annotations: Sequence[data.ClipAnnotation], + targets: Optional["TargetProtocol"] = None, + audio_loader: Optional["AudioLoader"] = None, + preprocessor: Optional["PreprocessorProtocol"] = None, + config: Optional["BatDetect2Config"] = None, num_workers: Optional[int] = None, -) -> Tuple[pd.DataFrame, dict]: - config = config or EvaluationConfig() + output_dir: data.PathLike = DEFAULT_OUTPUT_DIR, + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, +): + from batdetect2.config import BatDetect2Config + + config = config or BatDetect2Config() audio_loader = audio_loader or build_audio_loader() @@ -39,60 +48,21 @@ def evaluate( targets = targets or build_targets() - labeller = labeller or build_clip_labeler( - targets, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - ) - - loader = build_val_loader( + loader = build_test_loader( test_annotations, audio_loader=audio_loader, - labeller=labeller, preprocessor=preprocessor, num_workers=num_workers, ) - dataset: ValidationDataset = loader.dataset # type: ignore + evaluator = build_evaluator(config=config.evaluation, targets=targets) - clip_annotations = [] - predictions = [] - - evaluator = build_evaluator(config=config, targets=targets) - - for batch in loader: - outputs = model.detector(batch.spec) - - clip_annotations = [ - dataset.clip_annotations[int(example_idx)] - for example_idx in batch.idx - ] - - predictions = get_raw_predictions( - outputs, - start_times=[ - clip_annotation.clip.start_time - for clip_annotation in clip_annotations - ], - targets=targets, - postprocessor=model.postprocessor, - ) - - clip_annotations.extend(clip_annotations) - predictions.extend(predictions) - - matches = evaluator.evaluate(clip_annotations, predictions) - df = extract_matches_dataframe(matches) - - metrics = [ - DetectionAP(), - ClassificationAP(class_names=targets.class_names), - ] - - results = { - name: value - for metric in metrics - for name, value in metric(matches).items() - } - - return df, results + logger = build_logger( + config.evaluation.logger, + log_dir=Path(output_dir), + experiment_name=experiment_name, + run_name=run_name, + ) + module = EvaluationModule(model, evaluator) + trainer = Trainer(logger=logger, enable_checkpointing=False) + return trainer.test(module, loader) diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index f0bca83..dbbbee2 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter from batdetect2.targets import build_targets from batdetect2.typing.evaluate import ( ClipEvaluation, + EvaluatorProtocol, MatcherProtocol, MetricsProtocol, PlotterProtocol, @@ -135,7 +136,7 @@ def build_evaluator( matcher: Optional[MatcherProtocol] = None, plots: Optional[List[PlotterProtocol]] = None, metrics: Optional[List[MetricsProtocol]] = None, -) -> Evaluator: +) -> EvaluatorProtocol: config = config or EvaluationConfig() targets = targets or build_targets() matcher = matcher or build_matcher(config.match_strategy) diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py new file mode 100644 index 0000000..625e7fe --- /dev/null +++ b/src/batdetect2/evaluate/lightning.py @@ -0,0 +1,86 @@ +from typing import Sequence + +from lightning import LightningModule +from torch.utils.data import DataLoader + +from batdetect2.evaluate.dataset import TestDataset, TestExample +from batdetect2.evaluate.tables import FullEvaluationTable +from batdetect2.logging import get_image_logger, get_table_logger +from batdetect2.models import Model +from batdetect2.postprocess import to_raw_predictions +from batdetect2.typing import ClipEvaluation, EvaluatorProtocol + + +class EvaluationModule(LightningModule): + def __init__( + self, + model: Model, + evaluator: EvaluatorProtocol, + ): + super().__init__() + + self.model = model + self.evaluator = evaluator + + self.clip_evaluations = [] + + def test_step(self, batch: TestExample): + dataset = self.get_dataset() + clip_annotations = [ + dataset.clip_annotations[int(example_idx)] + for example_idx in batch.idx + ] + + outputs = self.model.detector(batch.spec) + clip_detections = self.model.postprocessor( + outputs, + start_times=[ca.clip.start_time for ca in clip_annotations], + ) + predictions = [ + to_raw_predictions( + clip_dets.numpy(), + targets=self.evaluator.targets, + ) + for clip_dets in clip_detections + ] + + self.clip_evaluations.extend( + self.evaluator.evaluate(clip_annotations, predictions) + ) + + def on_test_epoch_start(self): + self.clip_evaluations = [] + + def on_test_epoch_end(self): + self.log_metrics(self.clip_evaluations) + self.plot_examples(self.clip_evaluations) + self.log_table(self.clip_evaluations) + + def log_table(self, evaluated_clips: Sequence[ClipEvaluation]): + table_logger = get_table_logger(self.logger) # type: ignore + + if table_logger is None: + return + + df = FullEvaluationTable()(evaluated_clips) + table_logger("full_evaluation", df, 0) + + def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]): + plotter = get_image_logger(self.logger) # type: ignore + + if plotter is None: + return + + for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): + plotter(figure_name, fig, self.global_step) + + def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]): + metrics = self.evaluator.compute_metrics(evaluated_clips) + self.log_dict(metrics) + + def get_dataset(self) -> TestDataset: + dataloaders = self.trainer.test_dataloaders + assert isinstance(dataloaders, DataLoader) + dataset = dataloaders.dataset + assert isinstance(dataset, TestDataset) + return dataset diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index ea29ba4..efa024a 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -15,10 +15,8 @@ import numpy as np from pydantic import Field from sklearn import metrics, preprocessing -from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry -from batdetect2.typing import MetricsProtocol -from batdetect2.typing.evaluate import ClipEvaluation +from batdetect2.core import BaseConfig, Registry +from batdetect2.typing import ClipEvaluation, MetricsProtocol __all__ = ["DetectionAP", "ClassificationAP"] @@ -375,14 +373,96 @@ metrics_registry.register( ) -class ClipAPConfig(BaseConfig): - name: Literal["clip_ap"] = "clip_ap" +class ClipDetectionAPConfig(BaseConfig): + name: Literal["clip_detection_ap"] = "clip_detection_ap" + ap_implementation: APImplementation = "pascal_voc" + + +class ClipDetectionAP(MetricsProtocol): + def __init__( + self, + implementation: APImplementation, + ): + self.implementation = implementation + self.metric = _ap_impl_mapping[self.implementation] + + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + clip_det = [] + clip_scores = [] + + for match in clip_eval.matches: + clip_det.append(match.gt_det) + clip_scores.append(match.pred_score) + + y_true.append(any(clip_det)) + y_score.append(max(clip_scores or [0])) + + return {"clip_detection_ap": self.metric(y_true, y_score)} + + @classmethod + def from_config( + cls, + config: ClipDetectionAPConfig, + class_names: List[str], + ): + return cls(implementation=config.ap_implementation) + + +metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP) + + +class ClipDetectionROCAUCConfig(BaseConfig): + name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc" + + +class ClipDetectionROCAUC(MetricsProtocol): + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + clip_det = [] + clip_scores = [] + + for match in clip_eval.matches: + clip_det.append(match.gt_det) + clip_scores.append(match.pred_score) + + y_true.append(any(clip_det)) + y_score.append(max(clip_scores or [0])) + + return { + "clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score)) + } + + @classmethod + def from_config( + cls, + config: ClipDetectionROCAUCConfig, + class_names: List[str], + ): + return cls() + + +metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC) + + +class ClipMulticlassAPConfig(BaseConfig): + name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap" ap_implementation: APImplementation = "pascal_voc" include: Optional[List[str]] = None exclude: Optional[List[str]] = None -class ClipAP(MetricsProtocol): +class ClipMulticlassAP(MetricsProtocol): def __init__( self, class_names: List[str], @@ -454,15 +534,17 @@ class ClipAP(MetricsProtocol): [value for value in class_scores.values() if value != 0] ) return { - "clip_mAP": float(mean_ap), + "clip_multiclass_mAP": float(mean_ap), **{ - f"clip_AP/{class_name}": class_scores[class_name] + f"clip_multiclass_AP/{class_name}": class_scores[class_name] for class_name in self.selected }, } @classmethod - def from_config(cls, config: ClipAPConfig, class_names: List[str]): + def from_config( + cls, config: ClipMulticlassAPConfig, class_names: List[str] + ): return cls( implementation=config.ap_implementation, include=config.include, @@ -471,16 +553,16 @@ class ClipAP(MetricsProtocol): ) -metrics_registry.register(ClipAPConfig, ClipAP) +metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP) -class ClipROCAUCConfig(BaseConfig): - name: Literal["clip_roc_auc"] = "clip_roc_auc" +class ClipMulticlassROCAUCConfig(BaseConfig): + name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc" include: Optional[List[str]] = None exclude: Optional[List[str]] = None -class ClipROCAUC(MetricsProtocol): +class ClipMulticlassROCAUC(MetricsProtocol): def __init__( self, class_names: List[str], @@ -548,9 +630,11 @@ class ClipROCAUC(MetricsProtocol): [value for value in class_scores.values() if value != 0] ) return { - "clip_macro_ROC_AUC": float(mean_roc_auc), + "clip_multiclass_macro_ROC_AUC": float(mean_roc_auc), **{ - f"clip_ROC_AUC/{class_name}": class_scores[class_name] + f"clip_multiclass_ROC_AUC/{class_name}": class_scores[ + class_name + ] for class_name in self.selected }, } @@ -558,7 +642,7 @@ class ClipROCAUC(MetricsProtocol): @classmethod def from_config( cls, - config: ClipROCAUCConfig, + config: ClipMulticlassROCAUCConfig, class_names: List[str], ): return cls( @@ -568,7 +652,7 @@ class ClipROCAUC(MetricsProtocol): ) -metrics_registry.register(ClipROCAUCConfig, ClipROCAUC) +metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC) MetricConfig = Annotated[ Union[ @@ -578,8 +662,10 @@ MetricConfig = Annotated[ ClassificationROCAUCConfig, TopClassAPConfig, ClassificationBalancedAccuracyConfig, - ClipAPConfig, - ClipROCAUCConfig, + ClipDetectionAPConfig, + ClipDetectionROCAUCConfig, + ClipMulticlassAPConfig, + ClipMulticlassROCAUCConfig, ], Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots.py b/src/batdetect2/evaluate/plots.py index 6680c2e..5ca5092 100644 --- a/src/batdetect2/evaluate/plots.py +++ b/src/batdetect2/evaluate/plots.py @@ -10,19 +10,18 @@ from pydantic import Field from sklearn import metrics from sklearn.preprocessing import label_binarize -from batdetect2.audio import AudioConfig -from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry -from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader +from batdetect2.audio import AudioConfig, build_audio_loader +from batdetect2.core import BaseConfig, Registry from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.matches import plot_matches from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.typing.evaluate import ( +from batdetect2.typing import ( + AudioLoader, ClipEvaluation, MatchEvaluation, PlotterProtocol, + PreprocessorProtocol, ) -from batdetect2.typing.preprocess import AudioLoader __all__ = [ "build_plotter", @@ -431,6 +430,62 @@ class ClassificationROCCurves(PlotterProtocol): plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves) +class ConfusionMatrixConfig(BaseConfig): + name: Literal["confusion_matrix"] = "confusion_matrix" + background_class: str = "noise" + + +class ConfusionMatrix(PlotterProtocol): + def __init__(self, background_class: str, class_names: List[str]): + self.background_class = background_class + self.class_names = class_names + + def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + for match in clip_eval.matches: + # Ignore generic unclassified targets + if match.gt_det and match.gt_class is None: + continue + + y_true.append( + match.gt_class + if match.gt_class is not None + else self.background_class + ) + + top_class = match.pred_class + y_pred.append( + top_class + if top_class is not None + else self.background_class + ) + + display = metrics.ConfusionMatrixDisplay.from_predictions( + y_true, + y_pred, + labels=[*self.class_names, self.background_class], + ) + + yield "confusion_matrix", display.figure_ + + @classmethod + def from_config( + cls, + config: ConfusionMatrixConfig, + class_names: List[str], + ): + return cls( + background_class=config.background_class, + class_names=class_names, + ) + + +plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix) + + PlotConfig = Annotated[ Union[ ExampleGalleryConfig, @@ -439,6 +494,7 @@ PlotConfig = Annotated[ ClassificationPRCurvesConfig, DetectionROCCurveConfig, ClassificationROCCurvesConfig, + ConfusionMatrixConfig, ], Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/dataframe.py b/src/batdetect2/evaluate/tables.py similarity index 60% rename from src/batdetect2/evaluate/dataframe.py rename to src/batdetect2/evaluate/tables.py index 7398d34..9e36dbf 100644 --- a/src/batdetect2/evaluate/dataframe.py +++ b/src/batdetect2/evaluate/tables.py @@ -1,18 +1,49 @@ -from typing import List +from typing import Annotated, Callable, Literal, Sequence, Union import pandas as pd +from pydantic import Field from soundevent.geometry import compute_bounds -from batdetect2.typing.evaluate import ClipEvaluation +from batdetect2.core import BaseConfig, Registry +from batdetect2.typing import ClipEvaluation + +EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame] -def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame: +tables_registry: Registry[EvaluationTableGenerator, []] = Registry( + "evaluation_table" +) + + +class FullEvaluationTableConfig(BaseConfig): + name: Literal["full_evaluation"] = "full_evaluation" + + +class FullEvaluationTable: + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> pd.DataFrame: + return extract_matches_dataframe(clip_evaluations) + + @classmethod + def from_config(cls, config: FullEvaluationTableConfig): + return cls() + + +tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable) + + +def extract_matches_dataframe( + clip_evaluations: Sequence[ClipEvaluation], +) -> pd.DataFrame: data = [] for clip_evaluation in clip_evaluations: for match in clip_evaluation.matches: gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None - pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None + pred_start_time = pred_low_freq = pred_end_time = ( + pred_high_freq + ) = None sound_event_annotation = match.sound_event_annotation @@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data ) if match.pred_geometry is not None: - pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = ( - compute_bounds(match.pred_geometry) - ) + ( + pred_start_time, + pred_low_freq, + pred_end_time, + pred_high_freq, + ) = compute_bounds(match.pred_geometry) data.append( { @@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data df = pd.DataFrame(data) df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore return df + + +EvaluationTableConfig = Annotated[ + Union[FullEvaluationTableConfig,], Field(discriminator="name") +] + + +def build_table_generator( + config: EvaluationTableConfig, +) -> EvaluationTableGenerator: + return tables_registry.build(config) diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index b28d3a8..1a74eb9 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -67,7 +67,7 @@ from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets from batdetect2.typing.models import DetectionModel from batdetect2.typing.postprocess import ( - DetectionsTensor, + ClipDetectionsTensor, PostprocessorProtocol, ) from batdetect2.typing.preprocess import PreprocessorProtocol @@ -121,7 +121,7 @@ class Model(torch.nn.Module): self.postprocessor = postprocessor self.targets = targets - def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]: + def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]: spec = self.preprocessor(wav) outputs = self.detector(spec) return self.postprocessor(outputs) diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index c4ad923..8497c00 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -8,13 +8,10 @@ from batdetect2.postprocess.decoding import ( convert_raw_predictions_to_clip_prediction, to_raw_predictions, ) -from batdetect2.postprocess.nms import ( - non_max_suppression, -) +from batdetect2.postprocess.nms import non_max_suppression from batdetect2.postprocess.postprocessor import ( Postprocessor, build_postprocessor, - get_raw_predictions, ) __all__ = [ @@ -25,5 +22,4 @@ __all__ = [ "to_raw_predictions", "load_postprocess_config", "non_max_suppression", - "get_raw_predictions", ] diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 6802aa8..e7e1635 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -6,7 +6,7 @@ import numpy as np from soundevent import data from batdetect2.typing.postprocess import ( - DetectionsArray, + ClipDetectionsArray, RawPrediction, ) from batdetect2.typing.targets import TargetProtocol @@ -28,7 +28,7 @@ decoding. def to_raw_predictions( - detections: DetectionsArray, + detections: ClipDetectionsArray, targets: TargetProtocol, ) -> List[RawPrediction]: predictions = [] diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index 29e981b..a1b5a65 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates all extracted information into a structured `xarray.Dataset`. """ -from typing import List, Optional, Tuple, Union +from typing import List, Optional import torch -from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression -from batdetect2.typing.postprocess import ( - DetectionsTensor, - ModelOutput, -) +from batdetect2.typing.postprocess import ClipDetectionsTensor __all__ = [ - "extract_prediction_tensor", + "extract_detection_peaks", ] -def extract_prediction_tensor( - output: ModelOutput, +def extract_detection_peaks( + detection_heatmap: torch.Tensor, + size_heatmap: torch.Tensor, + feature_heatmap: torch.Tensor, + classification_heatmap: torch.Tensor, max_detections: int = 200, threshold: Optional[float] = None, - nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, -) -> List[DetectionsTensor]: - detection_heatmap = non_max_suppression( - output.detection_probs.detach(), - kernel_size=nms_kernel_size, - ) - +) -> List[ClipDetectionsTensor]: height = detection_heatmap.shape[-2] width = detection_heatmap.shape[-1] @@ -53,9 +46,9 @@ def extract_prediction_tensor( freqs = freqs.flatten().to(detection_heatmap.device) times = times.flatten().to(detection_heatmap.device) - output_size_preds = output.size_preds.detach() - output_features = output.features.detach() - output_class_probs = output.class_probs.detach() + output_size_preds = size_heatmap.detach() + output_features = feature_heatmap.detach() + output_class_probs = classification_heatmap.detach() predictions = [] for idx, item in enumerate(detection_heatmap): @@ -65,23 +58,25 @@ def extract_prediction_tensor( detection_scores = item.take(indices) detection_freqs = freqs.take(indices) detection_times = times.take(indices) - sizes = output_size_preds[idx, :, detection_freqs, detection_times].T - features = output_features[idx, :, detection_freqs, detection_times].T - class_scores = output_class_probs[ - idx, :, detection_freqs, detection_times - ].T if threshold is not None: mask = detection_scores >= threshold + detection_scores = detection_scores[mask] - sizes = sizes[mask] detection_times = detection_times[mask] detection_freqs = detection_freqs[mask] - features = features[mask] - class_scores = class_scores[mask] + + sizes = output_size_preds[idx, :, detection_freqs, detection_times].T + features = output_features[idx, :, detection_freqs, detection_times].T + class_scores = output_class_probs[ + idx, + :, + detection_freqs, + detection_times, + ].T predictions.append( - DetectionsTensor( + ClipDetectionsTensor( scores=detection_scores, sizes=sizes, features=features, diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index 5c13bd3..677b315 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -1,29 +1,20 @@ -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch from loguru import logger -from soundevent import data from batdetect2.postprocess.config import ( PostprocessConfig, ) -from batdetect2.postprocess.decoding import ( - DEFAULT_CLASSIFICATION_THRESHOLD, - convert_raw_prediction_to_sound_event_prediction, - convert_raw_predictions_to_clip_prediction, - to_raw_predictions, -) -from batdetect2.postprocess.extraction import extract_prediction_tensor +from batdetect2.postprocess.extraction import extract_detection_peaks +from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression from batdetect2.postprocess.remapping import map_detection_to_clip from batdetect2.typing import ModelOutput from batdetect2.typing.postprocess import ( - BatDetect2Prediction, - DetectionsTensor, + ClipDetectionsTensor, PostprocessorProtocol, - RawPrediction, ) from batdetect2.typing.preprocess import PreprocessorProtocol -from batdetect2.typing.targets import TargetProtocol __all__ = [ "build_postprocessor", @@ -60,24 +51,43 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): max_freq: float, top_k_per_sec: int = 200, detection_threshold: float = 0.01, + nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, ): """Initialize the Postprocessor.""" super().__init__() - self.samplerate = samplerate + + self.output_samplerate = samplerate self.min_freq = min_freq self.max_freq = max_freq self.top_k_per_sec = top_k_per_sec self.detection_threshold = detection_threshold + self.nms_kernel_size = nms_kernel_size + + def forward( + self, + output: ModelOutput, + start_times: Optional[List[float]] = None, + ) -> List[ClipDetectionsTensor]: + detection_heatmap = non_max_suppression( + output.detection_probs.detach(), + kernel_size=self.nms_kernel_size, + ) - def forward(self, output: ModelOutput) -> List[DetectionsTensor]: width = output.detection_probs.shape[-1] - duration = width / self.samplerate + duration = width / self.output_samplerate max_detections = int(self.top_k_per_sec * duration) - detections = extract_prediction_tensor( - output, + detections = extract_detection_peaks( + detection_heatmap, + size_heatmap=output.size_preds, + feature_heatmap=output.features, + classification_heatmap=output.class_probs, max_detections=max_detections, threshold=self.detection_threshold, ) + + if start_times is None: + start_times = [0 for _ in range(len(detections))] + return [ map_detection_to_clip( detection, @@ -88,121 +98,3 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): ) for detection in detections ] - - def get_detections( - self, - output: ModelOutput, - start_times: Optional[List[float]] = None, - ) -> List[DetectionsTensor]: - width = output.detection_probs.shape[-1] - duration = width / self.samplerate - max_detections = int(self.top_k_per_sec * duration) - - detections = extract_prediction_tensor( - output, - max_detections=max_detections, - threshold=self.detection_threshold, - ) - - if start_times is None: - return detections - - width = output.detection_probs.shape[-1] - duration = width / self.samplerate - return [ - map_detection_to_clip( - detection, - start_time=start_time, - end_time=start_time + duration, - min_freq=self.min_freq, - max_freq=self.max_freq, - ) - for detection, start_time in zip(detections, start_times) - ] - - -def get_raw_predictions( - output: ModelOutput, - targets: TargetProtocol, - postprocessor: PostprocessorProtocol, - start_times: Optional[List[float]] = None, -) -> List[List[RawPrediction]]: - """Extract intermediate RawPrediction objects for a batch.""" - detections = postprocessor.get_detections(output, start_times) - return [ - to_raw_predictions(detection.numpy(), targets=targets) - for detection in detections - ] - - -def get_sound_event_predictions( - output: ModelOutput, - targets: TargetProtocol, - postprocessor: PostprocessorProtocol, - clips: List[data.Clip], - classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, -) -> List[List[BatDetect2Prediction]]: - raw_predictions = get_raw_predictions( - output, - targets=targets, - postprocessor=postprocessor, - start_times=[clip.start_time for clip in clips], - ) - return [ - [ - BatDetect2Prediction( - raw=raw, - sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( - raw, - recording=clip.recording, - targets=targets, - classification_threshold=classification_threshold, - ), - ) - for raw in predictions - ] - for predictions, clip in zip(raw_predictions, clips) - ] - - -def get_predictions( - output: ModelOutput, - clips: List[data.Clip], - targets: TargetProtocol, - postprocessor: PostprocessorProtocol, - classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, -) -> List[data.ClipPrediction]: - """Perform the full postprocessing pipeline for a batch. - - Takes raw model output and corresponding clips, applies the entire - configured chain (NMS, remapping, extraction, geometry recovery, class - decoding), producing final `soundevent.data.ClipPrediction` objects. - - Parameters - ---------- - output : ModelOutput - Raw output from the neural network model for a batch. - clips : List[data.Clip] - List of `soundevent.data.Clip` objects corresponding to the batch. - - Returns - ------- - List[data.ClipPrediction] - List containing one `ClipPrediction` object for each input clip, - populated with `SoundEventPrediction` objects. - """ - raw_predictions = get_raw_predictions( - output, - targets=targets, - postprocessor=postprocessor, - start_times=[clip.start_time for clip in clips], - ) - return [ - convert_raw_predictions_to_clip_prediction( - prediction, - clip, - targets=targets, - classification_threshold=classification_threshold, - ) - for prediction, clip in zip(raw_predictions, clips) - ] diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py index 0fc5db7..1168a96 100644 --- a/src/batdetect2/postprocess/remapping.py +++ b/src/batdetect2/postprocess/remapping.py @@ -20,7 +20,7 @@ import xarray as xr from soundevent.arrays import Dimensions from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.typing.postprocess import DetectionsTensor +from batdetect2.typing.postprocess import ClipDetectionsTensor __all__ = [ "features_to_xarray", @@ -31,15 +31,15 @@ __all__ = [ def map_detection_to_clip( - detections: DetectionsTensor, + detections: ClipDetectionsTensor, start_time: float, end_time: float, min_freq: float, max_freq: float, -) -> DetectionsTensor: +) -> ClipDetectionsTensor: duration = end_time - start_time bandwidth = max_freq - min_freq - return DetectionsTensor( + return ClipDetectionsTensor( scores=detections.scores, sizes=detections.sizes, features=detections.features, diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 7a7207a..029a90a 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -14,7 +14,6 @@ from batdetect2.train.augmentations import ( scale_volume, warp_spectrogram, ) -from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.config import ( PLTrainerConfig, TrainingConfig, @@ -61,7 +60,6 @@ __all__ = [ "add_echo", "build_augmentations", "build_clip_labeler", - "build_clipper", "build_loss", "build_train_dataset", "build_train_loader", @@ -74,7 +72,6 @@ __all__ = [ "mask_time", "mix_audio", "scale_volume", - "select_subclip", "train", "warp_spectrogram", ] diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index c958c9d..172c5ed 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -11,9 +11,9 @@ from pydantic import Field from soundevent import data from soundevent.geometry import scale_geometry, shift_geometry +from batdetect2.audio.clips import get_subclip_annotation from batdetect2.core.arrays import adjust_width from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.train.clips import get_subclip_annotation from batdetect2.typing import AudioLoader, Augmentation __all__ = [ diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 3a0a9a0..19b9751 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -5,13 +5,13 @@ from lightning.pytorch.callbacks import Callback from soundevent import data from torch.utils.data import DataLoader -from batdetect2.evaluate import Evaluator -from batdetect2.postprocess import get_raw_predictions +from batdetect2.logging import get_image_logger +from batdetect2.postprocess import to_raw_predictions from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule -from batdetect2.train.logging import get_image_plotter from batdetect2.typing import ( ClipEvaluation, + EvaluatorProtocol, ModelOutput, RawPrediction, TrainExample, @@ -19,7 +19,7 @@ from batdetect2.typing import ( class ValidationMetrics(Callback): - def __init__(self, evaluator: Evaluator): + def __init__(self, evaluator: EvaluatorProtocol): super().__init__() self.evaluator = evaluator @@ -34,12 +34,12 @@ class ValidationMetrics(Callback): assert isinstance(dataset, ValidationDataset) return dataset - def plot_examples( + def generate_plots( self, pl_module: LightningModule, evaluated_clips: List[ClipEvaluation], ): - plotter = get_image_plotter(pl_module.logger) # type: ignore + plotter = get_image_logger(pl_module.logger) # type: ignore if plotter is None: return @@ -66,7 +66,7 @@ class ValidationMetrics(Callback): ) self.log_metrics(pl_module, clip_evaluations) - self.plot_examples(pl_module, clip_evaluations) + self.generate_plots(pl_module, clip_evaluations) return super().on_validation_epoch_end(trainer, pl_module) @@ -88,8 +88,7 @@ class ValidationMetrics(Callback): batch_idx: int, dataloader_idx: int = 0, ) -> None: - postprocessor = pl_module.model.postprocessor - targets = pl_module.model.targets + model = pl_module.model dataset = self.get_dataset(trainer) clip_annotations = [ @@ -97,15 +96,14 @@ class ValidationMetrics(Callback): for example_idx in batch.idx ] - predictions = get_raw_predictions( + clip_detections = model.postprocessor( outputs, - start_times=[ - clip_annotation.clip.start_time - for clip_annotation in clip_annotations - ], - targets=targets, - postprocessor=postprocessor, + start_times=[ca.clip.start_time for ca in clip_annotations], ) + predictions = [ + to_raw_predictions(clip_dets.numpy(), targets=model.targets) + for clip_dets in clip_detections + ] self._clip_annotations.extend(clip_annotations) self._predictions.extend(predictions) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index ad4fa27..699b791 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -5,17 +5,9 @@ from soundevent import data from batdetect2.core.configs import BaseConfig, load_config from batdetect2.evaluate.config import EvaluationConfig -from batdetect2.train.augmentations import ( - DEFAULT_AUGMENTATION_CONFIG, - AugmentationsConfig, -) -from batdetect2.train.clips import ( - ClipConfig, - PaddedClipConfig, - RandomClipConfig, -) +from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig +from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.labels import LabelConfig -from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.losses import LossConfig __all__ = [ @@ -45,30 +37,6 @@ class PLTrainerConfig(BaseConfig): val_check_interval: Optional[Union[int, float]] = None -class ValLoaderConfig(BaseConfig): - num_workers: int = 0 - - clipping_strategy: ClipConfig = Field( - default_factory=lambda: PaddedClipConfig() - ) - - -class TrainLoaderConfig(BaseConfig): - num_workers: int = 0 - - batch_size: int = 8 - - shuffle: bool = False - - augmentations: AugmentationsConfig = Field( - default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy() - ) - - clipping_strategy: ClipConfig = Field( - default_factory=lambda: PaddedClipConfig() - ) - - class OptimizerConfig(BaseConfig): learning_rate: float = 1e-3 t_max: int = 100 @@ -79,9 +47,8 @@ class TrainingConfig(BaseConfig): val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) loss: LossConfig = Field(default_factory=LossConfig) - cliping: RandomClipConfig = Field(default_factory=RandomClipConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) - logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) + logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig) labels: LabelConfig = Field(default_factory=LabelConfig) validation: EvaluationConfig = Field(default_factory=EvaluationConfig) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index d1eeb3c..9fe54c9 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -2,18 +2,21 @@ from typing import List, Optional, Sequence import torch from loguru import logger +from pydantic import Field from soundevent import data from torch.utils.data import DataLoader, Dataset -from batdetect2.audio import build_audio_loader +from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper +from batdetect2.audio.clips import PaddedClipConfig +from batdetect2.core import BaseConfig from batdetect2.core.arrays import adjust_width from batdetect2.preprocess import build_preprocessor from batdetect2.train.augmentations import ( + DEFAULT_AUGMENTATION_CONFIG, + AugmentationsConfig, RandomAudioSource, build_augmentations, ) -from batdetect2.train.clips import build_clipper -from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.labels import build_clip_labeler from batdetect2.typing import ( AudioLoader, @@ -144,6 +147,22 @@ class ValidationDataset(Dataset): ) +class TrainLoaderConfig(BaseConfig): + num_workers: int = 0 + + batch_size: int = 8 + + shuffle: bool = False + + augmentations: AugmentationsConfig = Field( + default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy() + ) + + clipping_strategy: ClipConfig = Field( + default_factory=lambda: PaddedClipConfig() + ) + + def build_train_loader( clip_annotations: Sequence[data.ClipAnnotation], audio_loader: Optional[AudioLoader] = None, @@ -178,6 +197,14 @@ def build_train_loader( ) +class ValLoaderConfig(BaseConfig): + num_workers: int = 0 + + clipping_strategy: ClipConfig = Field( + default_factory=lambda: PaddedClipConfig() + ) + + def build_val_loader( clip_annotations: Sequence[data.ClipAnnotation], audio_loader: Optional[AudioLoader] = None, diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 18fb248..09470e1 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -8,20 +8,21 @@ from loguru import logger from soundevent import data from batdetect2.audio import build_audio_loader -from batdetect2.evaluate.evaluator import Evaluator, build_evaluator +from batdetect2.evaluate.evaluator import build_evaluator +from batdetect2.logging import build_logger from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module -from batdetect2.train.logging import build_logger if TYPE_CHECKING: from batdetect2.config import BatDetect2Config from batdetect2.typing import ( AudioLoader, ClipLabeller, + EvaluatorProtocol, PreprocessorProtocol, TargetProtocol, ) @@ -122,7 +123,7 @@ def train( def build_trainer_callbacks( targets: "TargetProtocol", - evaluator: Optional[Evaluator] = None, + evaluator: Optional["EvaluatorProtocol"] = None, checkpoint_dir: Optional[Path] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, @@ -142,7 +143,6 @@ def build_trainer_callbacks( ModelCheckpoint( dirpath=str(checkpoint_dir), save_top_k=1, - filename="best-{epoch:02d}-{val_loss:.0f}", monitor="total_loss/val", ), ValidationMetrics(evaluator), @@ -152,7 +152,7 @@ def build_trainer_callbacks( def build_trainer( conf: "BatDetect2Config", targets: "TargetProtocol", - evaluator: Optional[Evaluator] = None, + evaluator: Optional["EvaluatorProtocol"] = None, checkpoint_dir: Optional[Path] = None, log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py index 5395a38..a06f387 100644 --- a/src/batdetect2/typing/__init__.py +++ b/src/batdetect2/typing/__init__.py @@ -1,7 +1,9 @@ from batdetect2.typing.evaluate import ( ClipEvaluation, + EvaluatorProtocol, MatchEvaluation, MetricsProtocol, + PlotterProtocol, ) from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.postprocess import ( @@ -49,6 +51,7 @@ __all__ = [ "MatchEvaluation", "MetricsProtocol", "ModelOutput", + "PlotterProtocol", "Position", "PostprocessorProtocol", "PreprocessorProtocol", @@ -60,4 +63,5 @@ __all__ = [ "SoundEventFilter", "TargetProtocol", "TrainExample", + "EvaluatorProtocol", ] diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index fad09db..3c71405 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -14,7 +14,11 @@ from typing import ( from matplotlib.figure import Figure from soundevent import data +from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.targets import TargetProtocol + __all__ = [ + "EvaluatorProtocol", "MetricsProtocol", "MatchEvaluation", ] @@ -107,3 +111,21 @@ class PlotterProtocol(Protocol): def __call__( self, clip_evaluations: Sequence[ClipEvaluation] ) -> Iterable[Tuple[str, Figure]]: ... + + +class EvaluatorProtocol(Protocol): + targets: TargetProtocol + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> List[ClipEvaluation]: ... + + def compute_metrics( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: ... + + def generate_plots( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Iterable[Tuple[str, Figure]]: ... diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index 5f2addf..df45759 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -12,7 +12,7 @@ system that deal with model predictions. """ from dataclasses import dataclass -from typing import List, NamedTuple, Optional, Protocol +from typing import List, NamedTuple, Optional, Protocol, Sequence import numpy as np import torch @@ -47,15 +47,13 @@ class GeometryDecoder(Protocol): class RawPrediction(NamedTuple): - """Intermediate representation of a single detected sound event.""" - geometry: data.Geometry detection_score: float class_scores: np.ndarray features: np.ndarray -class DetectionsArray(NamedTuple): +class ClipDetectionsArray(NamedTuple): scores: np.ndarray sizes: np.ndarray class_scores: np.ndarray @@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple): features: np.ndarray -class DetectionsTensor(NamedTuple): +class ClipDetectionsTensor(NamedTuple): scores: torch.Tensor sizes: torch.Tensor class_scores: torch.Tensor @@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple): frequencies: torch.Tensor features: torch.Tensor - def numpy(self) -> DetectionsArray: - return DetectionsArray( + def numpy(self) -> ClipDetectionsArray: + return ClipDetectionsArray( scores=self.scores.detach().cpu().numpy(), sizes=self.sizes.detach().cpu().numpy(), class_scores=self.class_scores.detach().cpu().numpy(), @@ -92,10 +90,8 @@ class BatDetect2Prediction: class PostprocessorProtocol(Protocol): """Protocol defining the interface for the full postprocessing pipeline.""" - def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ... - - def get_detections( + def __call__( self, output: ModelOutput, - start_times: Optional[List[float]] = None, - ) -> List[DetectionsTensor]: ... + start_times: Optional[Sequence[float]] = None, + ) -> List[ClipDetectionsTensor]: ...