Use match config for training

This commit is contained in:
mbsantiago 2025-08-11 12:02:46 +01:00
parent 8aa2d0cd11
commit c41551b59c
5 changed files with 77 additions and 28 deletions

View File

@ -19,7 +19,7 @@ class ClassExamples:
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_examples(
def plot_example_gallery(
matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,

View File

@ -179,7 +179,7 @@ def plot_false_positive_match(
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
@ -326,7 +326,7 @@ def plot_true_positive_match(
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
@ -407,7 +407,7 @@ def plot_cross_trigger_match(
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,

View File

@ -1,14 +1,21 @@
from typing import List
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule
@ -16,16 +23,25 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
def __init__(
self,
metrics: List[MetricsProtocol],
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
super().__init__()
if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided")
self.matches: List[MatchEvaluation] = []
self.match_config = match_config
self.metrics = metrics
self.plot = plot
self._matches: List[
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
] = []
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
@ -33,25 +49,33 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, LabeledDataset)
return dataset
def plot_examples(self, pl_module: LightningModule):
def plot_examples(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
if not isinstance(pl_module.logger, TensorBoardLogger):
return
for class_name, fig in plot_examples(
self.matches,
for class_name, fig in plot_example_gallery(
matches,
preprocessor=pl_module.preprocessor,
n_examples=5,
):
pl_module.logger.experiment.add_figure(
f"{class_name}/examples",
f"images/{class_name}_examples",
fig,
pl_module.global_step,
)
def log_metrics(self, pl_module: LightningModule):
def log_metrics(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
metrics.update(metric(matches).items())
pl_module.log_dict(metrics)
@ -60,10 +84,16 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.log_metrics(pl_module)
matches = _match_all_collected_examples(
self._matches,
pl_module.targets,
config=self.match_config,
)
self.log_metrics(pl_module, matches)
if self.plot:
self.plot_examples(pl_module)
self.plot_examples(pl_module, matches)
return super().on_validation_epoch_end(trainer, pl_module)
@ -72,7 +102,7 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.matches = []
self._matches = []
return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore
@ -110,13 +140,26 @@ class ValidationMetrics(Callback):
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
):
self.matches.extend(
match_sound_events_and_raw_predictions(
clip_annotation=clip_annotation,
raw_predictions=clip_predictions,
targets=pl_module.targets,
)
)
self._matches.append((clip_annotation, clip_predictions))
def _match_all_collected_examples(
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions")
with Pool() as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
)
return [match for clip_matches in matches for match in clip_matches]
def _is_in_subclip(

View File

@ -4,6 +4,7 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig
@ -94,6 +95,7 @@ class FullTrainingConfig(BaseConfig):
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_full_training_config(

View File

@ -7,6 +7,7 @@ from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
@ -82,7 +83,9 @@ def train(
logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig
) -> List[Callback]:
return [
ModelCheckpoint(
dirpath="outputs/checkpoints",
@ -96,7 +99,8 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
class_names=targets.class_names
),
ClassificationAccuracy(class_names=targets.class_names),
]
],
match_config=config.match,
),
]
@ -113,7 +117,7 @@ def build_trainer(
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets),
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
)