From d9395d3eebae2503fe4b0b674fbfe08f6f24e604 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 8 Aug 2025 13:06:28 +0100 Subject: [PATCH] Updated callback to include plotting --- src/batdetect2/train/callbacks.py | 56 +++++++++++++++++++++++-------- src/batdetect2/train/lightning.py | 7 ++-- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index fc2d748..a34d6c1 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -2,11 +2,13 @@ from typing import List from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loggers import TensorBoardLogger from soundevent import data from torch.utils.data import DataLoader from batdetect2.evaluate.match import match_sound_events_and_raw_predictions -from batdetect2.evaluate.types import Match, MetricsProtocol +from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol +from batdetect2.plotting.evaluation import plot_examples from batdetect2.targets.types import TargetProtocol from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.lightning import TrainingModule @@ -14,25 +16,55 @@ from batdetect2.train.types import ModelOutput class ValidationMetrics(Callback): - def __init__(self, metrics: List[MetricsProtocol]): + def __init__(self, metrics: List[MetricsProtocol], plot: bool = True): super().__init__() if len(metrics) == 0: raise ValueError("At least one metric needs to be provided") - self.matches: List[Match] = [] + self.matches: List[MatchEvaluation] = [] self.metrics = metrics + self.plot = plot + + def get_dataset(self, trainer: Trainer) -> LabeledDataset: + dataloaders = trainer.val_dataloaders + assert isinstance(dataloaders, DataLoader) + dataset = dataloaders.dataset + assert isinstance(dataset, LabeledDataset) + return dataset + + def plot_examples(self, pl_module: LightningModule): + if not isinstance(pl_module.logger, TensorBoardLogger): + return + + for class_name, fig in plot_examples( + self.matches, + preprocessor=pl_module.preprocessor, + n_examples=5, + ): + pl_module.logger.experiment.add_figure( + f"{class_name}/examples", + fig, + pl_module.global_step, + ) + + def log_metrics(self, pl_module: LightningModule): + metrics = {} + for metric in self.metrics: + metrics.update(metric(self.matches).items()) + + pl_module.log_dict(metrics) def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule, ) -> None: - metrics = {} - for metric in self.metrics: - metrics.update(metric(self.matches).items()) + self.log_metrics(pl_module) + + if self.plot: + self.plot_examples(pl_module) - pl_module.log_dict(metrics) return super().on_validation_epoch_end(trainer, pl_module) def on_validation_epoch_start( @@ -52,11 +84,7 @@ class ValidationMetrics(Callback): batch_idx: int, dataloader_idx: int = 0, ) -> None: - dataloaders = trainer.val_dataloaders - assert isinstance(dataloaders, DataLoader) - - dataset = dataloaders.dataset - assert isinstance(dataset, LabeledDataset) + dataset = self.get_dataset(trainer) clip_annotations = [ _get_subclip( @@ -74,7 +102,7 @@ class ValidationMetrics(Callback): clips = [clip_annotation.clip for clip_annotation in clip_annotations] - raw_predictions = pl_module.postprocessor.get_raw_predictions( + raw_predictions = pl_module.postprocessor.get_sound_event_predictions( outputs, clips, ) @@ -84,7 +112,7 @@ class ValidationMetrics(Callback): ): self.matches.extend( match_sound_events_and_raw_predictions( - sound_events=clip_annotation.sound_events, + clip_annotation=clip_annotation, raw_predictions=clip_predictions, targets=pl_module.targets, ) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 74dd282..5d9151e 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -48,20 +48,19 @@ class TrainingModule(L.LightningModule): def training_step(self, batch: TrainExample): outputs = self.forward(batch.spec) losses = self.loss(outputs, batch) - self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("detection_loss/train", losses.total, logger=True) self.log("size_loss/train", losses.total, logger=True) self.log("classification_loss/train", losses.total, logger=True) - return losses.total def validation_step( # type: ignore - self, batch: TrainExample, batch_idx: int + self, + batch: TrainExample, + batch_idx: int, ) -> ModelOutput: outputs = self.forward(batch.spec) losses = self.loss(outputs, batch) - self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("detection_loss/val", losses.total, logger=True) self.log("size_loss/val", losses.total, logger=True)