Independent preprocessor for generating validation plots

This commit is contained in:
mbsantiago 2025-08-31 23:04:16 +01:00
parent d3d2a28130
commit 71c2301c21
2 changed files with 13 additions and 3 deletions

View File

@ -9,6 +9,7 @@ from batdetect2.evaluate.match import (
MatchConfig,
match_all_predictions,
)
from batdetect2.plotting.clips import PreprocessorProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_raw_predictions
from batdetect2.train.dataset import ValidationDataset
@ -27,6 +28,7 @@ class ValidationMetrics(Callback):
def __init__(
self,
metrics: List[MetricsProtocol],
preprocessor: PreprocessorProtocol,
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
@ -37,6 +39,7 @@ class ValidationMetrics(Callback):
self.match_config = match_config
self.metrics = metrics
self.preprocessor = preprocessor
self.plot = plot
self._clip_annotations: List[data.ClipAnnotation] = []
@ -61,7 +64,7 @@ class ValidationMetrics(Callback):
for class_name, fig in plot_example_gallery(
matches,
preprocessor=pl_module.model.preprocessor,
preprocessor=self.preprocessor,
n_examples=4,
):
plotter(

View File

@ -126,7 +126,9 @@ def build_training_module(
def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: EvaluationConfig,
) -> List[Callback]:
return [
ModelCheckpoint(
@ -142,6 +144,7 @@ def build_trainer_callbacks(
),
ClassificationAccuracy(class_names=targets.class_names),
],
preprocessor=preprocessor,
match_config=config.match,
),
]
@ -163,7 +166,11 @@ def build_trainer(
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
logger=train_logger,
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
callbacks=build_trainer_callbacks(
targets,
config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess),
),
)