mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Updated callback to include plotting
This commit is contained in:
parent
aaec66c15e
commit
d9395d3eeb
@ -2,11 +2,13 @@ from typing import List
|
||||
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.plotting.evaluation import plot_examples
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -14,25 +16,55 @@ from batdetect2.train.types import ModelOutput
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self, metrics: List[MetricsProtocol]):
|
||||
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
|
||||
super().__init__()
|
||||
|
||||
if len(metrics) == 0:
|
||||
raise ValueError("At least one metric needs to be provided")
|
||||
|
||||
self.matches: List[Match] = []
|
||||
self.matches: List[MatchEvaluation] = []
|
||||
self.metrics = metrics
|
||||
self.plot = plot
|
||||
|
||||
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
return dataset
|
||||
|
||||
def plot_examples(self, pl_module: LightningModule):
|
||||
if not isinstance(pl_module.logger, TensorBoardLogger):
|
||||
return
|
||||
|
||||
for class_name, fig in plot_examples(
|
||||
self.matches,
|
||||
preprocessor=pl_module.preprocessor,
|
||||
n_examples=5,
|
||||
):
|
||||
pl_module.logger.experiment.add_figure(
|
||||
f"{class_name}/examples",
|
||||
fig,
|
||||
pl_module.global_step,
|
||||
)
|
||||
|
||||
def log_metrics(self, pl_module: LightningModule):
|
||||
metrics = {}
|
||||
for metric in self.metrics:
|
||||
metrics.update(metric(self.matches).items())
|
||||
|
||||
pl_module.log_dict(metrics)
|
||||
|
||||
def on_validation_epoch_end(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
metrics = {}
|
||||
for metric in self.metrics:
|
||||
metrics.update(metric(self.matches).items())
|
||||
self.log_metrics(pl_module)
|
||||
|
||||
if self.plot:
|
||||
self.plot_examples(pl_module)
|
||||
|
||||
pl_module.log_dict(metrics)
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
def on_validation_epoch_start(
|
||||
@ -52,11 +84,7 @@ class ValidationMetrics(Callback):
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
_get_subclip(
|
||||
@ -74,7 +102,7 @@ class ValidationMetrics(Callback):
|
||||
|
||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||
|
||||
raw_predictions = pl_module.postprocessor.get_raw_predictions(
|
||||
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
|
||||
outputs,
|
||||
clips,
|
||||
)
|
||||
@ -84,7 +112,7 @@ class ValidationMetrics(Callback):
|
||||
):
|
||||
self.matches.extend(
|
||||
match_sound_events_and_raw_predictions(
|
||||
sound_events=clip_annotation.sound_events,
|
||||
clip_annotation=clip_annotation,
|
||||
raw_predictions=clip_predictions,
|
||||
targets=pl_module.targets,
|
||||
)
|
||||
|
||||
@ -48,20 +48,19 @@ class TrainingModule(L.LightningModule):
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/train", losses.total, logger=True)
|
||||
self.log("size_loss/train", losses.total, logger=True)
|
||||
self.log("classification_loss/train", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step( # type: ignore
|
||||
self, batch: TrainExample, batch_idx: int
|
||||
self,
|
||||
batch: TrainExample,
|
||||
batch_idx: int,
|
||||
) -> ModelOutput:
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/val", losses.total, logger=True)
|
||||
self.log("size_loss/val", losses.total, logger=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user