From e383a33cbfb21f0868d5866c7e34adc92d9d9cb9 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 3 Apr 2025 16:49:58 +0100 Subject: [PATCH] Improved train module --- batdetect2/train/__init__.py | 48 +++++++++++ batdetect2/train/augmentations.py | 40 ++++++--- batdetect2/train/dataset.py | 45 +++++----- batdetect2/train/labels.py | 20 ++++- batdetect2/train/losses.py | 8 +- batdetect2/train/preprocess.py | 46 +++++++---- batdetect2/train/targets.py | 8 ++ batdetect2/train/train.py | 132 +++++++++++++----------------- 8 files changed, 222 insertions(+), 125 deletions(-) diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index e69de29..f364c30 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -0,0 +1,48 @@ +from batdetect2.train.augmentations import ( + AugmentationsConfig, + add_echo, + augment_example, + load_agumentation_config, + mask_frequency, + mask_time, + mix_examples, + scale_volume, + select_subclip, + warp_spectrogram, +) +from batdetect2.train.config import TrainingConfig, load_train_config +from batdetect2.train.dataset import ( + LabeledDataset, + SubclipConfig, + TrainExample, +) +from batdetect2.train.labels import LabelConfig, load_label_config +from batdetect2.train.preprocess import preprocess_annotations +from batdetect2.train.targets import TargetConfig, load_target_config +from batdetect2.train.train import TrainerConfig, load_trainer_config, train + +__all__ = [ + "AugmentationsConfig", + "LabelConfig", + "LabeledDataset", + "SubclipConfig", + "TargetConfig", + "TrainExample", + "TrainerConfig", + "TrainingConfig", + "add_echo", + "augment_example", + "load_agumentation_config", + "load_label_config", + "load_target_config", + "load_train_config", + "load_trainer_config", + "mask_frequency", + "mask_time", + "mix_examples", + "preprocess_annotations", + "scale_volume", + "select_subclip", + "train", + "warp_spectrogram", +] diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index 317e07f..3686701 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -3,16 +3,30 @@ from typing import Callable, Optional, Union import numpy as np import xarray as xr from pydantic import Field -from soundevent import arrays +from soundevent import arrays, data -from batdetect2.configs import BaseConfig +from batdetect2.configs import BaseConfig, load_config from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram from batdetect2.preprocess.arrays import adjust_width Augmentation = Callable[[xr.Dataset], xr.Dataset] -class AugmentationConfig(BaseConfig): +__all__ = [ + "AugmentationsConfig", + "load_agumentation_config", + "select_subclip", + "mix_examples", + "add_echo", + "scale_volume", + "warp_spectrogram", + "mask_time", + "mask_frequency", + "augment_example", +] + + +class BaseAugmentationConfig(BaseConfig): enable: bool = True probability: float = 0.2 @@ -63,7 +77,7 @@ def select_subclip( ) -class MixAugmentationConfig(AugmentationConfig): +class MixAugmentationConfig(BaseAugmentationConfig): min_weight: float = 0.3 max_weight: float = 0.7 @@ -133,7 +147,7 @@ def mix_examples( ) -class EchoAugmentationConfig(AugmentationConfig): +class EchoAugmentationConfig(BaseAugmentationConfig): max_delay: float = 0.005 min_weight: float = 0.0 max_weight: float = 1.0 @@ -188,7 +202,7 @@ def add_echo( ) -class VolumeAugmentationConfig(AugmentationConfig): +class VolumeAugmentationConfig(BaseAugmentationConfig): min_scaling: float = 0.0 max_scaling: float = 2.0 @@ -206,7 +220,7 @@ def scale_volume( return example.assign(spectrogram=example["spectrogram"] * factor) -class WarpAugmentationConfig(AugmentationConfig): +class WarpAugmentationConfig(BaseAugmentationConfig): delta: float = 0.04 @@ -294,7 +308,7 @@ def mask_axis( return array.where(condition, other=mask_value) -class TimeMaskAugmentationConfig(AugmentationConfig): +class TimeMaskAugmentationConfig(BaseAugmentationConfig): max_perc: float = 0.05 max_masks: int = 3 @@ -318,7 +332,7 @@ def mask_time( return example.assign(spectrogram=spectrogram) -class FrequencyMaskAugmentationConfig(AugmentationConfig): +class FrequencyMaskAugmentationConfig(BaseAugmentationConfig): max_perc: float = 0.10 max_masks: int = 3 @@ -361,7 +375,13 @@ class AugmentationsConfig(BaseConfig): ) -def should_apply(config: AugmentationConfig) -> bool: +def load_agumentation_config( + path: data.PathLike, field: Optional[str] = None +) -> AugmentationsConfig: + return load_config(path, schema=AugmentationsConfig, field=field) + + +def should_apply(config: BaseAugmentationConfig) -> bool: if not config.enable: return False diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index b444f42..a5c754d 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -43,28 +43,23 @@ class SubclipConfig(BaseConfig): class DatasetConfig(BaseConfig): subclip: SubclipConfig = Field(default_factory=SubclipConfig) - preprocessing: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) augmentation: AugmentationsConfig = Field( default_factory=AugmentationsConfig ) class LabeledDataset(Dataset): - config: DatasetConfig - def __init__( self, filenames: Sequence[PathLike], - augment: bool = False, - subclip: bool = False, - config: Optional[DatasetConfig] = None, + subclip: Optional[SubclipConfig] = None, + augmentation: Optional[AugmentationsConfig] = None, + preprocessing: Optional[PreprocessingConfig] = None, ): self.filenames = filenames - self.augment = augment self.subclip = subclip - self.config = config or DatasetConfig() + self.augmentation = augmentation + self.preprocessing = preprocessing or PreprocessingConfig() def __len__(self): return len(self.filenames) @@ -75,16 +70,16 @@ class LabeledDataset(Dataset): if self.subclip: dataset = select_subclip( dataset, - duration=self.config.subclip.duration, - width=self.config.subclip.width, - random=self.config.subclip.random, + duration=self.subclip.duration, + width=self.subclip.width, + random=self.subclip.random, ) - if self.augment: + if self.augmentation: dataset = augment_example( dataset, - self.config.augmentation, - preprocessing_config=self.config.preprocessing, + self.augmentation, + preprocessing_config=self.preprocessing, others=self.get_random_example, ) @@ -101,15 +96,15 @@ class LabeledDataset(Dataset): cls, directory: PathLike, extension: str = ".nc", - config: Optional[DatasetConfig] = None, - augment: bool = False, - subclip: bool = False, + subclip: Optional[SubclipConfig] = None, + augmentation: Optional[AugmentationsConfig] = None, + preprocessing: Optional[PreprocessingConfig] = None, ): return cls( get_files(directory, extension), - config=config, - augment=augment, subclip=subclip, + augmentation=augmentation, + preprocessing=preprocessing, ) def get_random_example(self) -> xr.Dataset: @@ -119,9 +114,9 @@ class LabeledDataset(Dataset): if self.subclip: dataset = select_subclip( dataset, - duration=self.config.subclip.duration, - width=self.config.subclip.width, - random=self.config.subclip.random, + duration=self.subclip.duration, + width=self.subclip.width, + random=self.subclip.random, ) return dataset @@ -144,7 +139,7 @@ class LabeledDataset(Dataset): if not self.subclip: return tensor - width = self.config.subclip.width + width = self.subclip.width return adjust_width(tensor, width) diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index 326138f..9f339b7 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -3,13 +3,19 @@ from typing import Callable, List, Optional, Sequence, Tuple import numpy as np import xarray as xr +from pydantic import Field from scipy.ndimage import gaussian_filter from soundevent import arrays, data, geometry from soundevent.geometry.operations import Positions -from batdetect2.configs import BaseConfig +from batdetect2.configs import BaseConfig, load_config -__all__ = ["generate_heatmaps"] +__all__ = [ + "HeatmapsConfig", + "LabelConfig", + "generate_heatmaps", + "load_label_config", +] class HeatmapsConfig(BaseConfig): @@ -19,6 +25,10 @@ class HeatmapsConfig(BaseConfig): frequency_scale: float = 1 / 859.375 +class LabelConfig(BaseConfig): + heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig) + + def generate_heatmaps( sound_events: Sequence[data.SoundEventAnnotation], spec: xr.DataArray, @@ -132,3 +142,9 @@ def generate_heatmaps( ).fillna(0.0) return detection_heatmap, class_heatmap, size_heatmap + + +def load_label_config( + path: data.PathLike, field: Optional[str] = None +) -> LabelConfig: + return load_config(path, schema=LabelConfig, field=field) diff --git a/batdetect2/train/losses.py b/batdetect2/train/losses.py index 22ac986..74739d1 100644 --- a/batdetect2/train/losses.py +++ b/batdetect2/train/losses.py @@ -6,9 +6,15 @@ from pydantic import Field from batdetect2.configs import BaseConfig from batdetect2.models.typing import ModelOutput -from batdetect2.plot import detection from batdetect2.train.dataset import TrainExample +__all__ = [ + "bbox_size_loss", + "compute_loss", + "focal_loss", + "mse_loss", +] + class SizeLossConfig(BaseConfig): weight: float = 0.1 diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 416d805..92ae3ca 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -17,7 +17,7 @@ from batdetect2.preprocess import ( compute_spectrogram, load_clip_audio, ) -from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps +from batdetect2.train.labels import LabelConfig, generate_heatmaps from batdetect2.train.targets import ( TargetConfig, build_encoder, @@ -30,6 +30,9 @@ FilenameFn = Callable[[data.ClipAnnotation], str] __all__ = [ "preprocess_annotations", + "preprocess_single_annotation", + "generate_train_example", + "TrainPreprocessingConfig", ] @@ -38,15 +41,21 @@ class TrainPreprocessingConfig(BaseConfig): default_factory=PreprocessingConfig ) target: TargetConfig = Field(default_factory=TargetConfig) - heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig) + labels: LabelConfig = Field(default_factory=LabelConfig) def generate_train_example( clip_annotation: data.ClipAnnotation, - config: Optional[TrainPreprocessingConfig] = None, + preprocessing_config: Optional[PreprocessingConfig] = None, + target_config: Optional[TargetConfig] = None, + label_config: Optional[LabelConfig] = None, ) -> xr.Dataset: """Generate a training example.""" - config = config or TrainPreprocessingConfig() + config = TrainPreprocessingConfig( + preprocessing=preprocessing_config or PreprocessingConfig(), + target=target_config or TargetConfig(), + labels=label_config or LabelConfig(), + ) wave = load_clip_audio( clip_annotation.clip, @@ -78,10 +87,10 @@ def generate_train_example( spectrogram, class_names, encoder, - target_sigma=config.heatmaps.sigma, - position=config.heatmaps.position, - time_scale=config.heatmaps.time_scale, - frequency_scale=config.heatmaps.frequency_scale, + target_sigma=config.labels.heatmaps.sigma, + position=config.labels.heatmaps.position, + time_scale=config.labels.heatmaps.time_scale, + frequency_scale=config.labels.heatmaps.frequency_scale, ) dataset = xr.Dataset( @@ -133,14 +142,14 @@ def preprocess_annotations( output_dir: PathLike, filename_fn: FilenameFn = _get_filename, replace: bool = False, - config: Optional[TrainPreprocessingConfig] = None, + preprocessing_config: Optional[PreprocessingConfig] = None, + target_config: Optional[TargetConfig] = None, + label_config: Optional[LabelConfig] = None, max_workers: Optional[int] = None, ) -> None: """Preprocess annotations and save to disk.""" output_dir = Path(output_dir) - config = config or TrainPreprocessingConfig() - if not output_dir.is_dir(): output_dir.mkdir(parents=True) @@ -151,9 +160,11 @@ def preprocess_annotations( partial( preprocess_single_annotation, output_dir=output_dir, - config=config, filename_fn=filename_fn, replace=replace, + preprocessing_config=preprocessing_config, + target_config=target_config, + label_config=label_config, ), clip_annotations, ), @@ -165,7 +176,9 @@ def preprocess_annotations( def preprocess_single_annotation( clip_annotation: data.ClipAnnotation, output_dir: PathLike, - config: TrainPreprocessingConfig, + preprocessing_config: Optional[PreprocessingConfig] = None, + target_config: Optional[TargetConfig] = None, + label_config: Optional[LabelConfig] = None, filename_fn: FilenameFn = _get_filename, replace: bool = False, ) -> None: @@ -181,7 +194,12 @@ def preprocess_single_annotation( path.unlink() try: - sample = generate_train_example(clip_annotation, config=config) + sample = generate_train_example( + clip_annotation, + preprocessing_config=preprocessing_config, + target_config=target_config, + label_config=label_config, + ) except Exception as error: raise RuntimeError( f"Failed to process annotation: {clip_annotation.uuid}" diff --git a/batdetect2/train/targets.py b/batdetect2/train/targets.py index c0771fa..d1d1143 100644 --- a/batdetect2/train/targets.py +++ b/batdetect2/train/targets.py @@ -9,6 +9,14 @@ from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.terms import TagInfo, get_tag_from_info +__all__ = [ + "TargetConfig", + "load_target_config", + "build_encoder", + "build_decoder", + "filter_sound_event", +] + class ReplaceConfig(BaseConfig): """Configuration for replacing tags.""" diff --git a/batdetect2/train/train.py b/batdetect2/train/train.py index b41f230..6ab8a6c 100644 --- a/batdetect2/train/train.py +++ b/batdetect2/train/train.py @@ -1,82 +1,68 @@ -from typing import Callable, NamedTuple, Optional +from typing import Optional, Union -import torch -from soundevent import data -from torch.optim import Adam -from torch.optim.lr_scheduler import CosineAnnealingLR +from lightning import LightningModule +from lightning.pytorch import Trainer +from soundevent.data import PathLike from torch.utils.data import DataLoader -from batdetect2.data.datasets import ClipAnnotationDataset -from batdetect2.models.typing import DetectionModel +from batdetect2.configs import BaseConfig, load_config +from batdetect2.train.dataset import LabeledDataset + +__all__ = [ + "train", + "TrainerConfig", + "load_trainer_config", +] -class TrainInputs(NamedTuple): - spec: torch.Tensor - detection_heatmap: torch.Tensor - class_heatmap: torch.Tensor - size_heatmap: torch.Tensor +class TrainerConfig(BaseConfig): + accelerator: str = "auto" + accumulate_grad_batches: int = 1 + deterministic: bool = True + check_val_every_n_epoch: int = 1 + devices: Union[str, int] = "auto" + enable_checkpointing: bool = True + gradient_clip_val: Optional[float] = None + limit_train_batches: Optional[Union[int, float]] = None + limit_test_batches: Optional[Union[int, float]] = None + limit_val_batches: Optional[Union[int, float]] = None + log_every_n_steps: Optional[int] = None + max_epochs: Optional[int] = None + min_epochs: Optional[int] = 100 + max_steps: Optional[int] = None + min_steps: Optional[int] = None + max_time: Optional[str] = None + precision: Optional[str] = None + reload_dataloaders_every_n_epochs: Optional[int] = None + val_check_interval: Optional[Union[int, float]] = None -def train_loop( - model: DetectionModel, - train_dataset: ClipAnnotationDataset[TrainInputs], - validation_dataset: ClipAnnotationDataset[TrainInputs], - device: Optional[torch.device] = None, - num_epochs: int = 100, - learning_rate: float = 1e-4, +def load_trainer_config(path: PathLike, field: Optional[str] = None): + return load_config(path, schema=TrainerConfig, field=field) + + +def train( + module: LightningModule, + train_dataset: LabeledDataset, + trainer_config: Optional[TrainerConfig] = None, + dev_run: bool = False, + overfit_batches: bool = False, + profiler: Optional[str] = None, ): - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) - validation_loader = DataLoader(validation_dataset, batch_size=32) - - model.to(device) - - optimizer = Adam(model.parameters(), lr=learning_rate) - scheduler = CosineAnnealingLR( - optimizer, - num_epochs * len(train_loader), + trainer_config = trainer_config or TrainerConfig() + trainer = Trainer( + **trainer_config.model_dump( + exclude_unset=True, + exclude_none=True, + ), + fast_dev_run=dev_run, + overfit_batches=overfit_batches, + profiler=profiler, ) - - for epoch in range(num_epochs): - train_loss = train_single_epoch( - model, - train_loader, - optimizer, - device, - scheduler, - ) - - -def train_single_epoch( - model: DetectionModel, - train_loader: DataLoader, - optimizer: Adam, - device: torch.device, - scheduler: CosineAnnealingLR, -): - model.train() - train_loss = tu.AverageMeter() - - for batch in train_loader: - optimizer.zero_grad() - - spec = batch.spec.to(device) - detection_heatmap = batch.detection_heatmap.to(device) - class_heatmap = batch.class_heatmap.to(device) - size_heatmap = batch.size_heatmap.to(device) - - outputs = model(spec) - - loss = loss_fun( - outputs, - gt_det, - gt_size, - gt_class, - det_criterion, - params, - class_inv_freq, - ) - - train_loss.update(loss.item(), data.shape[0]) - loss.backward() - optimizer.step() - scheduler.step() + train_loader = DataLoader( + train_dataset, + batch_size=module.config.train.batch_size, + shuffle=True, + num_workers=7, + ) + trainer.fit(module, train_dataloaders=train_loader)