diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index c8f244b..44e372e 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, field -from typing import List, Literal, Optional, Tuple +from typing import List, Literal, Optional, Protocol, Tuple import numpy as np from soundevent import data @@ -23,6 +23,16 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"] """The geometry representation to use for matching.""" +class AffinityFunction(Protocol): + def __call__( + self, + geometry1: data.Geometry, + geometry2: data.Geometry, + time_buffer: float = 0.01, + freq_buffer: float = 1000, + ) -> float: ... + + class MatchConfig(BaseConfig): """Configuration for matching geometries. @@ -74,6 +84,65 @@ _geometry_cast_functions: Mapping[ } +def _timestamp_affinity( + geometry1: data.Geometry, + geometry2: data.Geometry, + time_buffer: float = 0.01, + freq_buffer: float = 1000, +) -> float: + assert isinstance(geometry1, data.TimeStamp) + assert isinstance(geometry2, data.TimeStamp) + + start_time1 = geometry1.coordinates + start_time2 = geometry2.coordinates + + a = min(start_time1, start_time2) + b = max(start_time1, start_time2) + + if b - a >= 2 * time_buffer: + return 0 + + intersection = a - b + 2 * time_buffer + union = b - a + 2 * time_buffer + return intersection / union + + +def _interval_affinity( + geometry1: data.Geometry, + geometry2: data.Geometry, + time_buffer: float = 0.01, + freq_buffer: float = 1000, +) -> float: + assert isinstance(geometry1, data.TimeInterval) + assert isinstance(geometry2, data.TimeInterval) + + start_time1, end_time1 = geometry1.coordinates + start_time2, end_time2 = geometry1.coordinates + + start_time1 -= time_buffer + start_time2 -= time_buffer + end_time1 += time_buffer + end_time2 += time_buffer + + intersection = max( + 0, min(end_time1, end_time2) - max(start_time1, start_time2) + ) + union = ( + (end_time1 - start_time1) + (end_time2 - start_time2) - intersection + ) + + if union == 0: + return 0 + + return intersection / union + + +_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = { + "timestamp": _timestamp_affinity, + "interval": _interval_affinity, +} + + def match_geometries( source: List[data.Geometry], target: List[data.Geometry], @@ -81,6 +150,10 @@ def match_geometries( scores: Optional[List[float]] = None, ) -> Iterable[Tuple[Optional[int], Optional[int], float]]: geometry_cast = _geometry_cast_functions[config.geometry] + affinity_function = _affinity_functions.get( + config.geometry, + compute_affinity, + ) if config.strategy == "optimal": return optimal_match( @@ -98,6 +171,7 @@ def match_geometries( time_buffer=config.time_buffer, freq_buffer=config.frequency_buffer, affinity_threshold=config.affinity_threshold, + affinity_function=affinity_function, scores=scores, ) @@ -111,6 +185,7 @@ def greedy_match( target: List[data.Geometry], scores: Optional[List[float]] = None, affinity_threshold: float = 0.5, + affinity_function: AffinityFunction = compute_affinity, time_buffer: float = 0.001, freq_buffer: float = 1000, ) -> Iterable[Tuple[Optional[int], Optional[int], float]]: @@ -168,7 +243,7 @@ def greedy_match( affinities = np.array( [ - compute_affinity( + affinity_function( source_geometry, target_geometry, time_buffer=time_buffer, diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 376c08c..a24b187 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -14,7 +14,7 @@ from batdetect2.evaluate.metrics import ( ClassificationMeanAveragePrecision, DetectionAveragePrecision, ) -from batdetect2.models import build_model +from batdetect2.models import Model, build_model from batdetect2.train.augmentations import ( RandomExampleSource, build_augmentations, @@ -55,17 +55,13 @@ def train( ): config = config or FullTrainingConfig() - if model_path is not None: - logger.debug("Loading model from: {path}", path=model_path) - module = TrainingModule.load_from_checkpoint(model_path) # type: ignore - else: - module = build_training_module(config) + model = build_model(config=config) - trainer = build_trainer(config, targets=module.model.targets) + trainer = build_trainer(config, targets=model.targets) train_dataloader = build_train_loader( train_examples, - preprocessor=module.model.preprocessor, + preprocessor=model.preprocessor, config=config.train, num_workers=train_workers, ) @@ -73,7 +69,7 @@ def train( val_dataloader = ( build_val_loader( val_examples, - preprocessor=module.model.preprocessor, + preprocessor=model.preprocessor, config=config.train, num_workers=val_workers, ) @@ -81,6 +77,16 @@ def train( else None ) + if model_path is not None: + logger.debug("Loading model from: {path}", path=model_path) + module = TrainingModule.load_from_checkpoint(model_path) # type: ignore + else: + module = build_training_module( + model, + config, + batches_per_epoch=len(train_dataloader), + ) + logger.info("Starting main training loop...") trainer.fit( module, @@ -90,14 +96,17 @@ def train( logger.info("Training complete.") -def build_training_module(config: FullTrainingConfig) -> TrainingModule: - model = build_model(config=config) +def build_training_module( + model: Model, + config: FullTrainingConfig, + batches_per_epoch: int, +) -> TrainingModule: loss = build_loss(config=config.train.loss) return TrainingModule( model=model, loss=loss, learning_rate=config.train.learning_rate, - t_max=config.train.t_max, + t_max=config.train.t_max * batches_per_epoch, )