mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Independent preprocessor for generating validation plots
This commit is contained in:
parent
d3d2a28130
commit
71c2301c21
@ -9,6 +9,7 @@ from batdetect2.evaluate.match import (
|
|||||||
MatchConfig,
|
MatchConfig,
|
||||||
match_all_predictions,
|
match_all_predictions,
|
||||||
)
|
)
|
||||||
|
from batdetect2.plotting.clips import PreprocessorProtocol
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
from batdetect2.postprocess import get_raw_predictions
|
from batdetect2.postprocess import get_raw_predictions
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
@ -27,6 +28,7 @@ class ValidationMetrics(Callback):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
metrics: List[MetricsProtocol],
|
metrics: List[MetricsProtocol],
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
plot: bool = True,
|
plot: bool = True,
|
||||||
match_config: Optional[MatchConfig] = None,
|
match_config: Optional[MatchConfig] = None,
|
||||||
):
|
):
|
||||||
@ -37,6 +39,7 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
self.match_config = match_config
|
self.match_config = match_config
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
self.preprocessor = preprocessor
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
@ -61,7 +64,7 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
for class_name, fig in plot_example_gallery(
|
for class_name, fig in plot_example_gallery(
|
||||||
matches,
|
matches,
|
||||||
preprocessor=pl_module.model.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
n_examples=4,
|
n_examples=4,
|
||||||
):
|
):
|
||||||
plotter(
|
plotter(
|
||||||
|
|||||||
@ -126,7 +126,9 @@ def build_training_module(
|
|||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(
|
def build_trainer_callbacks(
|
||||||
targets: TargetProtocol, config: EvaluationConfig
|
targets: TargetProtocol,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
config: EvaluationConfig,
|
||||||
) -> List[Callback]:
|
) -> List[Callback]:
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
@ -142,6 +144,7 @@ def build_trainer_callbacks(
|
|||||||
),
|
),
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
],
|
],
|
||||||
|
preprocessor=preprocessor,
|
||||||
match_config=config.match,
|
match_config=config.match,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -163,7 +166,11 @@ def build_trainer(
|
|||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
callbacks=build_trainer_callbacks(
|
||||||
|
targets,
|
||||||
|
config=conf.evaluation,
|
||||||
|
preprocessor=build_preprocessor(conf.preprocess),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user