Updated callback to include plotting

This commit is contained in:
mbsantiago 2025-08-08 13:06:28 +01:00
parent aaec66c15e
commit d9395d3eeb
2 changed files with 45 additions and 18 deletions

View File

@ -2,11 +2,13 @@ from typing import List
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.types import Match, MetricsProtocol
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule
@ -14,25 +16,55 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol]):
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
super().__init__()
if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided")
self.matches: List[Match] = []
self.matches: List[MatchEvaluation] = []
self.metrics = metrics
self.plot = plot
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
return dataset
def plot_examples(self, pl_module: LightningModule):
if not isinstance(pl_module.logger, TensorBoardLogger):
return
for class_name, fig in plot_examples(
self.matches,
preprocessor=pl_module.preprocessor,
n_examples=5,
):
pl_module.logger.experiment.add_figure(
f"{class_name}/examples",
fig,
pl_module.global_step,
)
def log_metrics(self, pl_module: LightningModule):
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
pl_module.log_dict(metrics)
def on_validation_epoch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
self.log_metrics(pl_module)
if self.plot:
self.plot_examples(pl_module)
pl_module.log_dict(metrics)
return super().on_validation_epoch_end(trainer, pl_module)
def on_validation_epoch_start(
@ -52,11 +84,7 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
dataset = self.get_dataset(trainer)
clip_annotations = [
_get_subclip(
@ -74,7 +102,7 @@ class ValidationMetrics(Callback):
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_raw_predictions(
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs,
clips,
)
@ -84,7 +112,7 @@ class ValidationMetrics(Callback):
):
self.matches.extend(
match_sound_events_and_raw_predictions(
sound_events=clip_annotation.sound_events,
clip_annotation=clip_annotation,
raw_predictions=clip_predictions,
targets=pl_module.targets,
)

View File

@ -48,20 +48,19 @@ class TrainingModule(L.LightningModule):
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True)
self.log("size_loss/train", losses.total, logger=True)
self.log("classification_loss/train", losses.total, logger=True)
return losses.total
def validation_step( # type: ignore
self, batch: TrainExample, batch_idx: int
self,
batch: TrainExample,
batch_idx: int,
) -> ModelOutput:
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True)
self.log("size_loss/val", losses.total, logger=True)