Compare commits

...

4 Commits

Author SHA1 Message Date
mbsantiago
93e89ecc46 LR Scheduler takes num of total batches 2025-08-28 08:52:11 +01:00
mbsantiago
34ef9e92a1 Make sure preprocessing is batchable 2025-08-27 23:58:38 +01:00
mbsantiago
0b5ac96fe8 Update model config 2025-08-27 23:58:07 +01:00
mbsantiago
dba6d2d918 Updating configs 2025-08-27 23:44:49 +01:00
24 changed files with 669 additions and 506 deletions

View File

@ -1,14 +1,3 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio
targets: targets:
classes: classes:
classes: classes:
@ -99,7 +88,9 @@ model:
out_channels: 256 out_channels: 256
bottleneck: bottleneck:
channels: 256 channels: 256
self_attention: true layers:
- block_type: SelfAttention
attention_channels: 256
decoder: decoder:
layers: layers:
- block_type: FreqCoordConvUp - block_type: FreqCoordConvUp
@ -114,9 +105,19 @@ model:
out_channels: 32 out_channels: 32
train: train:
batch_size: 8
learning_rate: 0.001 learning_rate: 0.001
t_max: 100 t_max: 100
dataloaders:
train:
batch_size: 8
num_workers: 2
shuffle: True
val:
batch_size: 8
num_workers: 2
loss: loss:
detection: detection:
weight: 1.0 weight: 1.0
@ -130,14 +131,12 @@ train:
alpha: 2 alpha: 2
size: size:
weight: 0.1 weight: 0.1
logger: logger:
logger_type: mlflow logger_type: csv
experiment_name: batdetect2
tracking_uri: http://localhost:5000
log_model: true
save_dir: outputs/log/ save_dir: outputs/log/
artifact_location: outputs/artifacts/ name: logs
checkpoint_path_prefix: outputs/checkpoints/
augmentations: augmentations:
steps: steps:
- augmentation_type: mix_audio - augmentation_type: mix_audio

View File

@ -0,0 +1,8 @@
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio

View File

