Use match config for training

This commit is contained in:
mbsantiago 2025-08-11 12:02:46 +01:00
parent 8aa2d0cd11
commit c41551b59c
5 changed files with 77 additions and 28 deletions

View File

@ -19,7 +19,7 @@ class ClassExamples:
cross_triggers: List[MatchEvaluation] = field(default_factory=list) cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_examples( def plot_example_gallery(
matches: List[MatchEvaluation], matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
n_examples: int = 5, n_examples: int = 5,

View File

@ -179,7 +179,7 @@ def plot_false_positive_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ", f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
@ -326,7 +326,7 @@ def plot_true_positive_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
@ -407,7 +407,7 @@ def plot_cross_trigger_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,

View File

@ -1,14 +1,21 @@
from typing import List from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -16,16 +23,25 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True): def __init__(
self,
metrics: List[MetricsProtocol],
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
super().__init__() super().__init__()
if len(metrics) == 0: if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided") raise ValueError("At least one metric needs to be provided")
self.matches: List[MatchEvaluation] = [] self.match_config = match_config
self.metrics = metrics self.metrics = metrics
self.plot = plot self.plot = plot
self._matches: List[
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
] = []
def get_dataset(self, trainer: Trainer) -> LabeledDataset: def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader) assert isinstance(dataloaders, DataLoader)
@ -33,25 +49,33 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, LabeledDataset) assert isinstance(dataset, LabeledDataset)
return dataset return dataset
def plot_examples(self, pl_module: LightningModule): def plot_examples(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
if not isinstance(pl_module.logger, TensorBoardLogger): if not isinstance(pl_module.logger, TensorBoardLogger):
return return
for class_name, fig in plot_examples( for class_name, fig in plot_example_gallery(
self.matches, matches,
preprocessor=pl_module.preprocessor, preprocessor=pl_module.preprocessor,
n_examples=5, n_examples=5,
): ):
pl_module.logger.experiment.add_figure( pl_module.logger.experiment.add_figure(
f"{class_name}/examples", f"images/{class_name}_examples",
fig, fig,
pl_module.global_step, pl_module.global_step,
) )
def log_metrics(self, pl_module: LightningModule): def log_metrics(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
metrics = {} metrics = {}
for metric in self.metrics: for metric in self.metrics:
metrics.update(metric(self.matches).items()) metrics.update(metric(matches).items())
pl_module.log_dict(metrics) pl_module.log_dict(metrics)
@ -60,10 +84,16 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self.log_metrics(pl_module) matches = _match_all_collected_examples(
self._matches,
pl_module.targets,
config=self.match_config,
)
self.log_metrics(pl_module, matches)
if self.plot: if self.plot:
self.plot_examples(pl_module) self.plot_examples(pl_module, matches)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)
@ -72,7 +102,7 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self.matches = [] self._matches = []
return super().on_validation_epoch_start(trainer, pl_module) return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore def on_validation_batch_end( # type: ignore
@ -110,13 +140,26 @@ class ValidationMetrics(Callback):
for clip_annotation, clip_predictions in zip( for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions clip_annotations, raw_predictions
): ):
self.matches.extend( self._matches.append((clip_annotation, clip_predictions))
match_sound_events_and_raw_predictions(
clip_annotation=clip_annotation,
raw_predictions=clip_predictions, def _match_all_collected_examples(
targets=pl_module.targets, pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
) targets: TargetProtocol,
) config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions")
with Pool() as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
)
return [match for clip_matches in matches for match in clip_matches]
def _is_in_subclip( def _is_in_subclip(

View File

@ -4,6 +4,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models import BackboneConfig from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig from batdetect2.preprocess import PreprocessingConfig
@ -94,6 +95,7 @@ class FullTrainingConfig(BaseConfig):
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_full_training_config( def load_full_training_config(

View File

@ -7,6 +7,7 @@ from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.metrics import (
ClassificationAccuracy, ClassificationAccuracy,
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
@ -82,7 +83,9 @@ def train(
logger.info("Training complete.") logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig
) -> List[Callback]:
return [ return [
ModelCheckpoint( ModelCheckpoint(
dirpath="outputs/checkpoints", dirpath="outputs/checkpoints",
@ -96,7 +99,8 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
class_names=targets.class_names class_names=targets.class_names
), ),
ClassificationAccuracy(class_names=targets.class_names), ClassificationAccuracy(class_names=targets.class_names),
] ],
match_config=config.match,
), ),
] ]
@ -113,7 +117,7 @@ def build_trainer(
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
logger=build_logger(conf.train.logger), logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets), callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
) )