From 7c89e82579f32a1ba530cc8e7e71883ce7c90e8b Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 22 Apr 2025 00:36:34 +0100 Subject: [PATCH] Fixing imports after restructuring --- batdetect2/modules.py | 10 +- batdetect2/train/augmentations.py | 757 +++++++++++++++++++++++++----- batdetect2/train/callbacks.py | 39 +- batdetect2/train/dataset.py | 16 +- batdetect2/train/labels.py | 333 ++++--------- batdetect2/train/preprocess.py | 252 ++++++---- batdetect2/train/types.py | 48 ++ batdetect2/types.py | 6 +- tests/test_models/test_inputs.py | 34 -- tests/test_train/test_labels.py | 2 +- 10 files changed, 991 insertions(+), 506 deletions(-) create mode 100644 batdetect2/train/types.py delete mode 100644 tests/test_models/test_inputs.py diff --git a/batdetect2/modules.py b/batdetect2/modules.py index b379771..c79a6a3 100644 --- a/batdetect2/modules.py +++ b/batdetect2/modules.py @@ -10,13 +10,13 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.configs import BaseConfig from batdetect2.models import ( + BackboneConfig, BBoxHead, ClassifierHead, - ModelConfig, ModelOutput, - build_architecture, + build_backbone, ) -from batdetect2.post_process import ( +from batdetect2.postprocess import ( PostprocessConfig, postprocess_model_outputs, ) @@ -37,7 +37,7 @@ __all__ = [ class ModuleConfig(BaseConfig): train: TrainingConfig = Field(default_factory=TrainingConfig) targets: TargetConfig = Field(default_factory=TargetConfig) - architecture: ModelConfig = Field(default_factory=ModelConfig) + architecture: BackboneConfig = Field(default_factory=BackboneConfig) preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) @@ -58,7 +58,7 @@ class DetectorModel(L.LightningModule): self.config = config or ModuleConfig() self.save_hyperparameters() - self.backbone = build_architecture(self.config.architecture) + self.backbone = build_model_backbone(self.config.architecture) self.classifier = ClassifierHead( num_classes=len(self.config.targets.classes), diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index 3686701..322d4ee 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -1,4 +1,30 @@ -from typing import Callable, Optional, Union +"""Applies data augmentation techniques to BatDetect2 training examples. + +This module provides various functions and configurable components for applying +data augmentation to the training examples (`xr.Dataset` containing audio, +spectrogram, and target heatmaps) generated by the +`batdetect2.train.preprocess` module. + +Data augmentation artificially increases the diversity of the training data by +applying random transformations, which generally helps improve the robustness +and generalization performance of trained models. + +Augmentations included: +- Time-based: `select_subclip`, `warp_spectrogram`, `mask_time`. +- Amplitude/Noise-based: `mix_examples`, `add_echo`, `scale_volume`, + `mask_frequency`. + +Some augmentations modify the audio waveform and require recomputing the +spectrogram using the `PreprocessorProtocol`, while others operate directly +on the spectrogram or target heatmaps. The entire augmentation pipeline can be +configured using the `AugmentationsConfig` class, specifying a sequence of +augmentation steps, each with its own parameters and application probability. +The `build_augmentations` function constructs the final augmentation callable +from this configuration. +""" + +from functools import partial +from typing import Annotated, Callable, List, Literal, Optional, Union import numpy as np import xarray as xr @@ -6,15 +32,14 @@ from pydantic import Field from soundevent import arrays, data from batdetect2.configs import BaseConfig, load_config -from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram -from batdetect2.preprocess.arrays import adjust_width - -Augmentation = Callable[[xr.Dataset], xr.Dataset] - +from batdetect2.preprocess import PreprocessorProtocol +from batdetect2.train.types import Augmentation +from batdetect2.utils.arrays import adjust_width __all__ = [ "AugmentationsConfig", - "load_agumentation_config", + "load_augmentation_config", + "build_augmentations", "select_subclip", "mix_examples", "add_echo", @@ -22,13 +47,21 @@ __all__ = [ "warp_spectrogram", "mask_time", "mask_frequency", - "augment_example", + "MixAugmentationConfig", + "EchoAugmentationConfig", + "VolumeAugmentationConfig", + "WarpAugmentationConfig", + "TimeMaskAugmentationConfig", + "FrequencyMaskAugmentationConfig", + "AugmentationConfig", + "ExampleSource", ] +ExampleSource = Callable[[], xr.Dataset] +"""Type alias for a function that returns a training example (`xr.Dataset`). -class BaseAugmentationConfig(BaseConfig): - enable: bool = True - probability: float = 0.2 +Used by the `mix_examples` augmentation to fetch another example to mix with. +""" def select_subclip( @@ -38,7 +71,47 @@ def select_subclip( width: Optional[int] = None, random: bool = False, ) -> xr.Dataset: - """Select a random subclip from a clip.""" + """Extract a sub-clip (time segment) from a training example dataset. + + Selects a portion of the 'time' dimension from all relevant DataArrays + (`audio`, `spectrogram`, `detection`, `class`, `size`) within the example + Dataset. The segment can be defined by a fixed start time and + duration/width, or a random start time can be chosen. + + Parameters + ---------- + example : xr.Dataset + The input training example containing 'audio', 'spectrogram', and + target heatmaps, all with compatible 'time' (or 'audio_time') + coordinates. + start_time : float, optional + Desired start time (seconds) of the subclip. If None and `random` is + False, starts from the beginning of the example. If None and `random` + is True, a random start time is chosen. + duration : float, optional + Desired duration (seconds) of the subclip. Either `duration` or `width` + must be provided. + width : int, optional + Desired width (number of time bins) of the subclip's + spectrogram/heatmaps. Either `duration` or `width` must be provided. If + both are given, `duration` takes precedence. + random : bool, default=False + If True and `start_time` is None, selects a random start time ensuring + the subclip fits within the original example's duration. + + Returns + ------- + xr.Dataset + A new dataset containing only the selected time segment. Coordinates + are adjusted accordingly. Returns the original example if the requested + subclip cannot be extracted (e.g., duration too long). + + Raises + ------ + ValueError + If neither `duration` nor `width` is provided, or if time coordinates + are missing or invalid. + """ step = arrays.get_dim_step(example, "time") # type: ignore start, stop = arrays.get_dim_range(example, "time") # type: ignore @@ -77,22 +150,62 @@ def select_subclip( ) -class MixAugmentationConfig(BaseAugmentationConfig): +class MixAugmentationConfig(BaseConfig): + """Configuration for MixUp augmentation (mixing two examples).""" + + augmentation_type: Literal["mix_audio"] = "mix_audio" + + probability: float = 0.2 + """Probability of applying this augmentation to an example.""" + min_weight: float = 0.3 + """Minimum mixing weight (lambda) applied to the primary example.""" + max_weight: float = 0.7 + """Maximum mixing weight (lambda) applied to the primary example.""" def mix_examples( example: xr.Dataset, other: xr.Dataset, + preprocessor: PreprocessorProtocol, weight: Optional[float] = None, min_weight: float = 0.3, max_weight: float = 0.7, - config: Optional[PreprocessingConfig] = None, ) -> xr.Dataset: - """Combine two audio clips.""" - config = config or PreprocessingConfig() + """Combine two training examples using MixUp augmentation. + Performs a weighted linear combination of the audio waveforms from two + examples (`example` and `other`). The spectrogram is then *recomputed* + from the mixed audio using the provided `preprocessor`. Target heatmaps + (detection, class) are combined by taking the element-wise maximum. Target + size heatmaps are combined by element-wise addition. + + Parameters + ---------- + example : xr.Dataset + The primary training example. + other : xr.Dataset + The second training example to mix with `example`. + preprocessor : PreprocessorProtocol + The preprocessor used to recompute the spectrogram from mixed audio. + Ensures consistency with original preprocessing. + weight : float, optional + The mixing weight (lambda) applied to the primary `example`. The weight + for `other` will be `(1 - weight)`. If None, a random weight is chosen + uniformly between `min_weight` and `max_weight`. + min_weight : float, default=0.3 + Minimum value for the random weight lambda. + max_weight : float, default=0.7 + Maximum value for the random weight lambda. Must be >= `min_weight`. + + Returns + ------- + xr.Dataset + A new dataset representing the mixed example, with combined audio, + recomputed spectrogram, and combined target heatmaps. Attributes from + the primary `example` are preserved. + """ if weight is None: weight = np.random.uniform(min_weight, max_weight) @@ -101,9 +214,8 @@ def mix_examples( combined = weight * audio1 + (1 - weight) * audio2 - spectrogram = compute_spectrogram( - combined.rename({"audio_time": "time"}), - config=config.spectrogram, + spectrogram = preprocessor.compute_spectrogram( + combined.rename({"audio_time": "time"}) ).data # NOTE: The subclip's spectrogram might be slightly longer than the @@ -147,7 +259,14 @@ def mix_examples( ) -class EchoAugmentationConfig(BaseAugmentationConfig): +class EchoAugmentationConfig(BaseConfig): + """Configuration for adding synthetic echo/reverb.""" + + augmentation_type: Literal["add_echo"] = "add_echo" + + probability: float = 0.2 + """Probability of applying this augmentation.""" + max_delay: float = 0.005 min_weight: float = 0.0 max_weight: float = 1.0 @@ -155,15 +274,45 @@ class EchoAugmentationConfig(BaseAugmentationConfig): def add_echo( example: xr.Dataset, + preprocessor: PreprocessorProtocol, delay: Optional[float] = None, weight: Optional[float] = None, min_weight: float = 0.1, max_weight: float = 1.0, max_delay: float = 0.005, - config: Optional[PreprocessingConfig] = None, ) -> xr.Dataset: - """Add a delay to the audio.""" - config = config or PreprocessingConfig() + """Add a synthetic echo to the audio waveform. + + Creates an echo by adding a delayed and attenuated version of the original + audio waveform back onto itself. The spectrogram is then recomputed from + the modified audio. Target heatmaps remain unchanged. + + Parameters + ---------- + example : xr.Dataset + The input training example containing 'audio', 'spectrogram', etc. + preprocessor : PreprocessorProtocol + Preprocessor used to recompute the spectrogram from the modified audio. + delay : float, optional + The delay time in seconds for the echo. If None, chosen randomly + between 0 and `max_delay`. + weight : float, optional + The relative amplitude (weight) of the echo compared to the original. + If None, chosen randomly between `min_weight` and `max_weight`. + min_weight : float, default=0.1 + Minimum value for the random echo weight. + max_weight : float, default=1.0 + Maximum value for the random echo weight. Must be >= `min_weight`. + max_delay : float, default=0.005 + Maximum value for the random echo delay in seconds. Must be >= 0. + + Returns + ------- + xr.Dataset + A new dataset with the echo added to the 'audio' variable and the + 'spectrogram' variable recomputed. Other variables (targets, attrs) + are copied from the original example. + """ if delay is None: delay = np.random.uniform(0, max_delay) @@ -176,9 +325,8 @@ def add_echo( audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0) audio = audio + weight * audio_delay - spectrogram = compute_spectrogram( + spectrogram = preprocessor.compute_spectrogram( audio.rename({"audio_time": "time"}), - config=config.spectrogram, ).data # NOTE: The subclip's spectrogram might be slightly longer than the @@ -202,7 +350,11 @@ def add_echo( ) -class VolumeAugmentationConfig(BaseAugmentationConfig): +class VolumeAugmentationConfig(BaseConfig): + """Configuration for random volume scaling of the spectrogram.""" + + augmentation_type: Literal["scale_volume"] = "scale_volume" + probability: float = 0.2 min_scaling: float = 0.0 max_scaling: float = 2.0 @@ -213,14 +365,44 @@ def scale_volume( max_scaling: float = 2, min_scaling: float = 0, ) -> xr.Dataset: - """Scale the volume of a spectrogram.""" + """Scale the amplitude of the spectrogram by a random factor. + + Multiplies the entire spectrogram DataArray by a scaling factor chosen + uniformly between `min_scaling` and `max_scaling`. This simulates changes + in recording volume or distance to the sound source directly in the + spectrogram domain. Audio and target heatmaps are unchanged. + + Parameters + ---------- + example : xr.Dataset + The input training example containing 'spectrogram'. + factor : float, optional + The scaling factor to apply. If None, chosen randomly between + `min_scaling` and `max_scaling`. + min_scaling : float, default=0.0 + Minimum value for the random scaling factor. Must be non-negative. + max_scaling : float, default=2.0 + Maximum value for the random scaling factor. Must be >= `min_scaling`. + + Returns + ------- + xr.Dataset + A new dataset with the 'spectrogram' variable scaled. + + Raises + ------ + ValueError + If `min_scaling` > `max_scaling` or if `min_scaling` is negative. + """ if factor is None: factor = np.random.uniform(min_scaling, max_scaling) return example.assign(spectrogram=example["spectrogram"] * factor) -class WarpAugmentationConfig(BaseAugmentationConfig): +class WarpAugmentationConfig(BaseConfig): + augmentation_type: Literal["warp"] = "warp" + probability: float = 0.2 delta: float = 0.04 @@ -229,11 +411,39 @@ def warp_spectrogram( factor: Optional[float] = None, delta: float = 0.04, ) -> xr.Dataset: - """Warp a spectrogram.""" + """Apply time warping by resampling the time axis. + + Stretches or compresses the time axis of the spectrogram and all target + heatmaps by a random `factor` (chosen between `1-delta` and `1+delta`). + Uses linear interpolation for the spectrogram and nearest neighbor for the + discrete-like target heatmaps. Updates the 'time' coordinate accordingly. + The audio waveform is not modified. + + Parameters + ---------- + example : xr.Dataset + Input training example with 'spectrogram', 'detection', 'class', + 'size', and 'time' coordinates. + factor : float, optional + The warping factor. If None, chosen randomly between `1-delta` and + `1+delta`. Values > 1 stretch time, < 1 compress time. + delta : float, default=0.04 + Controls the range `[1-delta, 1+delta]` for the random warp factor. + Must be >= 0 and < 1. + + Returns + ------- + xr.Dataset + A new dataset with time-warped spectrogram and target heatmaps, and + updated 'time' coordinates. + """ if factor is None: factor = np.random.uniform(1 - delta, 1 + delta) - start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore + start_time, end_time = arrays.get_dim_range( + example, # type: ignore + "time", + ) duration = end_time - start_time new_time = np.linspace( @@ -296,6 +506,39 @@ def mask_axis( end: float, mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean, ) -> xr.DataArray: + """Mask values along a specified dimension. + + Sets values in the DataArray to `mask_value` where the coordinate along + `dim` falls within the range [`start`, `end`). Values outside this range + are kept. Used as a helper for time/frequency masking. + + Parameters + ---------- + array : xr.DataArray + The input DataArray (e.g., spectrogram). + dim : str + The name of the dimension along which to mask + (e.g., "time", "frequency"). + start : float + The starting coordinate value for the mask range. + end : float + The ending coordinate value for the mask range (exclusive, typically). + Values >= start and < end will be masked. + mask_value : float or Callable[[xr.DataArray], float], default=np.mean + The value to use for masking. Can be a fixed float (e.g., 0.0) or a + callable (like `np.mean`, `np.median`) that computes the value from + the input `array`. + + Returns + ------- + xr.DataArray + The DataArray with the specified range along `dim` masked. + + Raises + ------ + ValueError + If `dim` is not found in the array's dimensions or coordinates. + """ if dim not in array.dims: raise ValueError(f"Axis {dim} not found in array") @@ -308,7 +551,9 @@ def mask_axis( return array.where(condition, other=mask_value) -class TimeMaskAugmentationConfig(BaseAugmentationConfig): +class TimeMaskAugmentationConfig(BaseConfig): + augmentation_type: Literal["mask_time"] = "mask_time" + probability: float = 0.2 max_perc: float = 0.05 max_masks: int = 3 @@ -318,9 +563,32 @@ def mask_time( max_perc: float = 0.05, max_mask: int = 3, ) -> xr.Dataset: - """Mask a random section of the time axis.""" + """Apply random time masking (SpecAugment) to the spectrogram. + + Randomly selects a number of time intervals (up to `max_masks`) and masks + (sets to the mean value) the spectrogram within those intervals. The width + of each mask is chosen randomly up to `max_perc` of the total duration. + Only the 'spectrogram' variable is modified. + + Parameters + ---------- + example : xr.Dataset + Input training example containing 'spectrogram' and 'time' coordinate. + max_perc : float, default=0.05 + Maximum width of a single mask as a fraction of total duration. + max_masks : int, default=3 + Maximum number of time masks to apply. + + Returns + ------- + xr.Dataset + Dataset with time masking applied to the 'spectrogram'. + """ num_masks = np.random.randint(1, max_mask + 1) - start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore + start_time, end_time = arrays.get_dim_range( + example, # type: ignore + "time", + ) spectrogram = example["spectrogram"] for _ in range(num_masks): @@ -332,7 +600,9 @@ def mask_time( return example.assign(spectrogram=spectrogram) -class FrequencyMaskAugmentationConfig(BaseAugmentationConfig): +class FrequencyMaskAugmentationConfig(BaseConfig): + augmentation_type: Literal["mask_freq"] = "mask_freq" + probability: float = 0.2 max_perc: float = 0.10 max_masks: int = 3 @@ -342,9 +612,38 @@ def mask_frequency( max_perc: float = 0.10, max_masks: int = 3, ) -> xr.Dataset: - """Mask a random section of the frequency axis.""" + """Apply random frequency masking (SpecAugment) to the spectrogram. + + Randomly selects a number of frequency intervals (up to `max_masks`) and + masks (sets to the mean value) the spectrogram within those intervals. The + height of each mask is chosen randomly up to `max_perc` of the total + frequency range. Only the 'spectrogram' variable is modified. + + Parameters + ---------- + example : xr.Dataset + Input training example containing 'spectrogram' and 'frequency' + coordinate. + max_perc : float, default=0.10 + Maximum height of a single mask as a fraction of total frequency range. + max_masks : int, default=3 + Maximum number of frequency masks to apply. + + Returns + ------- + xr.Dataset + Dataset with frequency masking applied to the 'spectrogram'. + + Raises + ------ + ValueError + If frequency coordinate info is missing or invalid. + """ num_masks = np.random.randint(1, max_masks + 1) - min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore + min_freq, max_freq = arrays.get_dim_range( + example, # type: ignore + "frequency", + ) spectrogram = example["spectrogram"] for _ in range(num_masks): @@ -356,88 +655,326 @@ def mask_frequency( return example.assign(spectrogram=spectrogram) +AugmentationConfig = Annotated[ + Union[ + MixAugmentationConfig, + EchoAugmentationConfig, + VolumeAugmentationConfig, + WarpAugmentationConfig, + FrequencyMaskAugmentationConfig, + TimeMaskAugmentationConfig, + ], + Field(discriminator="augmentation_type"), +] +"""Type alias for the discriminated union of individual augmentation config.""" + + class AugmentationsConfig(BaseConfig): - mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig) - echo: EchoAugmentationConfig = Field( - default_factory=EchoAugmentationConfig - ) - volume: VolumeAugmentationConfig = Field( - default_factory=VolumeAugmentationConfig - ) - warp: WarpAugmentationConfig = Field( - default_factory=WarpAugmentationConfig - ) - time_mask: TimeMaskAugmentationConfig = Field( - default_factory=TimeMaskAugmentationConfig - ) - frequency_mask: FrequencyMaskAugmentationConfig = Field( - default_factory=FrequencyMaskAugmentationConfig + """Configuration for a sequence of data augmentations. + + Attributes + ---------- + steps : List[AugmentationConfig] + An ordered list of configuration objects, each defining a single + augmentation step (e.g., MixAugmentationConfig, + TimeMaskAugmentationConfig). Each step's configuration must include an + `augmentation_type` field and a `probability` field, along with + type-specific parameters. The augmentations will be applied + (probabilistically) in the sequence defined by this list. + """ + + steps: List[AugmentationConfig] = Field(default_factory=list) + + +class MaybeApply: + """Applies an augmentation function probabilistically.""" + + def __init__( + self, + augmentation: Augmentation, + probability: float = 0.2, + ): + """Initialize the wrapper. + + Parameters + ---------- + augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset]) + The augmentation function to potentially apply. + probability : float, default=0.5 + The probability (0.0 to 1.0) of applying the augmentation. + """ + self.augmentation = augmentation + self.probability = probability + + def __call__(self, example: xr.Dataset) -> xr.Dataset: + """Apply the wrapped augmentation with configured probability. + + Parameters + ---------- + example : xr.Dataset + The input training example. + + Returns + ------- + xr.Dataset + The potentially augmented training example. + """ + if np.random.random() > self.probability: + return example + + return self.augmentation(example) + + +class AudioMixer: + """Callable class for MixUp augmentation, handling example fetching. + + Wraps the `mix_examples` logic, providing the necessary `example_source` + to fetch a second example dynamically when called. + + Parameters + ---------- + min_weight : float + Minimum mixing weight (lambda) for the primary example. + max_weight : float + Maximum mixing weight (lambda) for the primary example. + example_source : ExampleSource (Callable[[], xr.Dataset]) + A function that, when called, returns another random training example + dataset (`xr.Dataset`) to be mixed with the input example. + preprocessor : PreprocessorProtocol + The preprocessor needed to recompute the spectrogram after mixing + audio. + """ + + def __init__( + self, + min_weight: float, + max_weight: float, + example_source: ExampleSource, + preprocessor: PreprocessorProtocol, + ): + """Initialize the AudioMixer.""" + self.min_weight = min_weight + self.example_source = example_source + self.max_weight = max_weight + self.preprocessor = preprocessor + + def __call__(self, example: xr.Dataset) -> xr.Dataset: + """Fetch another example and perform mixup. + + Parameters + ---------- + example : xr.Dataset + The primary input training example. + + Returns + ------- + xr.Dataset + The mixed training example. Returns the original example if + fetching the second example fails. + """ + other = self.example_source() + return mix_examples( + example, + other, + self.preprocessor, + min_weight=self.min_weight, + max_weight=self.max_weight, + ) + + +def build_augmentation_from_config( + config: AugmentationConfig, + preprocessor: PreprocessorProtocol, + example_source: Optional[ExampleSource] = None, +) -> Augmentation: + """Factory function to build a single augmentation from its config. + + Takes a configuration object for one augmentation step (which includes the + `augmentation_type` discriminator) and returns the corresponding functional + augmentation callable (e.g., a `partial` function or a callable class like + `AudioMixer`). + + Parameters + ---------- + config : AugmentationConfig + Configuration object for a single augmentation (e.g., instance of + `MixAugmentationConfig`, `EchoAugmentationConfig`, etc.). + preprocessor : PreprocessorProtocol + The preprocessor object, required by augmentations that modify audio + and need to recompute the spectrogram (e.g., mixup, echo). + example_source : ExampleSource, optional + A callable that provides other training examples. Required only if + the configuration includes `MixAugmentationConfig` (`augmentation_type` + is "mix_audio"). + + Returns + ------- + Augmentation (Callable[[xr.Dataset], xr.Dataset]) + A callable function that takes a training example (`xr.Dataset`) and + returns a potentially augmented version. + + Raises + ------ + ValueError + If `config.augmentation_type` is "mix_audio" but `example_source` is + None. + NotImplementedError + If `config.augmentation_type` does not match any known augmentation + type. + """ + if config.augmentation_type == "mix_audio": + if example_source is None: + raise ValueError( + "Mix audio augmentation ('mix_audio') requires an " + "'example_source' callable to be provided." + ) + + return AudioMixer( + example_source=example_source, + preprocessor=preprocessor, + min_weight=config.min_weight, + max_weight=config.max_weight, + ) + + if config.augmentation_type == "add_echo": + return partial( + add_echo, + preprocessor=preprocessor, + max_delay=config.max_delay, + min_weight=config.min_weight, + max_weight=config.max_weight, + ) + + if config.augmentation_type == "scale_volume": + return partial( + scale_volume, + max_scaling=config.max_scaling, + min_scaling=config.min_scaling, + ) + + if config.augmentation_type == "warp": + return partial( + warp_spectrogram, + delta=config.delta, + ) + + if config.augmentation_type == "mask_time": + return partial( + mask_time, + max_perc=config.max_perc, + max_mask=config.max_masks, + ) + + if config.augmentation_type == "mask_freq": + return partial( + mask_frequency, + max_perc=config.max_perc, + max_masks=config.max_masks, + ) + + raise NotImplementedError( + "Invalid or unimplemented augmentation type: " + f"{config.augmentation_type}" ) -def load_agumentation_config( +def build_augmentations( + config: AugmentationsConfig, + preprocessor: PreprocessorProtocol, + example_source: Optional[ExampleSource] = None, +) -> Augmentation: + """Build a composite augmentation pipeline function from configuration. + + Takes the overall `AugmentationsConfig` (containing a list of individual + augmentation steps), builds the callable function for each step using + `build_augmentation_from_config`, wraps each function with `MaybeApply` + to handle its application probability, and returns a single + `Augmentation` function that applies the entire sequence. + + Parameters + ---------- + config : AugmentationsConfig + The configuration object detailing the sequence of augmentation steps. + preprocessor : PreprocessorProtocol + The preprocessor object, needed for audio-modifying augmentations. + example_source : ExampleSource, optional + A callable providing other examples, required if 'mix_audio' is used. + + Returns + ------- + Augmentation (Callable[[xr.Dataset], xr.Dataset]) + A single callable function that takes a training example (`xr.Dataset`) + and applies the configured sequence of augmentations probabilistically, + returning the augmented example. Returns the original example if + `config.steps` is empty. + + Raises + ------ + ValueError + If 'mix_audio' is configured but `example_source` is not provided. + NotImplementedError + If an unknown `augmentation_type` is encountered in `config.steps`. + """ + augmentations = [] + + for step_config in config.steps: + augmentation = build_augmentation_from_config( + step_config, + preprocessor=preprocessor, + example_source=example_source, + ) + augmentations.append( + MaybeApply( + augmentation=augmentation, + probability=step_config.probability, + ) + ) + + return partial(_apply_augmentations, augmentations=augmentations) + + +def load_augmentation_config( path: data.PathLike, field: Optional[str] = None ) -> AugmentationsConfig: + """Load the augmentations configuration from a file. + + Reads a configuration file (YAML) and validates it against the + `AugmentationsConfig` schema, potentially extracting data from a nested + field. + + Parameters + ---------- + path : PathLike + Path to the configuration file. + field : str, optional + Dot-separated path to a nested section within the file containing the + augmentations configuration (e.g., "training.augmentations"). If None, + the entire file content is used. + + Returns + ------- + AugmentationsConfig + The loaded and validated augmentations configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + yaml.YAMLError + If the file content is not valid YAML. + pydantic.ValidationError + If the loaded config data does not conform to `AugmentationsConfig`. + KeyError, TypeError + If `field` specifies an invalid path. + """ return load_config(path, schema=AugmentationsConfig, field=field) -def should_apply(config: BaseAugmentationConfig) -> bool: - if not config.enable: - return False - - return np.random.uniform() < config.probability - - -def augment_example( +def _apply_augmentations( example: xr.Dataset, - config: AugmentationsConfig, - preprocessing_config: Optional[PreprocessingConfig] = None, - others: Optional[Callable[[], xr.Dataset]] = None, -) -> xr.Dataset: - if should_apply(config.mix) and (others is not None): - other = others() - example = mix_examples( - example, - other, - min_weight=config.mix.min_weight, - max_weight=config.mix.max_weight, - config=preprocessing_config, - ) - - if should_apply(config.echo): - example = add_echo( - example, - max_delay=config.echo.max_delay, - min_weight=config.echo.min_weight, - max_weight=config.echo.max_weight, - config=preprocessing_config, - ) - - if should_apply(config.volume): - example = scale_volume( - example, - max_scaling=config.volume.max_scaling, - min_scaling=config.volume.min_scaling, - ) - - if should_apply(config.warp): - example = warp_spectrogram( - example, - delta=config.warp.delta, - ) - - if should_apply(config.time_mask): - example = mask_time( - example, - max_perc=config.time_mask.max_perc, - max_mask=config.time_mask.max_masks, - ) - - if should_apply(config.frequency_mask): - example = mask_frequency( - example, - max_perc=config.frequency_mask.max_perc, - max_masks=config.frequency_mask.max_masks, - ) - + augmentations: List[Augmentation], +): + """Apply a sequence of augmentation functions to an example.""" + for augmentation in augmentations: + example = augmentation(example) return example diff --git a/batdetect2/train/callbacks.py b/batdetect2/train/callbacks.py index b6bb05e..998bbbe 100644 --- a/batdetect2/train/callbacks.py +++ b/batdetect2/train/callbacks.py @@ -3,14 +3,15 @@ from lightning.pytorch.callbacks import Callback from torch.utils.data import DataLoader from batdetect2.evaluate import match_predictions_and_annotations -from batdetect2.post_process import postprocess_model_outputs +from batdetect2.postprocess import PostprocessorProtocol from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.types import ModelOutput class ValidationMetrics(Callback): - def __init__(self): + def __init__(self, postprocessor: PostprocessorProtocol): super().__init__() + self.postprocessor = postprocessor self.predictions = [] def on_validation_epoch_start( @@ -36,20 +37,20 @@ class ValidationMetrics(Callback): assert isinstance(dataset, LabeledDataset) clip_annotation = dataset.get_clip_annotation(batch_idx) - clip_prediction = postprocess_model_outputs( - outputs, - clips=[clip_annotation.clip], - classes=self.class_names, - decoder=self.decoder, - config=self.config.postprocessing, - )[0] - - matches = match_predictions_and_annotations( - clip_annotation, - clip_prediction, - ) - - self.validation_predictions.extend(matches) - return super().on_validation_batch_end( - trainer, pl_module, outputs, batch, batch_idx, dataloader_idx - ) + # clip_prediction = postprocess_model_outputs( + # outputs, + # clips=[clip_annotation.clip], + # classes=self.class_names, + # decoder=self.decoder, + # config=self.config.postprocessing, + # )[0] + # + # matches = match_predictions_and_annotations( + # clip_annotation, + # clip_prediction, + # ) + # + # self.validation_predictions.extend(matches) + # return super().on_validation_batch_end( + # trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + # ) diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 49205d9..c1be352 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -10,13 +10,13 @@ from soundevent import data from torch.utils.data import Dataset from batdetect2.configs import BaseConfig -from batdetect2.preprocess.tensors import adjust_width from batdetect2.train.augmentations import ( AugmentationsConfig, augment_example, select_subclip, ) -from batdetect2.train.preprocess import PreprocessingConfig +from batdetect2.train.preprocess import PreprocessorProtocol +from batdetect2.utils.tensors import adjust_width __all__ = [ "TrainExample", @@ -51,15 +51,15 @@ class DatasetConfig(BaseConfig): class LabeledDataset(Dataset): def __init__( self, + preprocessor: PreprocessorProtocol, filenames: Sequence[PathLike], subclip: Optional[SubclipConfig] = None, augmentation: Optional[AugmentationsConfig] = None, - preprocessing: Optional[PreprocessingConfig] = None, ): + self.preprocessor = preprocessor self.filenames = filenames self.subclip = subclip self.augmentation = augmentation - self.preprocessing = preprocessing or PreprocessingConfig() def __len__(self): return len(self.filenames) @@ -79,7 +79,7 @@ class LabeledDataset(Dataset): dataset = augment_example( dataset, self.augmentation, - preprocessing_config=self.preprocessing, + preprocessor=self.preprocessor, others=self.get_random_example, ) @@ -95,16 +95,16 @@ class LabeledDataset(Dataset): def from_directory( cls, directory: PathLike, + preprocessor: PreprocessorProtocol, extension: str = ".nc", subclip: Optional[SubclipConfig] = None, augmentation: Optional[AugmentationsConfig] = None, - preprocessing: Optional[PreprocessingConfig] = None, ): return cls( - get_preprocessed_files(directory, extension), + preprocessor=preprocessor, + filenames=get_preprocessed_files(directory, extension), subclip=subclip, augmentation=augmentation, - preprocessing=preprocessing, ) def get_random_example(self) -> xr.Dataset: diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index bb12b82..4f4454c 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -1,48 +1,47 @@ """Generate heatmap training targets for BatDetect2 models. -This module represents the final step in the `batdetect2.targets` pipeline, -converting processed sound event annotations from an audio clip into the -specific heatmap formats required for training the BatDetect2 neural network. +This module is responsible for creating the target labels used for training +BatDetect2 models. It converts sound event annotations for an audio clip into +the specific multi-channel heatmap formats required by the neural network. -It integrates the filtering, transformation, and class encoding logic defined -in the preceding configuration steps (`filtering`, `transform`, `classes`) -and applies them to generate three core outputs for a given spectrogram: +It uses a pre-configured object adhering to the `TargetProtocol` (from +`batdetect2.targets`) which encapsulates all the logic for filtering +annotations, transforming tags, encoding class names, and mapping annotation +geometry (ROIs) to target positions and sizes. This module then focuses on +rendering this information onto the heatmap grids. -1. **Detection Heatmap**: Indicates the presence and location of relevant - sound events. -2. **Class Heatmap**: Indicates the location and predicted class label for - events that match a specific target class. -3. **Size Heatmap**: Encodes the dimensions (width/time duration, - height/frequency bandwidth) of the detected sound events at their - reference locations. +The pipeline generates three core outputs for a given spectrogram: +1. **Detection Heatmap**: Indicates presence/location of relevant sound events. +2. **Class Heatmap**: Indicates location and class identity for specifically + classified events. +3. **Size Heatmap**: Encodes the target dimensions (width, height) of events. -The primary function generated by this module is a `ClipLabeller`, which takes -a `ClipAnnotation` object and its corresponding spectrogram (`xr.DataArray`) -and returns the calculated `Heatmaps`. Configuration options allow tuning of -the heatmap generation process (e.g., Gaussian smoothing sigma, reference point -within bounding boxes). +The primary function generated by this module is a `ClipLabeller` (defined in +`.types`), which takes a `ClipAnnotation` object and its corresponding +spectrogram and returns the calculated `Heatmaps` tuple. The main configurable +parameter specific to this module is the Gaussian smoothing sigma (`sigma`) +defined in `LabelConfig`. """ import logging from collections.abc import Iterable from functools import partial -from typing import Callable, List, NamedTuple, Optional +from typing import Optional import numpy as np import xarray as xr from scipy.ndimage import gaussian_filter -from soundevent import arrays, data, geometry -from soundevent.geometry.operations import Positions +from soundevent import arrays, data from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets.classes import SoundEventEncoder -from batdetect2.targets.filtering import SoundEventFilter -from batdetect2.targets.transform import SoundEventTransformation +from batdetect2.targets.types import TargetProtocol +from batdetect2.train.types import ( + ClipLabeller, + Heatmaps, +) __all__ = [ "LabelConfig", - "Heatmaps", - "ClipLabeller", "build_clip_labeler", "generate_clip_label", "generate_heatmaps", @@ -50,109 +49,45 @@ __all__ = [ ] +SIZE_DIMENSION = "dimension" +"""Dimension name for the size heatmap.""" + logger = logging.getLogger(__name__) -class Heatmaps(NamedTuple): - """Structure holding the generated heatmap targets. - - Attributes - ---------- - detection : xr.DataArray - Heatmap indicating the probability of sound event presence. Typically - smoothed with a Gaussian kernel centered on event reference points. - Shape matches the input spectrogram. Values normalized [0, 1]. - classes : xr.DataArray - Heatmap indicating the probability of specific class presence. Has an - additional 'category' dimension corresponding to the target class - names. Each category slice is typically smoothed with a Gaussian - kernel. Values normalized [0, 1] per category. - size : xr.DataArray - Heatmap encoding the size (width, height) of detected events. Has an - additional 'dimension' coordinate ('width', 'height'). Values represent - scaled dimensions placed at the event reference points. - """ - - detection: xr.DataArray - classes: xr.DataArray - size: xr.DataArray - - -ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps] -"""Type alias for the final clip labelling function. - -This function takes the complete annotations for a clip and the corresponding -spectrogram, applies all configured filtering, transformation, and encoding -steps, and returns the final `Heatmaps` used for model training. -""" - - class LabelConfig(BaseConfig): """Configuration parameters for heatmap generation. Attributes ---------- - position : Positions, default="bottom-left" - Specifies the reference point within each sound event's geometry - (bounding box) that is used to place the 'peak' or value on the - heatmaps. Options include 'center', 'bottom-left', 'top-right', etc. - See `soundevent.geometry.operations.Positions` for valid options. sigma : float, default=3.0 The standard deviation (in pixels/bins) of the Gaussian kernel applied to smooth the detection and class heatmaps. Larger values create more diffuse targets. - time_scale : float, default=1000.0 - A scaling factor applied to the time duration (width) of sound event - bounding boxes before storing them in the 'width' dimension of the - `size` heatmap. The appropriate value depends on how the model input/ - output resolution relates to physical time units (e.g., if time axis - is milliseconds, scale might be 1.0; if seconds, maybe 1000.0). Needs - to match model expectations. - frequency_scale : float, default=1/859.375 - A scaling factor applied to the frequency bandwidth (height) of sound - event bounding boxes before storing them in the 'height' dimension of - the `size` heatmap. The appropriate value depends on the relationship - between the frequency axis resolution (e.g., kHz or Hz per bin) and - the desired output units/scale for the model. Needs to match model - expectations. (The default suggests input frequency might be in Hz - and output is scaled relative to some reference). """ - position: Positions = "bottom-left" sigma: float = 3.0 - time_scale: float = 1000.0 - frequency_scale: float = 1 / 859.375 def build_clip_labeler( - filter_fn: SoundEventFilter, - transform_fn: SoundEventTransformation, - encoder_fn: SoundEventEncoder, - class_names: List[str], + targets: TargetProtocol, config: LabelConfig, ) -> ClipLabeller: - """Construct the clip labelling function. + """Construct the final clip labelling function. - This function takes the pre-built components from the previous target - definition steps (filtering, transformation, encoding) and the label - configuration, then returns a single callable (`ClipLabeller`) that - performs the end-to-end heatmap generation for a given clip and - spectrogram. + This factory function prepares the callable that will perform the + end-to-end heatmap generation for a given clip and spectrogram during + training data loading. It takes the fully configured `targets` object and + the `LabelConfig` and binds them to the `generate_clip_label` function. Parameters ---------- - filter_fn : SoundEventFilter - Function to filter irrelevant sound event annotations. - transform_fn : SoundEventTransformation - Function to transform tags of sound event annotations. - encoder_fn : SoundEventEncoder - Function to encode a sound event annotation into a class name. - class_names : List[str] - Ordered list of unique target class names for the classification - heatmap. + targets : TargetProtocol + An initialized object conforming to the `TargetProtocol`, providing all + necessary methods for filtering, transforming, encoding, and ROI + mapping. config : LabelConfig - Configuration object containing heatmap generation parameters (sigma, - etc.). + Configuration object containing heatmap generation parameters. Returns ------- @@ -162,10 +97,7 @@ def build_clip_labeler( """ return partial( generate_clip_label, - filter_fn=filter_fn, - transform_fn=transform_fn, - encoder_fn=encoder_fn, - class_names=class_names, + targets=targets, config=config, ) @@ -173,123 +105,97 @@ def build_clip_labeler( def generate_clip_label( clip_annotation: data.ClipAnnotation, spec: xr.DataArray, - filter_fn: SoundEventFilter, - transform_fn: SoundEventTransformation, - encoder_fn: SoundEventEncoder, - class_names: List[str], + targets: TargetProtocol, config: LabelConfig, ) -> Heatmaps: - """Generate heatmaps for a single clip by applying all processing steps. + """Generate training heatmaps for a single annotated clip. - This function orchestrates the process for one clip: - 1. Filters the sound events using `filter_fn`. - 2. Transforms the tags of filtered events using `transform_fn`. - 3. Passes the processed annotations and other parameters to - `generate_heatmaps` to create the final target heatmaps. + This function orchestrates the target generation process for one clip: + 1. Filters and transforms sound events using `targets.filter` and + `targets.transform`. + 2. Passes the resulting processed annotations, along with the spectrogram, + the `targets` object, and the Gaussian `sigma` from `config`, to the + core `generate_heatmaps` function. Parameters ---------- clip_annotation : data.ClipAnnotation - The complete annotation data for the audio clip. + The complete annotation data for the audio clip, including the list + of `sound_events` to process. spec : xr.DataArray - The spectrogram corresponding to the `clip_annotation`. - filter_fn : SoundEventFilter - Function to filter sound event annotations. - transform_fn : SoundEventTransformation - Function to transform tags of sound event annotations. - encoder_fn : SoundEventEncoder - Function to encode a sound event annotation into a class name. - class_names : List[str] - Ordered list of unique target class names. + The spectrogram corresponding to the `clip_annotation`. Must have + 'time' and 'frequency' dimensions/coordinates. + targets : TargetProtocol + The fully configured target definition object, providing methods for + filtering, transformation, encoding, and ROI mapping. config : LabelConfig - Configuration object containing heatmap generation parameters. + Configuration object providing heatmap parameters (primarily `sigma`). Returns ------- Heatmaps - The generated detection, classes, and size heatmaps for the clip. + A NamedTuple containing the generated 'detection', 'classes', and 'size' + heatmaps for this clip. """ return generate_heatmaps( ( - transform_fn(sound_event_annotation) + targets.transform(sound_event_annotation) for sound_event_annotation in clip_annotation.sound_events - if filter_fn(sound_event_annotation) + if targets.filter(sound_event_annotation) ), spec=spec, - class_names=class_names, - encoder=encoder_fn, + targets=targets, target_sigma=config.sigma, - position=config.position, - time_scale=config.time_scale, - frequency_scale=config.frequency_scale, ) def generate_heatmaps( sound_events: Iterable[data.SoundEventAnnotation], spec: xr.DataArray, - class_names: List[str], - encoder: SoundEventEncoder, + targets: TargetProtocol, target_sigma: float = 3.0, - position: Positions = "bottom-left", - time_scale: float = 1000.0, - frequency_scale: float = 1 / 859.375, dtype=np.float32, ) -> Heatmaps: """Generate detection, class, and size heatmaps from sound events. - Processes an iterable of sound event annotations (assumed to be already - filtered and transformed) and creates heatmap representations suitable - for training models like BatDetect2. + Creates heatmap representations from an iterable of sound event + annotations. This function relies on the provided `targets` object to get + the reference position, scaled size, and class encoding for each + annotation. - The process involves: - 1. Initializing empty heatmaps based on the spectrogram shape. - 2. Iterating through each sound event. - 3. For each event, finding its reference point and placing a '1.0' - on the detection heatmap at that point. - 4. Calculating the scaled bounding box size and placing it on the size - heatmap at the reference point. - 5. Encoding the event to get its class name and placing a '1.0' on the - corresponding class heatmap slice at the reference point (if - classified). - 6. Applying Gaussian smoothing to detection and class heatmaps. - 7. Normalizing detection and class heatmaps to the range [0, 1]. + Process: + 1. Initializes empty heatmap arrays based on `spec` shape and `targets` + metadata. + 2. Iterates through `sound_events`. + 3. For each event: + a. Gets geometry. Skips if missing. + b. Gets reference position using `targets.get_position()`. Skips if out + of bounds. + c. Places a peak (1.0) on the detection heatmap at the position. + d. Gets scaled size using `targets.get_size()` and places it on the + size heatmap. + e. Encodes class using `targets.encode()` and places a peak (1.0) on + the corresponding class heatmap layer if a specific class is + returned. + 4. Applies Gaussian smoothing (using `target_sigma`) to detection and class + heatmaps. + 5. Normalizes detection and class heatmaps to range [0, 1]. Parameters ---------- sound_events : Iterable[data.SoundEventAnnotation] - An iterable of sound event annotations to include in the heatmaps. - These should ideally be the result of prior filtering and tag - transformation steps. + An iterable of sound event annotations to render onto the heatmaps. spec : xr.DataArray The spectrogram array corresponding to the time/frequency range of the annotations. Used for shape and coordinate information. Must have - 'time' and 'frequency' dimensions. - class_names : List[str] - An ordered list of unique class names. The class heatmap will have - a channel ('category' dimension) for each name in this list. Must not - be empty. - encoder : SoundEventEncoder - A function that takes a SoundEventAnnotation and returns the - corresponding class name (str) or None if it doesn't belong to a - specific class (e.g., it falls into the generic 'Bat' category). + 'time' and 'frequency' dimensions/coordinates. + targets : TargetProtocol + The fully configured target definition object. Used to access class + names, dimension names, and the methods `get_position`, `get_size`, + `encode`. target_sigma : float, default=3.0 Standard deviation (in pixels/bins) of the Gaussian kernel applied to - smooth the detection and class heatmaps after initial point placement. - position : Positions, default="bottom-left" - The reference point within each annotation's geometry bounding box - used to place the signal on the heatmaps (e.g., "center", - "bottom-left"). See `soundevent.geometry.operations.Positions`. - time_scale : float, default=1000.0 - Scaling factor applied to the time duration (width in seconds) of - annotations when storing them in the size heatmap. The resulting - value's unit depends on this scale (e.g., 1000.0 might convert seconds - to ms). - frequency_scale : float, default=1/859.375 - Scaling factor applied to the frequency bandwidth (height in Hz or kHz) - of annotations when storing them in the size heatmap. The resulting - value's unit depends on this scale and the input unit. (Default - scaling relative to ~860 Hz). + smooth the detection and class heatmaps. dtype : type, default=np.float32 The data type for the generated heatmap arrays (e.g., `np.float32`). @@ -303,28 +209,10 @@ def generate_heatmaps( ------ ValueError If the input spectrogram `spec` does not have both 'time' and - 'frequency' dimensions, or if `class_names` is empty. - - Notes - ----- - * This function expects `sound_events` to be already filtered and - transformed. - * It includes error handling to skip individual annotations that cause - issues (e.g., missing geometry, out-of-bounds coordinates, encoder - errors) allowing the rest of the clip to be processed. Warnings or - errors are logged. - * The `time_scale` and `frequency_scale` parameters are crucial and must be - set according to the expectations of the specific BatDetect2 model - architecture being trained. Consult model documentation for required - units/scales. - * Gaussian filtering and normalization are applied only to detection and - class heatmaps, not the size heatmap. + 'frequency' dimensions, or if `targets.class_names` is empty. """ shape = dict(zip(spec.dims, spec.shape)) - if len(class_names) == 0: - raise ValueError("No class names provided.") - if "time" not in shape or "frequency" not in shape: raise ValueError( "Spectrogram must have time and frequency dimensions." @@ -333,18 +221,18 @@ def generate_heatmaps( # Initialize heatmaps detection_heatmap = xr.zeros_like(spec, dtype=dtype) class_heatmap = xr.DataArray( - data=np.zeros((len(class_names), *spec.shape), dtype=dtype), + data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype), dims=["category", *spec.dims], coords={ - "category": [*class_names], + "category": [*targets.class_names], **spec.coords, }, ) size_heatmap = xr.DataArray( data=np.zeros((2, *spec.shape), dtype=dtype), - dims=["dimension", *spec.dims], + dims=[SIZE_DIMENSION, *spec.dims], coords={ - "dimension": ["width", "height"], + SIZE_DIMENSION: targets.dimension_names, **spec.coords, }, ) @@ -359,7 +247,7 @@ def generate_heatmaps( continue # Get the position of the sound event - time, frequency = geometry.get_geometry_point(geom, position=position) + time, frequency = targets.get_position(sound_event_annotation) # Set 1.0 at the position of the sound event in the detection heatmap try: @@ -379,17 +267,7 @@ def generate_heatmaps( ) continue - # Set the size of the sound event at the position in the size heatmap - start_time, low_freq, end_time, high_freq = geometry.compute_bounds( - geom - ) - - size = np.array( - [ - (end_time - start_time) * time_scale, - (high_freq - low_freq) * frequency_scale, - ] - ) + size = targets.get_size(sound_event_annotation) size_heatmap = arrays.set_value_at_pos( size_heatmap, @@ -400,7 +278,7 @@ def generate_heatmaps( # Get the class name of the sound event try: - class_name = encoder(sound_event_annotation) + class_name = targets.encode(sound_event_annotation) except ValueError as e: logger.warning( "Skipping annotation %s: Unexpected error while encoding " @@ -414,19 +292,6 @@ def generate_heatmaps( # If the label is None skip the sound event continue - if class_name not in class_names: - # If the label is not in the class names skip the sound event - logger.warning( - ( - "Skipping annotation %s for class heatmap: " - "class name '%s' not in class names. Class names: %s" - ), - sound_event_annotation.uuid, - class_name, - class_names, - ) - continue - try: class_heatmap = arrays.set_value_at_pos( class_heatmap, @@ -453,7 +318,7 @@ def generate_heatmaps( ) class_heatmap = class_heatmap.groupby("category").map( - gaussian_filter, + gaussian_filter, # type: ignore args=(target_sigma,), ) diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 3a67135..b90770f 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -1,98 +1,101 @@ -"""Module for preprocessing data for training.""" +"""Preprocesses datasets for BatDetect2 model training. + +This module provides functions to take a collection of annotated audio clips +(`soundevent.data.ClipAnnotation`) and process them into the final format +required for training a BatDetect2 model. This typically involves: + +1. Loading the relevant audio segment for each annotation using a configured + `PreprocessorProtocol`. +2. Generating the corresponding input spectrogram using the + `PreprocessorProtocol`. +3. Generating the target heatmaps (detection, classification, size) using a + configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic). +4. Packaging the input spectrogram, target heatmaps, and potentially the + processed audio waveform into an `xarray.Dataset`. +5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF) + in an output directory. + +This offline preprocessing is often preferred for large datasets as it avoids +computationally intensive steps during the actual training loop. The module +includes utilities for parallel processing using `multiprocessing`. +""" -import os from functools import partial from multiprocessing import Pool from pathlib import Path -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Optional, Sequence import xarray as xr -from pydantic import Field from soundevent import data from tqdm.auto import tqdm -from batdetect2.configs import BaseConfig -from batdetect2.preprocess import ( - PreprocessingConfig, - compute_spectrogram, - load_clip_audio, -) -from batdetect2.targets import ( - LabelConfig, - TargetConfig, - build_sound_event_filter, - build_target_encoder, - generate_heatmaps, - get_class_names, -) - -PathLike = Union[Path, str, os.PathLike] -FilenameFn = Callable[[data.ClipAnnotation], str] +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.train.types import ClipLabeller __all__ = [ "preprocess_annotations", "preprocess_single_annotation", "generate_train_example", - "TrainPreprocessingConfig", ] - -class TrainPreprocessingConfig(BaseConfig): - preprocessing: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - target: TargetConfig = Field(default_factory=TargetConfig) - labels: LabelConfig = Field(default_factory=LabelConfig) +FilenameFn = Callable[[data.ClipAnnotation], str] +"""Type alias for a function that generates an output filename.""" def generate_train_example( clip_annotation: data.ClipAnnotation, - preprocessing_config: Optional[PreprocessingConfig] = None, - target_config: Optional[TargetConfig] = None, - label_config: Optional[LabelConfig] = None, + preprocessor: PreprocessorProtocol, + labeller: ClipLabeller, ) -> xr.Dataset: - """Generate a training example.""" - config = TrainPreprocessingConfig( - preprocessing=preprocessing_config or PreprocessingConfig(), - target=target_config or TargetConfig(), - labels=label_config or LabelConfig(), - ) + """Generate a complete training example for one annotation. - wave = load_clip_audio( - clip_annotation.clip, - config=config.preprocessing.audio, - ) + This function takes a single `ClipAnnotation`, applies the configured + preprocessing (`PreprocessorProtocol`) to get the processed waveform and + input spectrogram, applies the configured target generation + (`ClipLabeller`) to get the target heatmaps, and packages them all into a + single `xr.Dataset`. - spectrogram = compute_spectrogram( - wave, - config=config.preprocessing.spectrogram, - ) + Parameters + ---------- + clip_annotation : data.ClipAnnotation + The annotated clip to process. Contains the reference to the `Clip` + (audio segment) and the associated `SoundEventAnnotation` objects. + preprocessor : PreprocessorProtocol + An initialized preprocessor object responsible for loading/processing + audio and computing the input spectrogram. + labeller : ClipLabeller + An initialized clip labeller function responsible for generating the + target heatmaps (detection, class, size) from the `clip_annotation` + and the computed spectrogram. - filter_fn = build_sound_event_filter( - include=config.target.include, - exclude=config.target.exclude, - ) + Returns + ------- + xr.Dataset + An xarray Dataset containing the following data variables: + - `audio`: The preprocessed audio waveform (dims: 'audio_time'). + - `spectrogram`: The computed input spectrogram + (dims: 'time', 'frequency'). + - `detection`: The target detection heatmap + (dims: 'time', 'frequency'). + - `class`: The target class heatmap + (dims: 'category', 'time', 'frequency'). + - `size`: The target size heatmap + (dims: 'dimension', 'time', 'frequency'). + The Dataset also includes metadata in its attributes. - selected_events = [ - event for event in clip_annotation.sound_events if filter_fn(event) - ] + Notes + ----- + - The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time' + within the output Dataset to avoid coordinate conflicts with the + spectrogram's 'time' dimension when stored together. + - The original `ClipAnnotation` metadata is stored as a JSON string in the + Dataset's attributes for provenance. + """ + wave = preprocessor.load_clip_audio(clip_annotation.clip) - encoder = build_target_encoder( - config.target.classes, - replacement_rules=config.target.replace, - ) - class_names = get_class_names(config.target.classes) + spectrogram = preprocessor.compute_spectrogram(wave) - detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( - selected_events, - spectrogram, - class_names, - encoder, - target_sigma=config.labels.heatmaps.sigma, - position=config.labels.heatmaps.position, - time_scale=config.labels.heatmaps.time_scale, - frequency_scale=config.labels.heatmaps.frequency_scale, - ) + heatmaps = labeller(clip_annotation, spectrogram) dataset = xr.Dataset( { @@ -102,15 +105,14 @@ def generate_train_example( # as the waveform. "audio": wave.rename({"time": "audio_time"}), "spectrogram": spectrogram, - "detection": detection_heatmap, - "class": class_heatmap, - "size": size_heatmap, + "detection": heatmaps.detection, + "class": heatmaps.classes, + "size": heatmaps.size, } ) return dataset.assign_attrs( title=f"Training example for {clip_annotation.uuid}", - config=config.model_dump_json(), clip_annotation=clip_annotation.model_dump_json( exclude_none=True, exclude_defaults=True, @@ -119,10 +121,21 @@ def generate_train_example( ) -def save_to_file( +def _save_xr_dataset_to_file( dataset: xr.Dataset, - path: PathLike, + path: data.PathLike, ) -> None: + """Save an xarray Dataset to a NetCDF file with compression. + + Internal helper function used by `preprocess_single_annotation`. + + Parameters + ---------- + dataset : xr.Dataset + The training example dataset to save. + path : PathLike + The output file path (e.g., 'output/uuid.nc'). + """ dataset.to_netcdf( path, encoding={ @@ -135,20 +148,60 @@ def save_to_file( def _get_filename(clip_annotation: data.ClipAnnotation) -> str: + """Generate a default output filename based on the annotation UUID.""" return f"{clip_annotation.uuid}.nc" def preprocess_annotations( clip_annotations: Sequence[data.ClipAnnotation], - output_dir: PathLike, + output_dir: data.PathLike, + preprocessor: PreprocessorProtocol, + labeller: ClipLabeller, filename_fn: FilenameFn = _get_filename, replace: bool = False, - preprocessing_config: Optional[PreprocessingConfig] = None, - target_config: Optional[TargetConfig] = None, - label_config: Optional[LabelConfig] = None, max_workers: Optional[int] = None, ) -> None: - """Preprocess annotations and save to disk.""" + """Preprocess a sequence of ClipAnnotations and save results to disk. + + Generates the full training example (spectrogram, heatmaps, etc.) for each + `ClipAnnotation` in the input sequence using the provided `preprocessor` + and `labeller`. Saves each example as a separate NetCDF file in the + `output_dir`. Utilizes multiprocessing for potentially faster processing. + + Parameters + ---------- + clip_annotations : Sequence[data.ClipAnnotation] + A sequence (e.g., list) of the clip annotations to preprocess. + output_dir : PathLike + Path to the directory where the processed NetCDF files will be saved. + Will be created if it doesn't exist. + preprocessor : PreprocessorProtocol + Initialized preprocessor object to generate spectrograms. + labeller : ClipLabeller + Initialized labeller function to generate target heatmaps. + filename_fn : FilenameFn, optional + Function to generate the output filename (without extension) for each + `ClipAnnotation`. Defaults to using the annotation UUID via + `_get_filename`. + replace : bool, default=False + If True, existing files in `output_dir` with the same generated name + will be overwritten. If False (default), existing files are skipped. + max_workers : int, optional + Maximum number of worker processes to use for parallel processing. + If None (default), uses the number of CPUs available (`os.cpu_count()`). + + Returns + ------- + None + This function does not return anything; its side effect is creating + files in the `output_dir`. + + Raises + ------ + RuntimeError + If processing fails for any individual annotation when using + multiprocessing. The original exception will be attached as the cause. + """ output_dir = Path(output_dir) if not output_dir.is_dir(): @@ -163,9 +216,8 @@ def preprocess_annotations( output_dir=output_dir, filename_fn=filename_fn, replace=replace, - preprocessing_config=preprocessing_config, - target_config=target_config, - label_config=label_config, + preprocessor=preprocessor, + labeller=labeller, ), clip_annotations, ), @@ -176,13 +228,34 @@ def preprocess_annotations( def preprocess_single_annotation( clip_annotation: data.ClipAnnotation, - output_dir: PathLike, - preprocessing_config: Optional[PreprocessingConfig] = None, - target_config: Optional[TargetConfig] = None, - label_config: Optional[LabelConfig] = None, + output_dir: data.PathLike, + preprocessor: PreprocessorProtocol, + labeller: ClipLabeller, filename_fn: FilenameFn = _get_filename, replace: bool = False, ) -> None: + """Process a single ClipAnnotation and save the result to a file. + + Internal function designed to be called by `preprocess_annotations`, often + in parallel worker processes. It generates the training example using + `generate_train_example` and saves it using `save_to_file`. Handles + file existence checks based on the `replace` flag. + + Parameters + ---------- + clip_annotation : data.ClipAnnotation + The single annotation to process. + output_dir : Path + The directory where the output NetCDF file should be saved. + preprocessor : PreprocessorProtocol + Initialized preprocessor object. + labeller : ClipLabeller + Initialized labeller function. + filename_fn : FilenameFn, default=_get_filename + Function to determine the output filename. + replace : bool, default=False + Whether to overwrite existing output files. + """ output_dir = Path(output_dir) filename = filename_fn(clip_annotation) @@ -197,13 +270,12 @@ def preprocess_single_annotation( try: sample = generate_train_example( clip_annotation, - preprocessing_config=preprocessing_config, - target_config=target_config, - label_config=label_config, + preprocessor=preprocessor, + labeller=labeller, ) except Exception as error: raise RuntimeError( f"Failed to process annotation: {clip_annotation.uuid}" ) from error - save_to_file(sample, path) + _save_xr_dataset_to_file(sample, path) diff --git a/batdetect2/train/types.py b/batdetect2/train/types.py new file mode 100644 index 0000000..5bc071b --- /dev/null +++ b/batdetect2/train/types.py @@ -0,0 +1,48 @@ +from typing import Callable, NamedTuple + +import xarray as xr +from soundevent import data + +__all__ = [ + "Heatmaps", + "ClipLabeller", + "Augmentation", +] + + +class Heatmaps(NamedTuple): + """Structure holding the generated heatmap targets. + + Attributes + ---------- + detection : xr.DataArray + Heatmap indicating the probability of sound event presence. Typically + smoothed with a Gaussian kernel centered on event reference points. + Shape matches the input spectrogram. Values normalized [0, 1]. + classes : xr.DataArray + Heatmap indicating the probability of specific class presence. Has an + additional 'category' dimension corresponding to the target class + names. Each category slice is typically smoothed with a Gaussian + kernel. Values normalized [0, 1] per category. + size : xr.DataArray + Heatmap encoding the size (width, height) of detected events. Has an + additional 'dimension' coordinate ('width', 'height'). Values represent + scaled dimensions placed at the event reference points. + """ + + detection: xr.DataArray + classes: xr.DataArray + size: xr.DataArray + + +ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps] +"""Type alias for the final clip labelling function. + +This function takes the complete annotations for a clip and the corresponding +spectrogram, applies all configured filtering, transformation, and encoding +steps, and returns the final `Heatmaps` used for model training. +""" + +Augmentation = Callable[[xr.Dataset], xr.Dataset] + + diff --git a/batdetect2/types.py b/batdetect2/types.py index ec9ea8b..78b229a 100644 --- a/batdetect2/types.py +++ b/batdetect2/types.py @@ -1,14 +1,10 @@ """Types used in the code base.""" -from typing import Any, List, NamedTuple, Optional - +from typing import Any, List, NamedTuple, Optional, TypedDict import numpy as np import torch -from typing import TypedDict - - try: from typing import Protocol except ImportError: diff --git a/tests/test_models/test_inputs.py b/tests/test_models/test_inputs.py deleted file mode 100644 index af3e34d..0000000 --- a/tests/test_models/test_inputs.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from hypothesis import given -from hypothesis import strategies as st - -from batdetect2.models import ModelConfig, ModelType, build_architecture - - -@given( - input_width=st.integers(min_value=50, max_value=1500), - input_height=st.integers(min_value=1, max_value=16), - model_type=st.sampled_from(ModelType), -) -def test_model_can_process_spectrograms_of_any_width( - input_width, - input_height, - model_type, -): - # Input height must be divisible by 8 - input_height = 8 * input_height - - input = torch.rand([1, 1, input_height, input_width]) - - model = build_architecture( - config=ModelConfig( - name=model_type, # type: ignore - input_height=input_height, - ), - ) - - output = model(input) - assert output.shape[0] == 1 - assert output.shape[1] == model.out_channels - assert output.shape[2] == input_height - assert output.shape[3] == input_width diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 9b68f54..e86aae7 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr from soundevent import data -from batdetect2.targets import generate_heatmaps +from batdetect2.train.labels import generate_heatmaps recording = data.Recording( samplerate=256_000,