Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
b997a122f1 Enable batdetect2 in the cli 2025-08-11 12:25:04 +01:00
mbsantiago
c41551b59c Use match config for training 2025-08-11 12:02:46 +01:00
mbsantiago
8aa2d0cd11 Add evaluation config with match strategy 2025-08-11 12:02:30 +01:00
8 changed files with 111 additions and 29 deletions

View File

@ -1,6 +1,7 @@
"""BatDetect2 command line interface."""
import click
from loguru import logger
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
@ -21,4 +22,5 @@ BatDetect2 - Detection and Classification
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
logger.enable("batdetect2")
# click.echo(BATDETECT_ASCII_ART)

View File

@ -1,9 +1,15 @@
from batdetect2.evaluate.config import (
EvaluationConfig,
load_evaluation_config,
)
from batdetect2.evaluate.match import (
match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
)
__all__ = [
"match_sound_events_and_raw_predictions",
"EvaluationConfig",
"load_evaluation_config",
"match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
]

View File

@ -0,0 +1,25 @@
from typing import Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
]
class EvaluationConfig(BaseConfig):
match: MatchConfig = Field(
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
)
def load_evaluation_config(
path: data.PathLike,
field: Optional[str] = None,
) -> EvaluationConfig:
return load_config(path, schema=EvaluationConfig, field=field)

View File

@ -19,7 +19,7 @@ class ClassExamples:
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_examples(
def plot_example_gallery(
matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,

View File

@ -179,7 +179,7 @@ def plot_false_positive_match(
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
@ -326,7 +326,7 @@ def plot_true_positive_match(
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
@ -407,7 +407,7 @@ def plot_cross_trigger_match(
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,

View File

@ -1,14 +1,21 @@
from typing import List
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule
@ -16,16 +23,25 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
def __init__(
self,
metrics: List[MetricsProtocol],
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
super().__init__()
if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided")
self.matches: List[MatchEvaluation] = []
self.match_config = match_config
self.metrics = metrics
self.plot = plot
self._matches: List[
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
] = []
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
@ -33,25 +49,33 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, LabeledDataset)
return dataset
def plot_examples(self, pl_module: LightningModule):
def plot_examples(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
if not isinstance(pl_module.logger, TensorBoardLogger):
return
for class_name, fig in plot_examples(
self.matches,
for class_name, fig in plot_example_gallery(
matches,
preprocessor=pl_module.preprocessor,
n_examples=5,
):
pl_module.logger.experiment.add_figure(
f"{class_name}/examples",
f"images/{class_name}_examples",
fig,
pl_module.global_step,
)
def log_metrics(self, pl_module: LightningModule):
def log_metrics(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
metrics.update(metric(matches).items())
pl_module.log_dict(metrics)
@ -60,10 +84,16 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.log_metrics(pl_module)
matches = _match_all_collected_examples(
self._matches,
pl_module.targets,
config=self.match_config,
)
self.log_metrics(pl_module, matches)
if self.plot:
self.plot_examples(pl_module)
self.plot_examples(pl_module, matches)
return super().on_validation_epoch_end(trainer, pl_module)
@ -72,7 +102,7 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.matches = []
self._matches = []
return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore
@ -110,13 +140,26 @@ class ValidationMetrics(Callback):
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
):
self.matches.extend(
match_sound_events_and_raw_predictions(
clip_annotation=clip_annotation,
raw_predictions=clip_predictions,
targets=pl_module.targets,
)
)
self._matches.append((clip_annotation, clip_predictions))
def _match_all_collected_examples(
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions")
with Pool() as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
)
return [match for clip_matches in matches for match in clip_matches]
def _is_in_subclip(

View File

@ -4,6 +4,7 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig
@ -94,6 +95,7 @@ class FullTrainingConfig(BaseConfig):
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_full_training_config(

View File

@ -7,6 +7,7 @@ from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
@ -82,7 +83,9 @@ def train(
logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig
) -> List[Callback]:
return [
ModelCheckpoint(
dirpath="outputs/checkpoints",
@ -96,7 +99,8 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
class_names=targets.class_names
),
ClassificationAccuracy(class_names=targets.class_names),
]
],
match_config=config.match,
),
]
@ -113,7 +117,7 @@ def build_trainer(
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets),
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
)