@ -1,6 +1,6 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np import numpy as np
from soundevent import data from soundevent import data
@ -23,6 +23,16 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching.""" """The geometry representation to use for matching."""
class AffinityFunction(Protocol):
def __call__(
self,
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ...
class MatchConfig(BaseConfig): class MatchConfig(BaseConfig):
"""Configuration for matching geometries. """Configuration for matching geometries.
@ -74,6 +84,65 @@ _geometry_cast_functions: Mapping[
} }
def _timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
def _interval_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
"timestamp": _timestamp_affinity,
"interval": _interval_affinity,
}
def match_geometries( def match_geometries(
source: List[data.Geometry], source: List[data.Geometry],
target: List[data.Geometry], target: List[data.Geometry],
@ -81,6 +150,10 @@ def match_geometries(
scores: Optional[List[float]] = None, scores: Optional[List[float]] = None,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
geometry_cast = _geometry_cast_functions[config.geometry] geometry_cast = _geometry_cast_functions[config.geometry]
affinity_function = _affinity_functions.get(
config.geometry,
compute_affinity,
)
if config.strategy == "optimal": if config.strategy == "optimal":
return optimal_match( return optimal_match(
@ -98,6 +171,7 @@ def match_geometries(
time_buffer=config.time_buffer, time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer, freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold, affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
scores=scores, scores=scores,
) )
@ -111,6 +185,7 @@ def greedy_match(
target: List[data.Geometry], target: List[data.Geometry],
scores: Optional[List[float]] = None, scores: Optional[List[float]] = None,
affinity_threshold: float = 0.5, affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
time_buffer: float = 0.001, time_buffer: float = 0.001,
freq_buffer: float = 1000, freq_buffer: float = 1000,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
@ -168,7 +243,7 @@ def greedy_match(
affinities = np.array( affinities = np.array(
[ [
compute_affinity( affinity_function(
source_geometry, source_geometry,
target_geometry, target_geometry,
time_buffer=time_buffer, time_buffer=time_buffer,

View File

@ -26,13 +26,14 @@ for creating a standard BatDetect2 model instance is the `build_model` function
provided here. provided here.
""" """
from typing import Optional from typing import List, Optional
import torch import torch
from lightning import LightningModule from lightning import LightningModule
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
Backbone, Backbone,
BackboneConfig, BackboneConfig,
@ -66,8 +67,8 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess import PostprocessConfig, build_postprocessor from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel, ModelOutput from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import PostprocessorProtocol from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
@ -119,9 +120,12 @@ class Model(LightningModule):
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.targets = targets
self.save_hyperparameters()
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, wav: torch.Tensor) -> List[Detections]:
return self.detector(spec) spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
class ModelConfig(BaseConfig): class ModelConfig(BaseConfig):
@ -139,7 +143,6 @@ def build_model(config: Optional[ModelConfig] = None):
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess) preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
targets=targets,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.postprocess, config=config.postprocess,
) )
@ -153,3 +156,9 @@ def build_model(config: Optional[ModelConfig] = None):
preprocessor=preprocessor, preprocessor=preprocessor,
targets=targets, targets=targets,
) )
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)

View File

@ -55,6 +55,12 @@ __all__ = [
] ]
class SelfAttentionConfig(BaseConfig):
block_type: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
"""Self-Attention mechanism operating along the time dimension. """Self-Attention mechanism operating along the time dimension.
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
# Note, does not encode position information (absolute or relative) # Note, does not encode position information (absolute or relative)
self.temperature = temperature self.temperature = temperature
self.att_dim = attention_channels self.att_dim = attention_channels
self.key_fun = nn.Linear(in_channels, attention_channels) self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels) self.value_fun = nn.Linear(in_channels, attention_channels)
self.query_fun = nn.Linear(in_channels, attention_channels) self.query_fun = nn.Linear(in_channels, attention_channels)
@ -654,6 +661,7 @@ LayerConfig = Annotated[
StandardConvDownConfig, StandardConvDownConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
SelfAttentionConfig,
"LayerGroupConfig", "LayerGroupConfig",
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
@ -769,6 +777,17 @@ def build_layer_from_config(
input_height * 2, input_height * 2,
) )
if config.block_type == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.block_type == "LayerGroup": if config.block_type == "LayerGroup":
current_channels = in_channels current_channels = in_channels
current_height = input_height current_height = input_height

View File

@ -14,47 +14,27 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration. module based on the provided configuration.
""" """
from typing import Optional from typing import Annotated, List, Optional, Union
import torch import torch
from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import SelfAttention, VerticalConv from batdetect2.models.blocks import (
LayerConfig,
SelfAttentionConfig,
VerticalConv,
build_layer_from_config,
)
__all__ = [ __all__ = [
"BottleneckConfig", "BottleneckConfig",
"Bottleneck", "Bottleneck",
"BottleneckAttn",
"build_bottleneck", "build_bottleneck",
] ]
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
channels: int
self_attention: bool
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures. """Base Bottleneck module for Encoder-Decoder architectures.
@ -99,16 +79,24 @@ class Bottleneck(nn.Module):
input_height: int, input_height: int,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
bottleneck_channels: Optional[int] = None,
layers: Optional[List[torch.nn.Module]] = None,
) -> None: ) -> None:
"""Initialize the base Bottleneck layer.""" """Initialize the base Bottleneck layer."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.input_height = input_height self.input_height = input_height
self.out_channels = out_channels self.out_channels = out_channels
self.bottleneck_channels = (
bottleneck_channels
if bottleneck_channels is not None
else out_channels
)
self.layers = nn.ModuleList(layers or [])
self.conv_vert = VerticalConv( self.conv_vert = VerticalConv(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=self.bottleneck_channels,
input_height=input_height, input_height=input_height,
) )
@ -132,73 +120,52 @@ class Bottleneck(nn.Module):
convolution. convolution.
""" """
x = self.conv_vert(x) x = self.conv_vert(x)
for layer in self.layers:
x = layer(x)
return x.repeat([1, 1, self.input_height, 1]) return x.repeat([1, 1, self.input_height, 1])
class BottleneckAttn(Bottleneck): BottleneckLayerConfig = Annotated[
"""Bottleneck module including a Self-Attention layer. Union[SelfAttentionConfig,],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
the initial `VerticalConv`. This allows the bottleneck to capture global
temporal dependencies in the summarized frequency features before passing
them to the decoder.
Sequence: VerticalConv -> SelfAttention -> Repeat Height. class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Parameters Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
---------- ----------
input_height : int channels : int
Height (frequency bins) of the input tensor from the encoder. The number of output channels produced by the main convolutional layer
in_channels : int within the bottleneck. This often matches the number of channels coming
Number of channels in the input tensor from the encoder. from the last encoder stage, but can be different. Must be positive.
out_channels : int This also defines the channel dimensions used within the optional
Number of output channels produced by the `VerticalConv` and `SelfAttention` layer.
subsequently processed and output by this bottleneck. Also determines self_attention : bool
the input/output channels of the internal `SelfAttention` layer. If True, includes a `SelfAttention` layer operating on the time
attention : nn.Module dimension after an initial `VerticalConv` layer within the bottleneck.
An initialized `SelfAttention` module instance. If False, only the initial `VerticalConv` (and height repetition) is
performed.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
""" """
def __init__( channels: int
self, layers: List[BottleneckLayerConfig] = Field(
input_height: int, default_factory=list,
in_channels: int, )
out_channels: int,
attention: nn.Module,
) -> None:
"""Initialize the Bottleneck with Self-Attention."""
super().__init__(input_height, in_channels, out_channels)
self.attention = attention
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
and repeating the height dimension.
"""
x = self.conv_vert(x)
x = self.attention(x)
return x.repeat([1, 1, self.input_height, 1])
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig( DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
channels=256, channels=256,
self_attention=True, layers=[
SelfAttentionConfig(attention_channels=256),
],
) )
@ -234,21 +201,25 @@ def build_bottleneck(
""" """
config = config or DEFAULT_BOTTLENECK_CONFIG config = config or DEFAULT_BOTTLENECK_CONFIG
if config.self_attention: current_channels = in_channels
attention = SelfAttention( current_height = input_height
in_channels=config.channels,
attention_channels=config.channels,
)
return BottleneckAttn( layers = []
input_height=input_height,
in_channels=in_channels, for layer_config in config.layers:
out_channels=config.channels, layer, current_channels, current_height = build_layer_from_config(
attention=attention, input_height=current_height,
in_channels=current_channels,
config=layer_config,
) )
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)
layers.append(layer)
return Bottleneck( return Bottleneck(
input_height=input_height, input_height=input_height,
in_channels=in_channels, in_channels=in_channels,
out_channels=config.channels, out_channels=config.channels,
layers=layers,
) )

View File

@ -26,19 +26,32 @@ def create_ax(
def plot_spectrogram( def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray], spec: Union[torch.Tensor, np.ndarray],
start_time: float, start_time: Optional[float] = None,
end_time: float, end_time: Optional[float] = None,
min_freq: float, min_freq: Optional[float] = None,
max_freq: float, max_freq: Optional[float] = None,
ax: Optional[axes.Axes] = None, ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
cmap="gray", cmap="gray",
) -> axes.Axes: ) -> axes.Axes:
if isinstance(spec, torch.Tensor): if isinstance(spec, torch.Tensor):
spec = spec.numpy() spec = spec.numpy()
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
if start_time is None:
start_time = 0
if end_time is None:
end_time = spec.shape[-1]
if min_freq is None:
min_freq = 0
if max_freq is None:
max_freq = spec.shape[-2]
ax.pcolormesh( ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True), np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True), np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),

View File

@ -2,6 +2,7 @@
from typing import List, Optional from typing import List, Optional
import torch
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -20,13 +21,15 @@ from batdetect2.postprocess.nms import (
) )
from batdetect2.postprocess.remapping import map_detection_to_clip from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
BatDetect2Prediction, BatDetect2Prediction,
Detections, Detections,
PostprocessorProtocol, PostprocessorProtocol,
RawPrediction, RawPrediction,
) )
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
@ -128,7 +131,6 @@ def load_postprocess_config(
def build_postprocessor( def build_postprocessor(
targets: TargetProtocol,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None, config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol: ) -> PostprocessorProtocol:
@ -139,29 +141,52 @@ def build_postprocessor(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return Postprocessor( return Postprocessor(
targets=targets, samplerate=preprocessor.output_samplerate,
preprocessor=preprocessor, min_freq=preprocessor.min_freq,
config=config, max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
) )
class Postprocessor(PostprocessorProtocol): class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline.""" """Standard implementation of the postprocessing pipeline."""
targets: TargetProtocol
preprocessor: PreprocessorProtocol
def __init__( def __init__(
self, self,
targets: TargetProtocol, samplerate: float,
preprocessor: PreprocessorProtocol, min_freq: float,
config: PostprocessConfig, max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
): ):
"""Initialize the Postprocessor.""" """Initialize the Postprocessor."""
self.targets = targets super().__init__()
self.preprocessor = preprocessor self.samplerate = samplerate
self.config = config self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[Detections]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]
def get_detections( def get_detections(
self, self,
@ -169,13 +194,13 @@ class Postprocessor(PostprocessorProtocol):
clips: Optional[List[data.Clip]] = None, clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ) -> List[Detections]:
width = output.detection_probs.shape[-1] width = output.detection_probs.shape[-1]
duration = width / self.preprocessor.output_samplerate duration = width / self.samplerate
max_detections = int(self.config.top_k_per_sec * duration) max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor( detections = extract_prediction_tensor(
output, output,
max_detections=max_detections, max_detections=max_detections,
threshold=self.config.detection_threshold, threshold=self.detection_threshold,
) )
if clips is None: if clips is None:
@ -186,96 +211,116 @@ class Postprocessor(PostprocessorProtocol):
detection, detection,
start_time=clip.start_time, start_time=clip.start_time,
end_time=clip.end_time, end_time=clip.end_time,
min_freq=self.preprocessor.min_freq, min_freq=self.min_freq,
max_freq=self.preprocessor.max_freq, max_freq=self.max_freq,
) )
for detection, clip in zip(detections, clips) for detection, clip in zip(detections, clips)
] ]
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data def get_raw_predictions(
extraction, and geometry recovery via the configured output: ModelOutput,
`targets.recover_roi`. clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Parameters Processes raw model output through remapping, NMS, detection, data
---------- extraction, and geometry recovery via the configured
output : ModelOutput `targets.recover_roi`.
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns Parameters
------- ----------
List[List[RawPrediction]] output : ModelOutput
List of lists (one inner list per input clip). Each inner list Raw output from the neural network model for a batch.
contains `RawPrediction` objects for detections in that clip. clips : List[data.Clip]
""" List of `soundevent.data.Clip` objects corresponding to the batch.
detections = self.get_detections(output, clips)
return [ Returns
convert_detections_to_raw_predictions( -------
dataset, List[List[RawPrediction]]
targets=self.targets, List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detections = postprocessor.get_detections(output, clips)
return [
convert_detections_to_raw_predictions(
dataset,
targets=targets,
)
for dataset in detections
]
def get_sound_event_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
clips,
targets=targets,
postprocessor=postprocessor,
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
) )
for dataset in detections for raw in predictions
] ]
for predictions, clip in zip(raw_predictions, clips)
]
def get_sound_event_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions( def get_predictions(
self, output: ModelOutput, clips: List[data.Clip] output: ModelOutput,
) -> List[data.ClipPrediction]: clips: List[data.Clip],
"""Perform the full postprocessing pipeline for a batch. targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects. decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters Parameters
---------- ----------
output : ModelOutput output : ModelOutput
Raw output from the neural network model for a batch. Raw output from the neural network model for a batch.
clips : List[data.Clip] clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch. List of `soundevent.data.Clip` objects corresponding to the batch.
Returns Returns
------- -------
List[data.ClipPrediction] List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip, List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects. populated with `SoundEventPrediction` objects.
""" """
raw_predictions = self.get_raw_predictions(output, clips) raw_predictions = get_raw_predictions(
return [ output,
convert_raw_predictions_to_clip_prediction( clips,
prediction, targets=targets,
clip, postprocessor=postprocessor,
targets=self.targets, )
classification_threshold=self.config.classification_threshold, return [
) convert_raw_predictions_to_clip_prediction(
for prediction, clip in zip(raw_predictions, clips) prediction,
] clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -139,7 +139,21 @@ class FrequencyClip(torch.nn.Module):
self.high_index = high_index self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec[self.low_index : self.high_index] low_index = self.low_index
if low_index is None:
low_index = 0
if self.high_index is None:
length = spec.shape[-2] - low_index
else:
length = self.high_index - low_index
return torch.narrow(
spec,
dim=-2,
start=low_index,
length=length,
)
class PcenConfig(BaseConfig): class PcenConfig(BaseConfig):
@ -256,16 +270,22 @@ class ResizeSpec(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1] current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length) target_length = int(self.time_factor * current_length)
return (
torch.nn.functional.interpolate( original_ndim = spec.ndim
spec.unsqueeze(0).unsqueeze(0), while spec.ndim < 4:
size=(self.height, target_length), spec = spec.unsqueeze(0)
mode="bilinear",
) resized = torch.nn.functional.interpolate(
.squeeze(0) spec,
.squeeze(0) size=(self.height, target_length),
mode="bilinear",
) )
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
class PeakNormalizeConfig(BaseConfig): class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize" name: Literal["peak_normalize"] = "peak_normalize"

View File

@ -2,6 +2,7 @@ from batdetect2.train.augmentations import (
AugmentationsConfig, AugmentationsConfig,
EchoAugmentationConfig, EchoAugmentationConfig,
FrequencyMaskAugmentationConfig, FrequencyMaskAugmentationConfig,
RandomExampleSource,
TimeMaskAugmentationConfig, TimeMaskAugmentationConfig,
VolumeAugmentationConfig, VolumeAugmentationConfig,
WarpAugmentationConfig, WarpAugmentationConfig,
@ -23,7 +24,6 @@ from batdetect2.train.config import (
) )
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource,
list_preprocessed_files, list_preprocessed_files,
) )
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config

