From 71c2301c216c52fc516b11bcda9b8f1e86ea1570 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 31 Aug 2025 23:04:16 +0100 Subject: [PATCH] Independent preprocessor for generating validation plots --- src/batdetect2/train/callbacks.py | 5 ++++- src/batdetect2/train/train.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index ae6636b..342d88a 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -9,6 +9,7 @@ from batdetect2.evaluate.match import ( MatchConfig, match_all_predictions, ) +from batdetect2.plotting.clips import PreprocessorProtocol from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.postprocess import get_raw_predictions from batdetect2.train.dataset import ValidationDataset @@ -27,6 +28,7 @@ class ValidationMetrics(Callback): def __init__( self, metrics: List[MetricsProtocol], + preprocessor: PreprocessorProtocol, plot: bool = True, match_config: Optional[MatchConfig] = None, ): @@ -37,6 +39,7 @@ class ValidationMetrics(Callback): self.match_config = match_config self.metrics = metrics + self.preprocessor = preprocessor self.plot = plot self._clip_annotations: List[data.ClipAnnotation] = [] @@ -61,7 +64,7 @@ class ValidationMetrics(Callback): for class_name, fig in plot_example_gallery( matches, - preprocessor=pl_module.model.preprocessor, + preprocessor=self.preprocessor, n_examples=4, ): plotter( diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index b1b05cc..62cd7a9 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -126,7 +126,9 @@ def build_training_module( def build_trainer_callbacks( - targets: TargetProtocol, config: EvaluationConfig + targets: TargetProtocol, + preprocessor: PreprocessorProtocol, + config: EvaluationConfig, ) -> List[Callback]: return [ ModelCheckpoint( @@ -142,6 +144,7 @@ def build_trainer_callbacks( ), ClassificationAccuracy(class_names=targets.class_names), ], + preprocessor=preprocessor, match_config=config.match, ), ] @@ -163,7 +166,11 @@ def build_trainer( return Trainer( **trainer_conf.model_dump(exclude_none=True), logger=train_logger, - callbacks=build_trainer_callbacks(targets, config=conf.evaluation), + callbacks=build_trainer_callbacks( + targets, + config=conf.evaluation, + preprocessor=build_preprocessor(conf.preprocess), + ), )