From cd4955d4f32a44dce8cd7632ac2daa5bd3c60e31 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 8 Sep 2025 22:04:30 +0100 Subject: [PATCH] Eval --- example_data/config.yaml | 7 +- src/batdetect2/cli/__init__.py | 2 + src/batdetect2/cli/evaluate.py | 63 +++++++++++++++++ src/batdetect2/cli/train.py | 6 ++ src/batdetect2/configs.py | 2 +- src/batdetect2/evaluate/dataframe.py | 62 +++++++++++++++++ src/batdetect2/evaluate/evaluate.py | 100 +++++++++++++++++++++++++++ src/batdetect2/models/__init__.py | 28 ++++---- src/batdetect2/train/config.py | 2 - src/batdetect2/train/lightning.py | 32 +++++++-- src/batdetect2/train/logging.py | 40 +++++++---- src/batdetect2/train/train.py | 50 +++++++++----- 12 files changed, 341 insertions(+), 53 deletions(-) create mode 100644 src/batdetect2/cli/evaluate.py create mode 100644 src/batdetect2/evaluate/dataframe.py create mode 100644 src/batdetect2/evaluate/evaluate.py diff --git a/example_data/config.yaml b/example_data/config.yaml index 083f873..90e1e42 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -108,6 +108,9 @@ train: labels: sigma: 3 + trainer: + max_epochs: 40 + dataloaders: train: batch_size: 8 @@ -115,7 +118,7 @@ train: shuffle: True val: - batch_size: 8 + batch_size: 1 num_workers: 2 loss: @@ -133,7 +136,7 @@ train: weight: 0.1 logger: - logger_type: tensorboard + logger_type: csv # save_dir: outputs/log/ # name: logs diff --git a/src/batdetect2/cli/__init__.py b/src/batdetect2/cli/__init__.py index bd5583e..cde6f8f 100644 --- a/src/batdetect2/cli/__init__.py +++ b/src/batdetect2/cli/__init__.py @@ -1,6 +1,7 @@ from batdetect2.cli.base import cli from batdetect2.cli.compat import detect from batdetect2.cli.data import data +from batdetect2.cli.evaluate import evaluate_command from batdetect2.cli.train import train_command __all__ = [ @@ -8,6 +9,7 @@ __all__ = [ "detect", "data", "train_command", + "evaluate_command", ] diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py new file mode 100644 index 0000000..172ef3f --- /dev/null +++ b/src/batdetect2/cli/evaluate.py @@ -0,0 +1,63 @@ +import sys +from pathlib import Path +from typing import Optional + +import click +from loguru import logger + +from batdetect2.cli.base import cli +from batdetect2.data import load_dataset_from_config +from batdetect2.evaluate.evaluate import evaluate +from batdetect2.train.lightning import load_model_from_checkpoint + +__all__ = ["evaluate_command"] + + +@cli.command(name="evaluate") +@click.argument("model-path", type=click.Path(exists=True)) +@click.argument("test_dataset", type=click.Path(exists=True)) +@click.option("--output-dir", type=click.Path()) +@click.option("--workers", type=int) +@click.option( + "-v", + "--verbose", + count=True, + help="Increase verbosity. -v for INFO, -vv for DEBUG.", +) +def evaluate_command( + model_path: Path, + test_dataset: Path, + output_dir: Optional[Path] = None, + workers: Optional[int] = None, + verbose: int = 0, +): + logger.remove() + if verbose == 0: + log_level = "WARNING" + elif verbose == 1: + log_level = "INFO" + else: + log_level = "DEBUG" + logger.add(sys.stderr, level=log_level) + + logger.info("Initiating evaluation process...") + + test_annotations = load_dataset_from_config(test_dataset) + logger.debug( + "Loaded {num_annotations} test examples", + num_annotations=len(test_annotations), + ) + + model, train_config = load_model_from_checkpoint(model_path) + + df, results = evaluate( + model, + test_annotations, + config=train_config, + num_workers=workers, + ) + + print(results) + + if output_dir: + df.to_csv(output_dir / "results.csv") diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index e574b38..cf6b089 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -20,6 +20,8 @@ __all__ = ["train_command"] @click.argument("train_dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True)) +@click.option("--ckpt-dir", type=click.Path(exists=True)) +@click.option("--log-dir", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True)) @click.option("--config-field", type=str) @click.option("--train-workers", type=int) @@ -34,6 +36,8 @@ def train_command( train_dataset: Path, val_dataset: Optional[Path] = None, model_path: Optional[Path] = None, + ckpt_dir: Optional[Path] = None, + log_dir: Optional[Path] = None, config: Optional[Path] = None, config_field: Optional[str] = None, train_workers: int = 0, @@ -83,4 +87,6 @@ def train_command( model_path=model_path, train_workers=train_workers, val_workers=val_workers, + log_dir=log_dir, + checkpoint_dir=ckpt_dir, ) diff --git a/src/batdetect2/configs.py b/src/batdetect2/configs.py index c7ffcd3..7399d6e 100644 --- a/src/batdetect2/configs.py +++ b/src/batdetect2/configs.py @@ -27,7 +27,7 @@ class BaseConfig(BaseModel): and serialization capabilities. """ - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="ignore") def to_yaml_string( self, diff --git a/src/batdetect2/evaluate/dataframe.py b/src/batdetect2/evaluate/dataframe.py new file mode 100644 index 0000000..4cc0ff9 --- /dev/null +++ b/src/batdetect2/evaluate/dataframe.py @@ -0,0 +1,62 @@ +from typing import List + +import pandas as pd +from soundevent.geometry import compute_bounds + +from batdetect2.typing.evaluate import MatchEvaluation + + +def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame: + data = [] + + for match in matches: + gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None + pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None + + sound_event_annotation = match.sound_event_annotation + + if sound_event_annotation is not None: + geometry = sound_event_annotation.sound_event.geometry + assert geometry is not None + gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = ( + compute_bounds(geometry) + ) + + if match.pred_geometry is not None: + pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = ( + compute_bounds(match.pred_geometry) + ) + + data.append( + { + ("recording", "uuid"): match.clip.recording.uuid, + ("clip", "uuid"): match.clip.uuid, + ("clip", "start_time"): match.clip.start_time, + ("clip", "end_time"): match.clip.end_time, + ("gt", "uuid"): match.sound_event_annotation.uuid + if match.sound_event_annotation is not None + else None, + ("gt", "class"): match.gt_class, + ("gt", "det"): match.gt_det, + ("gt", "start_time"): gt_start_time, + ("gt", "end_time"): gt_end_time, + ("gt", "low_freq"): gt_low_freq, + ("gt", "high_freq"): gt_high_freq, + ("pred", "score"): match.pred_score, + ("pred", "class"): match.pred_class, + ("pred", "class_score"): match.pred_class_score, + ("pred", "start_time"): pred_start_time, + ("pred", "end_time"): pred_end_time, + ("pred", "low_freq"): pred_low_freq, + ("pred", "high_freq"): pred_high_freq, + ("match", "affinity"): match.affinity, + **{ + ("pred_class_score", key): value + for key, value in match.pred_class_scores.items() + }, + } + ) + + df = pd.DataFrame(data) + df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore + return df diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py new file mode 100644 index 0000000..a4cef4d --- /dev/null +++ b/src/batdetect2/evaluate/evaluate.py @@ -0,0 +1,100 @@ +from typing import List, Optional, Tuple + +import pandas as pd +from soundevent import data + +from batdetect2.evaluate.dataframe import extract_matches_dataframe +from batdetect2.evaluate.match import match_all_predictions +from batdetect2.evaluate.metrics import ( + ClassificationAccuracy, + ClassificationMeanAveragePrecision, + DetectionAveragePrecision, +) +from batdetect2.models import Model +from batdetect2.plotting.clips import build_audio_loader +from batdetect2.postprocess import get_raw_predictions +from batdetect2.preprocess import build_preprocessor +from batdetect2.targets import build_targets +from batdetect2.train.config import FullTrainingConfig +from batdetect2.train.dataset import ValidationDataset +from batdetect2.train.labels import build_clip_labeler +from batdetect2.train.train import build_val_loader + + +def evaluate( + model: Model, + test_annotations: List[data.ClipAnnotation], + config: Optional[FullTrainingConfig] = None, + num_workers: Optional[int] = None, +) -> Tuple[pd.DataFrame, dict]: + config = config or FullTrainingConfig() + + audio_loader = build_audio_loader(config.preprocess.audio) + + preprocessor = build_preprocessor(config.preprocess) + + targets = build_targets(config.targets) + + labeller = build_clip_labeler( + targets, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + config=config.train.labels, + ) + + loader = build_val_loader( + test_annotations, + audio_loader=audio_loader, + labeller=labeller, + preprocessor=preprocessor, + config=config.train, + num_workers=num_workers, + ) + + dataset: ValidationDataset = loader.dataset # type: ignore + + clip_annotations = [] + predictions = [] + + for batch in loader: + outputs = model.detector(batch.spec) + + clip_annotations = [ + dataset.clip_annotations[int(example_idx)] + for example_idx in batch.idx + ] + + predictions = get_raw_predictions( + outputs, + clips=[ + clip_annotation.clip for clip_annotation in clip_annotations + ], + targets=targets, + postprocessor=model.postprocessor, + ) + + clip_annotations.extend(clip_annotations) + predictions.extend(predictions) + + matches = match_all_predictions( + clip_annotations, + predictions, + targets=targets, + config=config.evaluation.match, + ) + + df = extract_matches_dataframe(matches) + + metrics = [ + DetectionAveragePrecision(), + ClassificationMeanAveragePrecision(class_names=targets.class_names), + ClassificationAccuracy(class_names=targets.class_names), + ] + + results = { + name: value + for metric in metrics + for name, value in metric(matches).items() + } + + return df, results diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 0cb2e9a..404d5ca 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -29,7 +29,6 @@ provided here. from typing import List, Optional import torch -from lightning import LightningModule from pydantic import Field from soundevent.data import PathLike @@ -105,7 +104,16 @@ __all__ = [ ] -class Model(LightningModule): +class ModelConfig(BaseConfig): + model: BackboneConfig = Field(default_factory=BackboneConfig) + preprocess: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + + +class Model(torch.nn.Module): detector: DetectionModel preprocessor: PreprocessorProtocol postprocessor: PostprocessorProtocol @@ -117,13 +125,14 @@ class Model(LightningModule): preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, targets: TargetProtocol, + config: ModelConfig, ): super().__init__() self.detector = detector self.preprocessor = preprocessor self.postprocessor = postprocessor self.targets = targets - self.save_hyperparameters() + self.config = config def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]: spec = self.preprocessor(wav) @@ -131,29 +140,24 @@ class Model(LightningModule): return self.postprocessor(outputs) -class ModelConfig(BaseConfig): - model: BackboneConfig = Field(default_factory=BackboneConfig) - preprocess: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) - - def build_model(config: Optional[ModelConfig] = None): config = config or ModelConfig() targets = build_targets(config=config.targets) + preprocessor = build_preprocessor(config=config.preprocess) + postprocessor = build_postprocessor( preprocessor=preprocessor, config=config.postprocess, ) + detector = build_detector( num_classes=len(targets.class_names), config=config.model, ) return Model( + config=config, detector=detector, postprocessor=postprocessor, preprocessor=preprocessor, diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index acffb82..010ae63 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -6,7 +6,6 @@ from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.evaluate import EvaluationConfig from batdetect2.models import ModelConfig -from batdetect2.targets import TargetConfig from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, @@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig): cliping: ClipingConfig = Field(default_factory=ClipingConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) labels: LabelConfig = Field(default_factory=LabelConfig) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index bc9edd3..317f65b 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,9 +1,14 @@ +from typing import Optional, Tuple + import lightning as L import torch +from soundevent.data import PathLike from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.models import Model +from batdetect2.models import Model, build_model +from batdetect2.train.config import FullTrainingConfig +from batdetect2.train.losses import build_loss from batdetect2.typing import ModelOutput, TrainExample __all__ = [ @@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule): def __init__( self, - model: Model, - loss: torch.nn.Module, + config: FullTrainingConfig, learning_rate: float = 0.001, t_max: int = 100, + model: Optional[Model] = None, + loss: Optional[torch.nn.Module] = None, ): super().__init__() + self.save_hyperparameters(logger=False) + + self.config = config self.learning_rate = learning_rate self.t_max = t_max + if loss is None: + loss = build_loss(self.config.train.loss) + + if model is None: + model = build_model(self.config) + self.loss = loss self.model = model - self.save_hyperparameters(logger=False) - - def forward(self, spec: torch.Tensor) -> ModelOutput: - return self.model.detector(spec) def training_step(self, batch: TrainExample): outputs = self.model.detector(batch.spec) @@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule): optimizer = Adam(self.parameters(), lr=self.learning_rate) scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) return [optimizer], [scheduler] + + +def load_model_from_checkpoint( + path: PathLike, +) -> Tuple[Model, FullTrainingConfig]: + module = TrainingModule.load_from_checkpoint(path) # type: ignore + return module.model, module.config diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index 3684e01..f482a1c 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -5,10 +5,11 @@ import numpy as np from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger from loguru import logger from pydantic import Field +from soundevent import data from batdetect2.configs import BaseConfig -DEFAULT_LOGS_DIR: str = "logs" +DEFAULT_LOGS_DIR: str = "outputs" class DVCLiveConfig(BaseConfig): @@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig): class TensorBoardLoggerConfig(BaseConfig): logger_type: Literal["tensorboard"] = "tensorboard" save_dir: str = DEFAULT_LOGS_DIR - name: Optional[str] = "default" + name: Optional[str] = "logs" version: Optional[str] = None log_graph: bool = False @@ -57,7 +58,10 @@ LoggerConfig = Annotated[ ] -def create_dvclive_logger(config: DVCLiveConfig) -> Logger: +def create_dvclive_logger( + config: DVCLiveConfig, + log_dir: Optional[data.PathLike] = None, +) -> Logger: try: from dvclive.lightning import DVCLiveLogger # type: ignore except ImportError as error: @@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger: ) from error return DVCLiveLogger( - dir=config.dir, + dir=log_dir if log_dir is not None else config.dir, run_name=config.run_name, prefix=config.prefix, log_model=config.log_model, @@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger: ) -def create_csv_logger(config: CSVLoggerConfig) -> Logger: +def create_csv_logger( + config: CSVLoggerConfig, + log_dir: Optional[data.PathLike] = None, +) -> Logger: from lightning.pytorch.loggers import CSVLogger return CSVLogger( - save_dir=config.save_dir, + save_dir=str(log_dir) if log_dir is not None else config.save_dir, name=config.name, version=config.version, flush_logs_every_n_steps=config.flush_logs_every_n_steps, ) -def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger: +def create_tensorboard_logger( + config: TensorBoardLoggerConfig, + log_dir: Optional[data.PathLike] = None, +) -> Logger: from lightning.pytorch.loggers import TensorBoardLogger return TensorBoardLogger( - save_dir=config.save_dir, + save_dir=str(log_dir) if log_dir is not None else config.save_dir, name=config.name, version=config.version, log_graph=config.log_graph, ) -def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger: +def create_mlflow_logger( + config: MLFlowLoggerConfig, + log_dir: Optional[data.PathLike] = None, +) -> Logger: try: from lightning.pytorch.loggers import MLFlowLogger except ImportError as error: @@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger: return MLFlowLogger( experiment_name=config.experiment_name, run_name=config.run_name, - save_dir=config.save_dir, + save_dir=str(log_dir) if log_dir is not None else config.save_dir, tracking_uri=config.tracking_uri, tags=config.tags, log_model=config.log_model, @@ -126,7 +139,10 @@ LOGGER_FACTORY = { } -def build_logger(config: LoggerConfig) -> Logger: +def build_logger( + config: LoggerConfig, + log_dir: Optional[data.PathLike] = None, +) -> Logger: """ Creates a logger instance from a validated Pydantic config object. """ @@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger: creation_func = LOGGER_FACTORY[logger_type] - return creation_func(config) + return creation_func(config, log_dir=log_dir) def get_image_plotter(logger: Logger): diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index aaa67f8..0eb2d6f 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import ( ClassificationMeanAveragePrecision, DetectionAveragePrecision, ) -from batdetect2.models import Model, build_model from batdetect2.plotting.clips import AudioLoader, build_audio_loader from batdetect2.preprocess import build_preprocessor +from batdetect2.targets import build_targets from batdetect2.train.augmentations import ( RandomAudioSource, build_augmentations, @@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger -from batdetect2.train.losses import build_loss from batdetect2.typing import ( PreprocessorProtocol, TargetProtocol, @@ -54,19 +53,21 @@ def train( model_path: Optional[data.PathLike] = None, train_workers: Optional[int] = None, val_workers: Optional[int] = None, + checkpoint_dir: Optional[data.PathLike] = None, + log_dir: Optional[data.PathLike] = None, ): config = config or FullTrainingConfig() - model = build_model(config=config) + targets = build_targets(config.targets) - trainer = build_trainer(config, targets=model.targets) + preprocessor = build_preprocessor(config.preprocess) audio_loader = build_audio_loader(config=config.preprocess.audio) labeller = build_clip_labeler( - model.targets, - min_freq=model.preprocessor.min_freq, - max_freq=model.preprocessor.max_freq, + targets, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, config=config.train.labels, ) @@ -74,7 +75,7 @@ def train( train_annotations, audio_loader=audio_loader, labeller=labeller, - preprocessor=build_preprocessor(config.preprocess), + preprocessor=preprocessor, config=config.train, num_workers=train_workers, ) @@ -84,7 +85,7 @@ def train( val_annotations, audio_loader=audio_loader, labeller=labeller, - preprocessor=build_preprocessor(config.preprocess), + preprocessor=preprocessor, config=config.train, num_workers=val_workers, ) @@ -97,11 +98,17 @@ def train( module = TrainingModule.load_from_checkpoint(model_path) # type: ignore else: module = build_training_module( - model, config, t_max=config.train.t_max * len(train_dataloader), ) + trainer = build_trainer( + config, + targets=targets, + checkpoint_dir=checkpoint_dir, + log_dir=log_dir, + ) + logger.info("Starting main training loop...") trainer.fit( module, @@ -112,15 +119,12 @@ def train( def build_training_module( - model: Model, config: Optional[FullTrainingConfig] = None, t_max: int = 200, ) -> TrainingModule: config = config or FullTrainingConfig() - loss = build_loss(config=config.train.loss) return TrainingModule( - model=model, - loss=loss, + config=config, learning_rate=config.train.learning_rate, t_max=t_max, ) @@ -130,10 +134,14 @@ def build_trainer_callbacks( targets: TargetProtocol, preprocessor: PreprocessorProtocol, config: EvaluationConfig, + checkpoint_dir: Optional[data.PathLike] = None, ) -> List[Callback]: + if checkpoint_dir is None: + checkpoint_dir = "outputs/checkpoints" + return [ ModelCheckpoint( - dirpath="outputs/checkpoints", + dirpath=str(checkpoint_dir), save_top_k=1, monitor="total_loss/val", ), @@ -154,15 +162,22 @@ def build_trainer_callbacks( def build_trainer( conf: FullTrainingConfig, targets: TargetProtocol, + checkpoint_dir: Optional[data.PathLike] = None, + log_dir: Optional[data.PathLike] = None, ) -> Trainer: trainer_conf = conf.train.trainer logger.opt(lazy=True).debug( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) - train_logger = build_logger(conf.train.logger) + train_logger = build_logger(conf.train.logger, log_dir=log_dir) - train_logger.log_hyperparams(conf.model_dump(mode="json")) + train_logger.log_hyperparams( + conf.model_dump( + mode="json", + exclude_none=True, + ) + ) return Trainer( **trainer_conf.model_dump(exclude_none=True), @@ -171,6 +186,7 @@ def build_trainer( targets, config=conf.evaluation, preprocessor=build_preprocessor(conf.preprocess), + checkpoint_dir=checkpoint_dir, ), )