Updated callback to include plotting

This commit is contained in:
mbsantiago 2025-08-08 13:06:28 +01:00
parent aaec66c15e
commit d9395d3eeb
2 changed files with 45 additions and 18 deletions

View File

@ -2,11 +2,13 @@ from typing import List
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.types import Match, MetricsProtocol from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -14,25 +16,55 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol]): def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
super().__init__() super().__init__()
if len(metrics) == 0: if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided") raise ValueError("At least one metric needs to be provided")
self.matches: List[Match] = [] self.matches: List[MatchEvaluation] = []
self.metrics = metrics self.metrics = metrics
self.plot = plot
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
return dataset
def plot_examples(self, pl_module: LightningModule):
if not isinstance(pl_module.logger, TensorBoardLogger):
return
for class_name, fig in plot_examples(
self.matches,
preprocessor=pl_module.preprocessor,
n_examples=5,
):
pl_module.logger.experiment.add_figure(
f"{class_name}/examples",
fig,
pl_module.global_step,
)
def log_metrics(self, pl_module: LightningModule):
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
pl_module.log_dict(metrics)
def on_validation_epoch_end( def on_validation_epoch_end(
self, self,
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
metrics = {} self.log_metrics(pl_module)
for metric in self.metrics:
metrics.update(metric(self.matches).items()) if self.plot:
self.plot_examples(pl_module)
pl_module.log_dict(metrics)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)
def on_validation_epoch_start( def on_validation_epoch_start(
@ -52,11 +84,7 @@ class ValidationMetrics(Callback):
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> None: ) -> None:
dataloaders = trainer.val_dataloaders dataset = self.get_dataset(trainer)
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotations = [ clip_annotations = [
_get_subclip( _get_subclip(
@ -74,7 +102,7 @@ class ValidationMetrics(Callback):
clips = [clip_annotation.clip for clip_annotation in clip_annotations] clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_raw_predictions( raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs, outputs,
clips, clips,
) )
@ -84,7 +112,7 @@ class ValidationMetrics(Callback):
): ):
self.matches.extend( self.matches.extend(
match_sound_events_and_raw_predictions( match_sound_events_and_raw_predictions(
sound_events=clip_annotation.sound_events, clip_annotation=clip_annotation,
raw_predictions=clip_predictions, raw_predictions=clip_predictions,
targets=pl_module.targets, targets=pl_module.targets,
) )

View File

@ -48,20 +48,19 @@ class TrainingModule(L.LightningModule):
def training_step(self, batch: TrainExample): def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec) outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True) self.log("detection_loss/train", losses.total, logger=True)
self.log("size_loss/train", losses.total, logger=True) self.log("size_loss/train", losses.total, logger=True)
self.log("classification_loss/train", losses.total, logger=True) self.log("classification_loss/train", losses.total, logger=True)
return losses.total return losses.total
def validation_step( # type: ignore def validation_step( # type: ignore
self, batch: TrainExample, batch_idx: int self,
batch: TrainExample,
batch_idx: int,
) -> ModelOutput: ) -> ModelOutput:
outputs = self.forward(batch.spec) outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True) self.log("detection_loss/val", losses.total, logger=True)
self.log("size_loss/val", losses.total, logger=True) self.log("size_loss/val", losses.total, logger=True)