Compare commits

..

3 Commits

Author SHA1 Message Date
mbsantiago
b997a122f1 Enable batdetect2 in the cli 2025-08-11 12:25:04 +01:00
mbsantiago
c41551b59c Use match config for training 2025-08-11 12:02:46 +01:00
mbsantiago
8aa2d0cd11 Add evaluation config with match strategy 2025-08-11 12:02:30 +01:00
8 changed files with 111 additions and 29 deletions

View File

@ -1,6 +1,7 @@
"""BatDetect2 command line interface.""" """BatDetect2 command line interface."""
import click import click
from loguru import logger
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART # from batdetect2.cli.ascii import BATDETECT_ASCII_ART
@ -21,4 +22,5 @@ BatDetect2 - Detection and Classification
def cli(): def cli():
"""BatDetect2 - Bat Call Detection and Classification.""" """BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR) click.echo(INFO_STR)
logger.enable("batdetect2")
# click.echo(BATDETECT_ASCII_ART) # click.echo(BATDETECT_ASCII_ART)

View File

@ -1,9 +1,15 @@
from batdetect2.evaluate.config import (
EvaluationConfig,
load_evaluation_config,
)
from batdetect2.evaluate.match import ( from batdetect2.evaluate.match import (
match_predictions_and_annotations, match_predictions_and_annotations,
match_sound_events_and_raw_predictions, match_sound_events_and_raw_predictions,
) )
__all__ = [ __all__ = [
"match_sound_events_and_raw_predictions", "EvaluationConfig",
"load_evaluation_config",
"match_predictions_and_annotations", "match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
] ]

View File

@ -0,0 +1,25 @@
from typing import Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
]
class EvaluationConfig(BaseConfig):
match: MatchConfig = Field(
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
)
def load_evaluation_config(
path: data.PathLike,
field: Optional[str] = None,
) -> EvaluationConfig:
return load_config(path, schema=EvaluationConfig, field=field)

View File

@ -19,7 +19,7 @@ class ClassExamples:
cross_triggers: List[MatchEvaluation] = field(default_factory=list) cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_examples( def plot_example_gallery(
matches: List[MatchEvaluation], matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
n_examples: int = 5, n_examples: int = 5,

View File

@ -179,7 +179,7 @@ def plot_false_positive_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ", f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
@ -326,7 +326,7 @@ def plot_true_positive_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
@ -407,7 +407,7 @@ def plot_cross_trigger_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,

View File

@ -1,14 +1,21 @@
from typing import List from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -16,16 +23,25 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True): def __init__(
self,
metrics: List[MetricsProtocol],
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
super().__init__() super().__init__()
if len(metrics) == 0: if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided") raise ValueError("At least one metric needs to be provided")
self.matches: List[MatchEvaluation] = [] self.match_config = match_config
self.metrics = metrics self.metrics = metrics
self.plot = plot self.plot = plot
self._matches: List[
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
] = []
def get_dataset(self, trainer: Trainer) -> LabeledDataset: def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader) assert isinstance(dataloaders, DataLoader)
@ -33,25 +49,33 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, LabeledDataset) assert isinstance(dataset, LabeledDataset)
return dataset return dataset
def plot_examples(self, pl_module: LightningModule): def plot_examples(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
if not isinstance(pl_module.logger, TensorBoardLogger): if not isinstance(pl_module.logger, TensorBoardLogger):
return return
for class_name, fig in plot_examples( for class_name, fig in plot_example_gallery(
self.matches, matches,
preprocessor=pl_module.preprocessor, preprocessor=pl_module.preprocessor,
n_examples=5, n_examples=5,
): ):
pl_module.logger.experiment.add_figure( pl_module.logger.experiment.add_figure(
f"{class_name}/examples", f"images/{class_name}_examples",
fig, fig,
pl_module.global_step, pl_module.global_step,
) )
def log_metrics(self, pl_module: LightningModule): def log_metrics(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
metrics = {} metrics = {}
for metric in self.metrics: for metric in self.metrics:
metrics.update(metric(self.matches).items()) metrics.update(metric(matches).items())
pl_module.log_dict(metrics) pl_module.log_dict(metrics)
@ -60,10 +84,16 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self.log_metrics(pl_module) matches = _match_all_collected_examples(
self._matches,
pl_module.targets,
config=self.match_config,
)
self.log_metrics(pl_module, matches)
if self.plot: if self.plot:
self.plot_examples(pl_module) self.plot_examples(pl_module, matches)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)
@ -72,7 +102,7 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self.matches = [] self._matches = []
return super().on_validation_epoch_start(trainer, pl_module) return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore def on_validation_batch_end( # type: ignore
@ -110,13 +140,26 @@ class ValidationMetrics(Callback):
for clip_annotation, clip_predictions in zip( for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions clip_annotations, raw_predictions
): ):
self.matches.extend( self._matches.append((clip_annotation, clip_predictions))
match_sound_events_and_raw_predictions(
clip_annotation=clip_annotation,
raw_predictions=clip_predictions, def _match_all_collected_examples(
targets=pl_module.targets, pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
) targets: TargetProtocol,
) config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions")
with Pool() as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
)
return [match for clip_matches in matches for match in clip_matches]
def _is_in_subclip( def _is_in_subclip(

View File

@ -4,6 +4,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models import BackboneConfig from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig from batdetect2.preprocess import PreprocessingConfig
@ -94,6 +95,7 @@ class FullTrainingConfig(BaseConfig):
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_full_training_config( def load_full_training_config(

View File

@ -7,6 +7,7 @@ from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.metrics import (
ClassificationAccuracy, ClassificationAccuracy,
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
@ -82,7 +83,9 @@ def train(
logger.info("Training complete.") logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig
) -> List[Callback]:
return [ return [
ModelCheckpoint( ModelCheckpoint(
dirpath="outputs/checkpoints", dirpath="outputs/checkpoints",
@ -96,7 +99,8 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
class_names=targets.class_names class_names=targets.class_names
), ),
ClassificationAccuracy(class_names=targets.class_names), ClassificationAccuracy(class_names=targets.class_names),
] ],
match_config=config.match,
), ),
] ]
@ -113,7 +117,7 @@ def build_trainer(
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
logger=build_logger(conf.train.logger), logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets), callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
) )