diff --git a/src/batdetect2/plotting/evaluation.py b/src/batdetect2/plotting/evaluation.py index bdb2702..41b92fa 100644 --- a/src/batdetect2/plotting/evaluation.py +++ b/src/batdetect2/plotting/evaluation.py @@ -19,7 +19,7 @@ class ClassExamples: cross_triggers: List[MatchEvaluation] = field(default_factory=list) -def plot_examples( +def plot_example_gallery( matches: List[MatchEvaluation], preprocessor: PreprocessorProtocol, n_examples: int = 5, diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 7061918..3a001a4 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -179,7 +179,7 @@ def plot_false_positive_match( plt.text( start_time, high_freq, - f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ", + f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ", va="top", ha="right", color=color, @@ -326,7 +326,7 @@ def plot_true_positive_match( plt.text( start_time, high_freq, - f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", + f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", va="top", ha="right", color=color, @@ -407,7 +407,7 @@ def plot_cross_trigger_match( plt.text( start_time, high_freq, - f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", + f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", va="top", ha="right", color=color, diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index a34d6c1..e9f3730 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -1,14 +1,21 @@ -from typing import List +from functools import partial +from multiprocessing import Pool +from typing import List, Optional, Tuple from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback from lightning.pytorch.loggers import TensorBoardLogger +from loguru import logger from soundevent import data from torch.utils.data import DataLoader -from batdetect2.evaluate.match import match_sound_events_and_raw_predictions +from batdetect2.evaluate.match import ( + MatchConfig, + match_sound_events_and_raw_predictions, +) from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol -from batdetect2.plotting.evaluation import plot_examples +from batdetect2.plotting.evaluation import plot_example_gallery +from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.lightning import TrainingModule @@ -16,16 +23,25 @@ from batdetect2.train.types import ModelOutput class ValidationMetrics(Callback): - def __init__(self, metrics: List[MetricsProtocol], plot: bool = True): + def __init__( + self, + metrics: List[MetricsProtocol], + plot: bool = True, + match_config: Optional[MatchConfig] = None, + ): super().__init__() if len(metrics) == 0: raise ValueError("At least one metric needs to be provided") - self.matches: List[MatchEvaluation] = [] + self.match_config = match_config self.metrics = metrics self.plot = plot + self._matches: List[ + Tuple[data.ClipAnnotation, List[BatDetect2Prediction]] + ] = [] + def get_dataset(self, trainer: Trainer) -> LabeledDataset: dataloaders = trainer.val_dataloaders assert isinstance(dataloaders, DataLoader) @@ -33,25 +49,33 @@ class ValidationMetrics(Callback): assert isinstance(dataset, LabeledDataset) return dataset - def plot_examples(self, pl_module: LightningModule): + def plot_examples( + self, + pl_module: LightningModule, + matches: List[MatchEvaluation], + ): if not isinstance(pl_module.logger, TensorBoardLogger): return - for class_name, fig in plot_examples( - self.matches, + for class_name, fig in plot_example_gallery( + matches, preprocessor=pl_module.preprocessor, n_examples=5, ): pl_module.logger.experiment.add_figure( - f"{class_name}/examples", + f"images/{class_name}_examples", fig, pl_module.global_step, ) - def log_metrics(self, pl_module: LightningModule): + def log_metrics( + self, + pl_module: LightningModule, + matches: List[MatchEvaluation], + ): metrics = {} for metric in self.metrics: - metrics.update(metric(self.matches).items()) + metrics.update(metric(matches).items()) pl_module.log_dict(metrics) @@ -60,10 +84,16 @@ class ValidationMetrics(Callback): trainer: Trainer, pl_module: LightningModule, ) -> None: - self.log_metrics(pl_module) + matches = _match_all_collected_examples( + self._matches, + pl_module.targets, + config=self.match_config, + ) + + self.log_metrics(pl_module, matches) if self.plot: - self.plot_examples(pl_module) + self.plot_examples(pl_module, matches) return super().on_validation_epoch_end(trainer, pl_module) @@ -72,7 +102,7 @@ class ValidationMetrics(Callback): trainer: Trainer, pl_module: LightningModule, ) -> None: - self.matches = [] + self._matches = [] return super().on_validation_epoch_start(trainer, pl_module) def on_validation_batch_end( # type: ignore @@ -110,13 +140,26 @@ class ValidationMetrics(Callback): for clip_annotation, clip_predictions in zip( clip_annotations, raw_predictions ): - self.matches.extend( - match_sound_events_and_raw_predictions( - clip_annotation=clip_annotation, - raw_predictions=clip_predictions, - targets=pl_module.targets, - ) - ) + self._matches.append((clip_annotation, clip_predictions)) + + +def _match_all_collected_examples( + pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]], + targets: TargetProtocol, + config: Optional[MatchConfig] = None, +) -> List[MatchEvaluation]: + logger.info("Matching all annotations and predictions") + + with Pool() as p: + matches = p.starmap( + partial( + match_sound_events_and_raw_predictions, + targets=targets, + config=config, + ), + pre_matches, + ) + return [match for clip_matches in matches for match in clip_matches] def _is_in_subclip( diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 0a013aa..a21101b 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -4,6 +4,7 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config +from batdetect2.evaluate.config import EvaluationConfig from batdetect2.models import BackboneConfig from batdetect2.postprocess import PostprocessConfig from batdetect2.preprocess import PreprocessingConfig @@ -94,6 +95,7 @@ class FullTrainingConfig(BaseConfig): default_factory=PreprocessingConfig ) postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) def load_full_training_config( diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index caaa768..dacaae5 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -7,6 +7,7 @@ from loguru import logger from soundevent import data from torch.utils.data import DataLoader +from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.metrics import ( ClassificationAccuracy, ClassificationMeanAveragePrecision, @@ -82,7 +83,9 @@ def train( logger.info("Training complete.") -def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: +def build_trainer_callbacks( + targets: TargetProtocol, config: EvaluationConfig +) -> List[Callback]: return [ ModelCheckpoint( dirpath="outputs/checkpoints", @@ -96,7 +99,8 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: class_names=targets.class_names ), ClassificationAccuracy(class_names=targets.class_names), - ] + ], + match_config=config.match, ), ] @@ -113,7 +117,7 @@ def build_trainer( return Trainer( **trainer_conf.model_dump(exclude_none=True), logger=build_logger(conf.train.logger), - callbacks=build_trainer_callbacks(targets), + callbacks=build_trainer_callbacks(targets, config=conf.evaluation), )