mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Updated callback to include plotting
This commit is contained in:
parent
aaec66c15e
commit
d9395d3eeb
@ -2,11 +2,13 @@ from typing import List
|
|||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
||||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||||
|
from batdetect2.plotting.evaluation import plot_examples
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
@ -14,25 +16,55 @@ from batdetect2.train.types import ModelOutput
|
|||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self, metrics: List[MetricsProtocol]):
|
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if len(metrics) == 0:
|
if len(metrics) == 0:
|
||||||
raise ValueError("At least one metric needs to be provided")
|
raise ValueError("At least one metric needs to be provided")
|
||||||
|
|
||||||
self.matches: List[Match] = []
|
self.matches: List[MatchEvaluation] = []
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
self.plot = plot
|
||||||
|
|
||||||
|
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
||||||
|
dataloaders = trainer.val_dataloaders
|
||||||
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
dataset = dataloaders.dataset
|
||||||
|
assert isinstance(dataset, LabeledDataset)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def plot_examples(self, pl_module: LightningModule):
|
||||||
|
if not isinstance(pl_module.logger, TensorBoardLogger):
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, fig in plot_examples(
|
||||||
|
self.matches,
|
||||||
|
preprocessor=pl_module.preprocessor,
|
||||||
|
n_examples=5,
|
||||||
|
):
|
||||||
|
pl_module.logger.experiment.add_figure(
|
||||||
|
f"{class_name}/examples",
|
||||||
|
fig,
|
||||||
|
pl_module.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_metrics(self, pl_module: LightningModule):
|
||||||
|
metrics = {}
|
||||||
|
for metric in self.metrics:
|
||||||
|
metrics.update(metric(self.matches).items())
|
||||||
|
|
||||||
|
pl_module.log_dict(metrics)
|
||||||
|
|
||||||
def on_validation_epoch_end(
|
def on_validation_epoch_end(
|
||||||
self,
|
self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
metrics = {}
|
self.log_metrics(pl_module)
|
||||||
for metric in self.metrics:
|
|
||||||
metrics.update(metric(self.matches).items())
|
if self.plot:
|
||||||
|
self.plot_examples(pl_module)
|
||||||
|
|
||||||
pl_module.log_dict(metrics)
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
def on_validation_epoch_start(
|
def on_validation_epoch_start(
|
||||||
@ -52,11 +84,7 @@ class ValidationMetrics(Callback):
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataset = self.get_dataset(trainer)
|
||||||
assert isinstance(dataloaders, DataLoader)
|
|
||||||
|
|
||||||
dataset = dataloaders.dataset
|
|
||||||
assert isinstance(dataset, LabeledDataset)
|
|
||||||
|
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
_get_subclip(
|
_get_subclip(
|
||||||
@ -74,7 +102,7 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||||
|
|
||||||
raw_predictions = pl_module.postprocessor.get_raw_predictions(
|
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
|
||||||
outputs,
|
outputs,
|
||||||
clips,
|
clips,
|
||||||
)
|
)
|
||||||
@ -84,7 +112,7 @@ class ValidationMetrics(Callback):
|
|||||||
):
|
):
|
||||||
self.matches.extend(
|
self.matches.extend(
|
||||||
match_sound_events_and_raw_predictions(
|
match_sound_events_and_raw_predictions(
|
||||||
sound_events=clip_annotation.sound_events,
|
clip_annotation=clip_annotation,
|
||||||
raw_predictions=clip_predictions,
|
raw_predictions=clip_predictions,
|
||||||
targets=pl_module.targets,
|
targets=pl_module.targets,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -48,20 +48,19 @@ class TrainingModule(L.LightningModule):
|
|||||||
def training_step(self, batch: TrainExample):
|
def training_step(self, batch: TrainExample):
|
||||||
outputs = self.forward(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/train", losses.total, logger=True)
|
self.log("detection_loss/train", losses.total, logger=True)
|
||||||
self.log("size_loss/train", losses.total, logger=True)
|
self.log("size_loss/train", losses.total, logger=True)
|
||||||
self.log("classification_loss/train", losses.total, logger=True)
|
self.log("classification_loss/train", losses.total, logger=True)
|
||||||
|
|
||||||
return losses.total
|
return losses.total
|
||||||
|
|
||||||
def validation_step( # type: ignore
|
def validation_step( # type: ignore
|
||||||
self, batch: TrainExample, batch_idx: int
|
self,
|
||||||
|
batch: TrainExample,
|
||||||
|
batch_idx: int,
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
outputs = self.forward(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/val", losses.total, logger=True)
|
self.log("detection_loss/val", losses.total, logger=True)
|
||||||
self.log("size_loss/val", losses.total, logger=True)
|
self.log("size_loss/val", losses.total, logger=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user