From e65d5a68464cb1119b5cfd523e24599c057caa47 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 10 Sep 2025 21:09:51 +0100 Subject: [PATCH] Added more clipping options for validation --- src/batdetect2/train/clips.py | 70 +++++++++++++++++---- src/batdetect2/train/config.py | 47 ++++++++------ src/batdetect2/train/dataset.py | 7 ++- src/batdetect2/train/train.py | 108 +++++++++++++++++--------------- 4 files changed, 147 insertions(+), 85 deletions(-) diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 7578aa7..0ebd203 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -1,25 +1,32 @@ -from typing import List, Optional +from typing import Annotated, List, Literal, Optional, Union import numpy as np from loguru import logger +from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds, intervals_overlap from batdetect2.configs import BaseConfig +from batdetect2.data._core import Registry from batdetect2.typing import ClipperProtocol DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_MAX_EMPTY_CLIP = 0.1 -class ClipingConfig(BaseConfig): +registry: Registry[ClipperProtocol] = Registry("clipper") + + +class RandomClipConfig(BaseConfig): + name: Literal["random_subclip"] = "random_subclip" duration: float = DEFAULT_TRAIN_CLIP_DURATION random: bool = True max_empty: float = DEFAULT_MAX_EMPTY_CLIP min_sound_event_overlap: float = 0 -class Clipper: +@registry.register(RandomClipConfig) +class RandomClip: def __init__( self, duration: float = 0.5, @@ -45,6 +52,14 @@ class Clipper: min_sound_event_overlap=self.min_sound_event_overlap, ) + @classmethod + def from_config(cls, config: RandomClipConfig): + return cls( + duration=config.duration, + max_empty=config.max_empty, + min_sound_event_overlap=config.min_sound_event_overlap, + ) + def get_subclip_annotation( clip_annotation: data.ClipAnnotation, @@ -136,17 +151,46 @@ def select_sound_event_annotations( return selected -def build_clipper( - config: Optional[ClipingConfig] = None, - random: Optional[bool] = None, -) -> ClipperProtocol: - config = config or ClipingConfig() +class PaddedClipConfig(BaseConfig): + name: Literal["whole_audio_padded"] = "whole_audio_padded" + chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION + + +@registry.register(PaddedClipConfig) +class PaddedClip: + def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION): + self.duration = duration + + def __call__( + self, + clip_annotation: data.ClipAnnotation, + ) -> data.ClipAnnotation: + clip = clip_annotation.clip + duration = clip.duration + + target_duration = self.duration * np.ceil(duration / self.duration) + clip = clip.model_copy( + update=dict( + end_time=clip.start_time + target_duration, + ) + ) + return clip_annotation.model_copy(update=dict(clip=clip)) + + @classmethod + def from_config(cls, config: PaddedClipConfig): + return cls(duration=config.chunk_size) + + +ClipConfig = Annotated[ + Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") +] + + +def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol: + config = config or RandomClipConfig() + logger.opt(lazy=True).debug( "Building clipper with config: \n{}", lambda: config.to_yaml_string(), ) - return Clipper( - duration=config.duration, - max_empty=config.max_empty, - random=config.random if random else False, - ) + return registry.build(config) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 010ae63..bcebfcb 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -10,7 +10,11 @@ from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, ) -from batdetect2.train.clips import ClipingConfig +from batdetect2.train.clips import ( + ClipConfig, + PaddedClipConfig, + RandomClipConfig, +) from batdetect2.train.labels import LabelConfig from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.losses import LossConfig @@ -44,34 +48,39 @@ class PLTrainerConfig(BaseConfig): val_check_interval: Optional[Union[int, float]] = None -class DataLoaderConfig(BaseConfig): - batch_size: int = 8 - shuffle: bool = False +class ValLoaderConfig(BaseConfig): num_workers: int = 0 - -DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True) -DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False) - - -class LoadersConfig(BaseConfig): - train: DataLoaderConfig = Field( - default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy() + clipping_strategy: ClipConfig = Field( + default_factory=lambda: RandomClipConfig() ) - val: DataLoaderConfig = Field( - default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy() + + +class TrainLoaderConfig(BaseConfig): + num_workers: int = 0 + + batch_size: int = 8 + + shuffle: bool = False + + augmentations: AugmentationsConfig = Field( + default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy() + ) + + clipping_strategy: ClipConfig = Field( + default_factory=lambda: PaddedClipConfig() ) class TrainingConfig(BaseConfig): learning_rate: float = 1e-3 t_max: int = 100 - dataloaders: LoadersConfig = Field(default_factory=LoadersConfig) + + train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig) + val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) + loss: LossConfig = Field(default_factory=LossConfig) - augmentations: AugmentationsConfig = Field( - default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy() - ) - cliping: ClipingConfig = Field(default_factory=ClipingConfig) + cliping: RandomClipConfig = Field(default_factory=RandomClipConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) labels: LabelConfig = Field(default_factory=LabelConfig) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 4add71f..865c1aa 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -87,6 +87,7 @@ class ValidationDataset(Dataset): audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, + clipper: Optional[ClipperProtocol] = None, audio_dir: Optional[data.PathLike] = None, ): self.clip_annotations = clip_annotations @@ -94,14 +95,18 @@ class ValidationDataset(Dataset): self.preprocessor = preprocessor self.audio_loader = audio_loader self.audio_dir = audio_dir + self.clipper = clipper def __len__(self): return len(self.clip_annotations) def __getitem__(self, idx) -> TrainExample: clip_annotation = self.clip_annotations[idx] - clip = clip_annotation.clip + if self.clipper is not None: + clip_annotation = self.clipper(clip_annotation) + + clip = clip_annotation.clip wav = self.audio_loader.load_clip( clip_annotation.clip, audio_dir=self.audio_dir, diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 7d0f21d..dfb4770 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -24,7 +24,11 @@ from batdetect2.train.augmentations import ( ) from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper -from batdetect2.train.config import FullTrainingConfig, TrainingConfig +from batdetect2.train.config import ( + FullTrainingConfig, + TrainLoaderConfig, + ValLoaderConfig, +) from batdetect2.train.dataset import TrainingDataset, ValidationDataset from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import TrainingModule @@ -85,7 +89,7 @@ def train( audio_loader=audio_loader, labeller=labeller, preprocessor=preprocessor, - config=config.train, + config=config.train.train_loader, num_workers=train_workers, ) @@ -95,7 +99,7 @@ def train( audio_loader=audio_loader, labeller=labeller, preprocessor=preprocessor, - config=config.train, + config=config.train.val_loader, num_workers=val_workers, ) if val_annotations is not None @@ -225,10 +229,17 @@ def build_train_loader( audio_loader: AudioLoader, labeller: ClipLabeller, preprocessor: PreprocessorProtocol, - config: Optional[TrainingConfig] = None, + config: Optional[TrainLoaderConfig] = None, num_workers: Optional[int] = None, ) -> DataLoader: - config = config or TrainingConfig() + config = config or TrainLoaderConfig() + + logger.info("Building training data loader...") + logger.opt(lazy=True).debug( + "Training data loader config: \n{config}", + config=lambda: config.to_yaml_string(exclude_none=True), + ) + train_dataset = build_train_dataset( clip_annotations, audio_loader=audio_loader, @@ -237,17 +248,11 @@ def build_train_loader( config=config, ) - logger.info("Building training data loader...") - loader_conf = config.dataloaders.train - logger.opt(lazy=True).debug( - "Training data loader config: \n{config}", - config=lambda: loader_conf.to_yaml_string(exclude_none=True), - ) - num_workers = num_workers or loader_conf.num_workers + num_workers = num_workers or config.num_workers return DataLoader( train_dataset, - batch_size=loader_conf.batch_size, - shuffle=loader_conf.shuffle, + batch_size=config.batch_size, + shuffle=config.shuffle, num_workers=num_workers, collate_fn=_collate_fn, ) @@ -258,11 +263,15 @@ def build_val_loader( audio_loader: AudioLoader, labeller: ClipLabeller, preprocessor: PreprocessorProtocol, - config: Optional[TrainingConfig] = None, + config: Optional[ValLoaderConfig] = None, num_workers: Optional[int] = None, ): logger.info("Building validation data loader...") - config = config or TrainingConfig() + config = config or ValLoaderConfig() + logger.opt(lazy=True).debug( + "Validation data loader config: \n{config}", + config=lambda: config.to_yaml_string(exclude_none=True), + ) val_dataset = build_val_dataset( clip_annotations, @@ -271,56 +280,28 @@ def build_val_loader( preprocessor=preprocessor, config=config, ) - loader_conf = config.dataloaders.val - logger.opt(lazy=True).debug( - "Validation data loader config: \n{config}", - config=lambda: loader_conf.to_yaml_string(exclude_none=True), - ) - num_workers = num_workers or loader_conf.num_workers + + num_workers = num_workers or config.num_workers return DataLoader( val_dataset, batch_size=1, - shuffle=loader_conf.shuffle, + shuffle=False, num_workers=num_workers, collate_fn=_collate_fn, ) -def _collate_fn(batch: List[TrainExample]) -> TrainExample: - max_width = max(item.spec.shape[-1] for item in batch) - return TrainExample( - spec=torch.stack( - [adjust_width(item.spec, max_width) for item in batch] - ), - detection_heatmap=torch.stack( - [adjust_width(item.detection_heatmap, max_width) for item in batch] - ), - size_heatmap=torch.stack( - [adjust_width(item.size_heatmap, max_width) for item in batch] - ), - class_heatmap=torch.stack( - [adjust_width(item.class_heatmap, max_width) for item in batch] - ), - idx=torch.stack([item.idx for item in batch]), - start_time=torch.stack([item.start_time for item in batch]), - end_time=torch.stack([item.end_time for item in batch]), - ) - - def build_train_dataset( clip_annotations: Sequence[data.ClipAnnotation], audio_loader: AudioLoader, labeller: ClipLabeller, preprocessor: PreprocessorProtocol, - config: Optional[TrainingConfig] = None, + config: Optional[TrainLoaderConfig] = None, ) -> TrainingDataset: logger.info("Building training dataset...") - config = config or TrainingConfig() + config = config or TrainLoaderConfig() - clipper = build_clipper( - config=config.cliping, - random=True, - ) + clipper = build_clipper(config=config.clipping_strategy) random_example_source = RandomAudioSource( clip_annotations, @@ -354,14 +335,37 @@ def build_val_dataset( audio_loader: AudioLoader, labeller: ClipLabeller, preprocessor: PreprocessorProtocol, - config: Optional[TrainingConfig] = None, + config: Optional[ValLoaderConfig] = None, ) -> ValidationDataset: logger.info("Building validation dataset...") - config = config or TrainingConfig() + config = config or ValLoaderConfig() + clipper = build_clipper(config.clipping_strategy) return ValidationDataset( clip_annotations, audio_loader=audio_loader, labeller=labeller, preprocessor=preprocessor, + clipper=clipper, + ) + + +def _collate_fn(batch: List[TrainExample]) -> TrainExample: + max_width = max(item.spec.shape[-1] for item in batch) + return TrainExample( + spec=torch.stack( + [adjust_width(item.spec, max_width) for item in batch] + ), + detection_heatmap=torch.stack( + [adjust_width(item.detection_heatmap, max_width) for item in batch] + ), + size_heatmap=torch.stack( + [adjust_width(item.size_heatmap, max_width) for item in batch] + ), + class_heatmap=torch.stack( + [adjust_width(item.class_heatmap, max_width) for item in batch] + ), + idx=torch.stack([item.idx for item in batch]), + start_time=torch.stack([item.start_time for item in batch]), + end_time=torch.stack([item.end_time for item in batch]), )