View File

@ -1,6 +1,7 @@
"""Applies data augmentation techniques to BatDetect2 training examples.""" """Applies data augmentation techniques to BatDetect2 training examples."""
import warnings import warnings
from collections.abc import Sequence
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
@ -10,8 +11,12 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import Augmentation, PreprocessorProtocol from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width
__all__ = [ __all__ = [
@ -39,21 +44,6 @@ ExampleSource = Callable[[], PreprocessedExample]
"""Type alias for a function that returns a training example""" """Type alias for a function that returns a training example"""
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
def mix_examples( def mix_examples(
example: PreprocessedExample, example: PreprocessedExample,
other: PreprocessedExample, other: PreprocessedExample,
@ -149,7 +139,12 @@ def add_echo(
audio = example.audio audio = example.audio
delay_steps = int(preprocessor.input_samplerate * delay) delay_steps = int(preprocessor.input_samplerate * delay)
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
slices = [slice(None)] * audio.ndim
slices[-1] = slice(None, -delay_steps)
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
delay_steps, dims=-1
)
audio = audio + weight * audio_delay audio = audio + weight * audio_delay
spectrogram = preprocessor(audio) spectrogram = preprocessor(audio)
@ -184,7 +179,7 @@ class VolumeAugmentationConfig(BaseConfig):
class ScaleVolume(torch.nn.Module): class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float, max_scaling: float): def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__() super().__init__()
self.min_scaling = min_scaling self.min_scaling = min_scaling
self.max_scaling = max_scaling self.max_scaling = max_scaling
@ -228,32 +223,22 @@ def warp_spectrogram(
example: PreprocessedExample, factor: float example: PreprocessedExample, factor: float
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Apply time warping by resampling the time axis.""" """Apply time warping by resampling the time axis."""
target_shape = example.spectrogram.shape width = example.spectrogram.shape[-1]
height = example.spectrogram.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor) new_width = int(target_shape[-1] * factor)
spectrogram = ( spectrogram = torch.nn.functional.interpolate(
torch.nn.functional.interpolate( adjust_width(example.spectrogram, new_width).unsqueeze(0),
adjust_width(example.spectrogram, new_width) size=target_shape,
.unsqueeze(0) mode="bilinear",
.unsqueeze(0), ).squeeze(0)
size=target_shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
detection = ( detection = torch.nn.functional.interpolate(
torch.nn.functional.interpolate( adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
adjust_width(example.detection_heatmap, new_width) size=target_shape,
.unsqueeze(0) mode="nearest",
.unsqueeze(0), ).squeeze(0)
size=target_shape,
mode="nearest",
)
.squeeze(0)
.squeeze(0)
)
classification = torch.nn.functional.interpolate( classification = torch.nn.functional.interpolate(
adjust_width(example.class_heatmap, new_width).unsqueeze(1), adjust_width(example.class_heatmap, new_width).unsqueeze(1),
@ -284,10 +269,16 @@ class TimeMaskAugmentationConfig(BaseConfig):
class MaskTime(torch.nn.Module): class MaskTime(torch.nn.Module):
def __init__(self, max_perc: float = 0.05, max_masks: int = 3) -> None: def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__() super().__init__()
self.max_perc = max_perc self.max_perc = max_perc
self.max_masks = max_masks self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample: def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1) num_masks = np.random.randint(1, self.max_masks + 1)
@ -306,20 +297,28 @@ class MaskTime(torch.nn.Module):
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size) (start, start + size) for start, size in zip(mask_start, mask_size)
] ]
return mask_time(example, masks) return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_time( def mask_time(
example: PreprocessedExample, example: PreprocessedExample,
masks: List[Tuple[int, int]], masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Apply time masking to the spectrogram.""" """Apply time masking to the spectrogram."""
for start, end in masks: for start, end in masks:
example.spectrogram[:, start:end] = example.spectrogram.mean() slices = [slice(None)] * example.spectrogram.ndim
example.class_heatmap[:, :, start:end] = 0 slices[-1] = slice(start, end)
example.size_heatmap[:, :, start:end] = 0
example.detection_heatmap[:, start:end] = 0 example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample( return PreprocessedExample(
audio=example.audio, audio=example.audio,
@ -335,13 +334,20 @@ class FrequencyMaskAugmentationConfig(BaseConfig):
probability: float = 0.2 probability: float = 0.2
max_perc: float = 0.10 max_perc: float = 0.10
max_masks: int = 3 max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module): class MaskFrequency(torch.nn.Module):
def __init__(self, max_perc: float = 0.10, max_masks: int = 3) -> None: def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__() super().__init__()
self.max_perc = max_perc self.max_perc = max_perc
self.max_masks = max_masks self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample: def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1) num_masks = np.random.randint(1, self.max_masks + 1)
@ -360,19 +366,26 @@ class MaskFrequency(torch.nn.Module):
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size) (start, start + size) for start, size in zip(mask_start, mask_size)
] ]
return mask_frequency(example, masks) return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_frequency( def mask_frequency(
example: PreprocessedExample, example: PreprocessedExample,
masks: List[Tuple[int, int]], masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Apply frequency masking to the spectrogram.""" """Apply frequency masking to the spectrogram."""
for start, end in masks: for start, end in masks:
example.spectrogram[start:end, :] = example.spectrogram.mean() slices = [slice(None)] * example.spectrogram.ndim
example.class_heatmap[:, start:end, :] = 0 slices[-2] = slice(start, end)
example.size_heatmap[:, start:end, :] = 0 example.spectrogram[tuple(slices)] = 0
example.detection_heatmap[start:end, :] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample( return PreprocessedExample(
audio=example.audio, audio=example.audio,
@ -383,6 +396,50 @@ def mask_frequency(
) )
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
class MixAudio(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
min_weight: float = 0.3,
max_weight: float = 0.7,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[
Union[ Union[
MixAugmentationConfig, MixAugmentationConfig,
@ -445,35 +502,6 @@ class MaybeApply(torch.nn.Module):
return self.augmentation(example) return self.augmentation(example)
class AudioMixer(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
min_weight: float,
max_weight: float,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
def build_augmentation_from_config( def build_augmentation_from_config(
config: AugmentationConfig, config: AugmentationConfig,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
@ -489,7 +517,7 @@ def build_augmentation_from_config(
) )
return None return None
return AudioMixer( return MixAudio(
example_source=example_source, example_source=example_source,
preprocessor=preprocessor, preprocessor=preprocessor,
min_weight=config.min_weight, min_weight=config.min_weight,
@ -585,3 +613,25 @@ def load_augmentation_config(
) -> AugmentationsConfig: ) -> AugmentationsConfig:
"""Load the augmentations configuration from a file.""" """Load the augmentations configuration from a file."""
return load_config(path, schema=AugmentationsConfig, field=field) return load_config(path, schema=AugmentationsConfig, field=field)
class RandomExampleSource:
def __init__(
self,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example
@classmethod
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol):
filenames = list_preprocessed_files(path)
return cls(filenames, clipper=clipper)

View File

@ -14,7 +14,9 @@ from batdetect2.evaluate.match import (
MatchConfig, MatchConfig,
match_sound_events_and_raw_predictions, match_sound_events_and_raw_predictions,
) )
from batdetect2.models import Model
from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions
from batdetect2.train.dataset import LabeledDataset from batdetect2.train.dataset import LabeledDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import ( from batdetect2.typing import (
@ -22,7 +24,6 @@ from batdetect2.typing import (
MatchEvaluation, MatchEvaluation,
MetricsProtocol, MetricsProtocol,
ModelOutput, ModelOutput,
PostprocessorProtocol,
TargetProtocol, TargetProtocol,
TrainExample, TrainExample,
) )
@ -127,8 +128,7 @@ class ValidationMetrics(Callback):
batch, batch,
outputs, outputs,
dataset=self.get_dataset(trainer), dataset=self.get_dataset(trainer),
postprocessor=pl_module.model.postprocessor, model=pl_module.model,
targets=pl_module.model.targets,
) )
) )
@ -137,15 +137,14 @@ def _get_batch_clips_and_predictions(
batch: TrainExample, batch: TrainExample,
outputs: ModelOutput, outputs: ModelOutput,
dataset: LabeledDataset, dataset: LabeledDataset,
postprocessor: PostprocessorProtocol, model: Model,
targets: TargetProtocol,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]: ) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [ clip_annotations = [
_get_subclip( _get_subclip(
dataset.get_clip_annotation(example_id), dataset.get_clip_annotation(example_id),
start_time=start_time.item(), start_time=start_time.item(),
end_time=end_time.item(), end_time=end_time.item(),
targets=targets, targets=model.targets,
) )
for example_id, start_time, end_time in zip( for example_id, start_time, end_time in zip(
batch.idx, batch.idx,
@ -156,9 +155,11 @@ def _get_batch_clips_and_predictions(
clips = [clip_annotation.clip for clip_annotation in clip_annotations] clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions( raw_predictions = get_sound_event_predictions(
outputs, outputs,
clips, clips,
targets=model.targets,
postprocessor=model.postprocessor
) )
return [ return [

View File

@ -8,7 +8,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width, slice_tensor
DEFAULT_TRAIN_CLIP_DURATION = 0.512 DEFAULT_TRAIN_CLIP_DURATION = 0.512
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -90,7 +90,12 @@ def select_subclip(
audio_start = int(np.floor(start * input_samplerate)) audio_start = int(np.floor(start * input_samplerate))
audio = adjust_width( audio = adjust_width(
example.audio[audio_start : audio_start + audio_width], slice_tensor(
example.audio,
start=audio_start,
end=audio_start + audio_width,
dim=-1,
),
audio_width, audio_width,
value=fill_value, value=fill_value,
) )
@ -100,19 +105,39 @@ def select_subclip(
return PreprocessedExample( return PreprocessedExample(
audio=audio, audio=audio,
spectrogram=adjust_width( spectrogram=adjust_width(
example.spectrogram[:, spec_start : spec_start + spec_width], slice_tensor(
example.spectrogram,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
class_heatmap=adjust_width( class_heatmap=adjust_width(
example.class_heatmap[:, :, spec_start : spec_start + spec_width], slice_tensor(
example.class_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
detection_heatmap=adjust_width( detection_heatmap=adjust_width(
example.detection_heatmap[:, spec_start : spec_start + spec_width], slice_tensor(
example.detection_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
size_heatmap=adjust_width( size_heatmap=adjust_width(
example.size_heatmap[:, :, spec_start : spec_start + spec_width], slice_tensor(
example.size_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
) )

View File

@ -44,8 +44,8 @@ class PLTrainerConfig(BaseConfig):
class DataLoaderConfig(BaseConfig): class DataLoaderConfig(BaseConfig):
batch_size: int batch_size: int = 8
shuffle: bool shuffle: bool = False
num_workers: int = 0 num_workers: int = 0

View File

@ -1,5 +1,4 @@
from pathlib import Path from typing import Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
@ -7,6 +6,10 @@ from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation from batdetect2.train.augmentations import Augmentation
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import PreprocessedExample
@ -38,8 +41,8 @@ class LabeledDataset(Dataset):
example = self.augmentation(example) example = self.augmentation(example)
return TrainExample( return TrainExample(
spec=example.spectrogram.unsqueeze(0), spec=example.spectrogram,
detection_heatmap=example.detection_heatmap.unsqueeze(0), detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap, class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap, size_heatmap=example.size_heatmap,
idx=torch.tensor(idx), idx=torch.tensor(idx),
@ -73,37 +76,3 @@ class LabeledDataset(Dataset):
def get_clip_annotation(self, idx) -> data.ClipAnnotation: def get_clip_annotation(self, idx) -> data.ClipAnnotation:
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+") item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
return item["clip_annotation"].tolist() return item["clip_annotation"].tolist()
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
class RandomExampleSource:
def __init__(
self,
filenames: List[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example

View File

@ -41,7 +41,6 @@ from batdetect2.typing import (
__all__ = [ __all__ = [
"LabelConfig", "LabelConfig",
"build_clip_labeler", "build_clip_labeler",
"generate_clip_label",
"generate_heatmaps", "generate_heatmaps",
"load_label_config", "load_label_config",
] ]
@ -99,21 +98,26 @@ def build_clip_labeler(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return partial( return partial(
generate_clip_label, generate_heatmaps,
targets=targets, targets=targets,
config=config,
min_freq=min_freq, min_freq=min_freq,
max_freq=max_freq, max_freq=max_freq,
target_sigma=config.sigma,
) )
def generate_clip_label( def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
spec: torch.Tensor, spec: torch.Tensor,
targets: TargetProtocol, targets: TargetProtocol,
config: LabelConfig,
min_freq: float, min_freq: float,
max_freq: float, max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps: ) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip. """Generate training heatmaps for a single annotated clip.
@ -150,57 +154,14 @@ def generate_clip_label(
num=len(clip_annotation.sound_events), num=len(clip_annotation.sound_events),
) )
sound_events = [] height = spec.shape[-2]
width = spec.shape[-1]
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps(
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
spec=spec,
targets=targets,
target_sigma=config.sigma,
min_freq=min_freq,
max_freq=max_freq,
)
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps:
if not spec.ndim == 2:
raise ValueError(
"Expecting a 2-dimensional tensor of shape (H, W), "
"H is the height of the spectrogram "
"(frequency bins), and W is the width of the spectrogram "
f"(temporal bins). Instead got: {spec.shape}"
)
height, width = spec.shape
num_classes = len(targets.class_names) num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names) num_dims = len(targets.dimension_names)
clip = clip_annotation.clip clip = clip_annotation.clip
# Initialize heatmaps # Initialize heatmaps
detection_heatmap = torch.zeros([height, width], dtype=dtype) detection_heatmap = torch.zeros([1, height, width], dtype=dtype)
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype) class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype) size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
@ -214,6 +175,16 @@ def generate_heatmaps(
times = times.to(spec.device) times = times.to(spec.device)
for sound_event_annotation in clip_annotation.sound_events: for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_event_annotation = targets.transform(sound_event_annotation)
geom = sound_event_annotation.sound_event.geometry geom = sound_event_annotation.sound_event.geometry
if geom is None: if geom is None:
logger.debug( logger.debug(
@ -245,7 +216,10 @@ def generate_heatmaps(
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2 distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2)) gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob) detection_heatmap[0] = torch.maximum(
detection_heatmap[0],
gaussian_blob,
)
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:]) size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
# Get the class name of the sound event # Get the class name of the sound event

