mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Independent preprocessor for generating validation plots
This commit is contained in:
parent
d3d2a28130
commit
71c2301c21
@ -9,6 +9,7 @@ from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
match_all_predictions,
|
||||
)
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol
|
||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
@ -27,6 +28,7 @@ class ValidationMetrics(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
metrics: List[MetricsProtocol],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
plot: bool = True,
|
||||
match_config: Optional[MatchConfig] = None,
|
||||
):
|
||||
@ -37,6 +39,7 @@ class ValidationMetrics(Callback):
|
||||
|
||||
self.match_config = match_config
|
||||
self.metrics = metrics
|
||||
self.preprocessor = preprocessor
|
||||
self.plot = plot
|
||||
|
||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||
@ -61,7 +64,7 @@ class ValidationMetrics(Callback):
|
||||
|
||||
for class_name, fig in plot_example_gallery(
|
||||
matches,
|
||||
preprocessor=pl_module.model.preprocessor,
|
||||
preprocessor=self.preprocessor,
|
||||
n_examples=4,
|
||||
):
|
||||
plotter(
|
||||
|
||||
@ -126,7 +126,9 @@ def build_training_module(
|
||||
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: TargetProtocol, config: EvaluationConfig
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: EvaluationConfig,
|
||||
) -> List[Callback]:
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
@ -142,6 +144,7 @@ def build_trainer_callbacks(
|
||||
),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
],
|
||||
preprocessor=preprocessor,
|
||||
match_config=config.match,
|
||||
),
|
||||
]
|
||||
@ -163,7 +166,11 @@ def build_trainer(
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
logger=train_logger,
|
||||
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
||||
callbacks=build_trainer_callbacks(
|
||||
targets,
|
||||
config=conf.evaluation,
|
||||
preprocessor=build_preprocessor(conf.preprocess),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user