View File

@ -34,7 +34,7 @@ class TrainingModule(L.LightningModule):
return self.model(spec) return self.model(spec)
def training_step(self, batch: TrainExample): def training_step(self, batch: TrainExample):
outputs = self.model(batch.spec) outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True) self.log("detection_loss/train", losses.total, logger=True)
@ -47,7 +47,7 @@ class TrainingModule(L.LightningModule):
batch: TrainExample, batch: TrainExample,
batch_idx: int, batch_idx: int,
) -> ModelOutput: ) -> ModelOutput:
outputs = self.model(batch.spec) outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True) self.log("detection_loss/val", losses.total, logger=True)

View File

@ -2,7 +2,7 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Sequence, TypedDict from typing import Callable, List, Optional, Sequence, TypedDict
import numpy as np import numpy as np
import torch import torch
@ -28,6 +28,8 @@ __all__ = [
"preprocess_dataset", "preprocess_dataset",
"TrainPreprocessConfig", "TrainPreprocessConfig",
"load_train_preprocessing_config", "load_train_preprocessing_config",
"save_preprocessed_example",
"load_preprocessed_example",
] ]
FilenameFn = Callable[[data.ClipAnnotation], str] FilenameFn = Callable[[data.ClipAnnotation], str]
@ -94,8 +96,10 @@ def generate_train_example(
labeller: ClipLabeller, labeller: ClipLabeller,
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Generate a complete training example for one annotation.""" """Generate a complete training example for one annotation."""
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip)) wave = torch.tensor(
spectrogram = preprocessor(wave) audio_loader.load_clip(clip_annotation.clip)
).unsqueeze(0)
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
heatmaps = labeller(clip_annotation, spectrogram) heatmaps = labeller(clip_annotation, spectrogram)
return PreprocessedExample( return PreprocessedExample(
audio=wave, audio=wave,
@ -145,7 +149,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
labeller=self.labeller, labeller=self.labeller,
) )
save_example_to_file(example, clip_annotation, path) save_preprocessed_example(example, clip_annotation, path)
return idx return idx
@ -153,7 +157,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
return len(self.clips) return len(self.clips)
def save_example_to_file( def save_preprocessed_example(
example: PreprocessedExample, example: PreprocessedExample,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
path: data.PathLike, path: data.PathLike,
@ -169,6 +173,23 @@ def save_example_to_file(
) )
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
def _get_filename(clip_annotation: data.ClipAnnotation) -> str: def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID.""" """Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}" return f"{clip_annotation.uuid}"

View File

@ -14,14 +14,16 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import build_model from batdetect2.models import Model, build_model
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import (
RandomExampleSource,
build_augmentations,
)
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource,
) )
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
@ -53,17 +55,13 @@ def train(
): ):
config = config or FullTrainingConfig() config = config or FullTrainingConfig()
if model_path is not None: model = build_model(config=config)
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(config)
trainer = build_trainer(config, targets=module.model.targets) trainer = build_trainer(config, targets=model.targets)
train_dataloader = build_train_loader( train_dataloader = build_train_loader(
train_examples, train_examples,
preprocessor=module.model.preprocessor, preprocessor=model.preprocessor,
config=config.train, config=config.train,
num_workers=train_workers, num_workers=train_workers,
) )
@ -71,7 +69,7 @@ def train(
val_dataloader = ( val_dataloader = (
build_val_loader( build_val_loader(
val_examples, val_examples,
preprocessor=module.model.preprocessor, preprocessor=model.preprocessor,
config=config.train, config=config.train,
num_workers=val_workers, num_workers=val_workers,
) )
@ -79,6 +77,16 @@ def train(
else None else None
) )
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(
model,
config,
batches_per_epoch=len(train_dataloader),
)
logger.info("Starting main training loop...") logger.info("Starting main training loop...")
trainer.fit( trainer.fit(
module, module,
@ -88,14 +96,17 @@ def train(
logger.info("Training complete.") logger.info("Training complete.")
def build_training_module(config: FullTrainingConfig) -> TrainingModule: def build_training_module(
model = build_model(config=config) model: Model,
config: FullTrainingConfig,
batches_per_epoch: int,
) -> TrainingModule:
loss = build_loss(config=config.train.loss) loss = build_loss(config=config.train.loss)
return TrainingModule( return TrainingModule(
model=model, model=model,
loss=loss, loss=loss,
learning_rate=config.train.learning_rate, learning_rate=config.train.learning_rate,
t_max=config.train.t_max, t_max=config.train.t_max * batches_per_epoch,
) )

View File

@ -95,69 +95,10 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.""" """Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[Detections]: ...
def get_detections( def get_detections(
self, self,
output: ModelOutput, output: ModelOutput,
clips: Optional[List[data.Clip]] = None, clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ... ) -> List[Detections]: ...
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes the raw model output for a batch through remapping, NMS,
detection, data extraction, and geometry recovery to produce a list of
`RawPrediction` objects for each corresponding input clip. This provides
a simplified, intermediate representation before final tag decoding.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[List[RawPrediction]]
A list of lists (one inner list per input clip, in order). Each
inner list contains the `RawPrediction` objects extracted for the
corresponding input clip.
"""
...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output for a batch and corresponding clips, applies the
entire postprocessing chain, and returns the final, interpretable
predictions as a list of `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[data.ClipPrediction]
A list containing one `ClipPrediction` object for each input clip
(in the same order), populated with `SoundEventPrediction` objects
representing the final detections with decoded tags and geometry.
"""
...

View File

@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2. throughout BatDetect2.
""" """
from collections.abc import Callable from collections.abc import Callable, Iterable
from typing import List, Optional, Protocol from typing import List, Optional, Protocol, Tuple
import numpy as np import numpy as np
from soundevent import data from soundevent import data

View File

@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import torch import torch
import xarray as xr import xarray as xr
@ -80,3 +82,14 @@ def adjust_width(
for index in range(dims) for index in range(dims)
] ]
return tensor[tuple(slices)] return tensor[tuple(slices)]
def slice_tensor(
tensor: torch.Tensor,
start: Optional[int] = None,
end: Optional[int] = None,
dim: int = -1,
) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start, end)
return tensor[tuple(slices)]

View File

@ -38,7 +38,6 @@ def build_from_config(
max_freq=preprocessor.max_freq, max_freq=preprocessor.max_freq,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
targets,
preprocessor=preprocessor, preprocessor=preprocessor,
config=postprocessing_config, config=postprocessing_config,
) )