Compare commits

...

4 Commits

Author SHA1 Message Date
mbsantiago
93e89ecc46 LR Scheduler takes num of total batches 2025-08-28 08:52:11 +01:00
mbsantiago
34ef9e92a1 Make sure preprocessing is batchable 2025-08-27 23:58:38 +01:00
mbsantiago
0b5ac96fe8 Update model config 2025-08-27 23:58:07 +01:00
mbsantiago
dba6d2d918 Updating configs 2025-08-27 23:44:49 +01:00
24 changed files with 669 additions and 506 deletions

View File

@ -1,14 +1,3 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio
targets:
classes:
classes:
@ -99,7 +88,9 @@ model:
out_channels: 256
bottleneck:
channels: 256
self_attention: true
layers:
- block_type: SelfAttention
attention_channels: 256
decoder:
layers:
- block_type: FreqCoordConvUp
@ -114,9 +105,19 @@ model:
out_channels: 32
train:
batch_size: 8
learning_rate: 0.001
t_max: 100
dataloaders:
train:
batch_size: 8
num_workers: 2
shuffle: True
val:
batch_size: 8
num_workers: 2
loss:
detection:
weight: 1.0
@ -130,14 +131,12 @@ train:
alpha: 2
size:
weight: 0.1
logger:
logger_type: mlflow
experiment_name: batdetect2
tracking_uri: http://localhost:5000
log_model: true
logger_type: csv
save_dir: outputs/log/
artifact_location: outputs/artifacts/
checkpoint_path_prefix: outputs/checkpoints/
name: logs
augmentations:
steps:
- augmentation_type: mix_audio

View File

@ -0,0 +1,8 @@
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio

View File

@ -1,6 +1,6 @@
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple
from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np
from soundevent import data
@ -23,6 +23,16 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
class AffinityFunction(Protocol):
def __call__(
self,
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ...
class MatchConfig(BaseConfig):
"""Configuration for matching geometries.
@ -74,6 +84,65 @@ _geometry_cast_functions: Mapping[
}
def _timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
def _interval_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
"timestamp": _timestamp_affinity,
"interval": _interval_affinity,
}
def match_geometries(
source: List[data.Geometry],
target: List[data.Geometry],
@ -81,6 +150,10 @@ def match_geometries(
scores: Optional[List[float]] = None,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
geometry_cast = _geometry_cast_functions[config.geometry]
affinity_function = _affinity_functions.get(
config.geometry,
compute_affinity,
)
if config.strategy == "optimal":
return optimal_match(
@ -98,6 +171,7 @@ def match_geometries(
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
scores=scores,
)
@ -111,6 +185,7 @@ def greedy_match(
target: List[data.Geometry],
scores: Optional[List[float]] = None,
affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
time_buffer: float = 0.001,
freq_buffer: float = 1000,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
@ -168,7 +243,7 @@ def greedy_match(
affinities = np.array(
[
compute_affinity(
affinity_function(
source_geometry,
target_geometry,
time_buffer=time_buffer,

View File

@ -26,13 +26,14 @@ for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
"""
from typing import Optional
from typing import List, Optional
import torch
from lightning import LightningModule
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
@ -66,8 +67,8 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel, ModelOutput
from batdetect2.typing.postprocess import PostprocessorProtocol
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
@ -119,9 +120,12 @@ class Model(LightningModule):
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.save_hyperparameters()
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
def forward(self, wav: torch.Tensor) -> List[Detections]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
class ModelConfig(BaseConfig):
@ -139,7 +143,6 @@ def build_model(config: Optional[ModelConfig] = None):
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
targets=targets,
preprocessor=preprocessor,
config=config.postprocess,
)
@ -153,3 +156,9 @@ def build_model(config: Optional[ModelConfig] = None):
preprocessor=preprocessor,
targets=targets,
)
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)

View File

@ -55,6 +55,12 @@ __all__ = [
]
class SelfAttentionConfig(BaseConfig):
block_type: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(nn.Module):
"""Self-Attention mechanism operating along the time dimension.
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
# Note, does not encode position information (absolute or relative)
self.temperature = temperature
self.att_dim = attention_channels
self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels)
self.query_fun = nn.Linear(in_channels, attention_channels)
@ -654,6 +661,7 @@ LayerConfig = Annotated[
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
SelfAttentionConfig,
"LayerGroupConfig",
],
Field(discriminator="block_type"),
@ -769,6 +777,17 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.block_type == "LayerGroup":
current_channels = in_channels
current_height = input_height

View File

@ -14,47 +14,27 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
"""
from typing import Optional
from typing import Annotated, List, Optional, Union
import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import SelfAttention, VerticalConv
from batdetect2.models.blocks import (
LayerConfig,
SelfAttentionConfig,
VerticalConv,
build_layer_from_config,
)
__all__ = [
"BottleneckConfig",
"Bottleneck",
"BottleneckAttn",
"build_bottleneck",
]
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
channels: int
self_attention: bool
class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures.
@ -99,16 +79,24 @@ class Bottleneck(nn.Module):
input_height: int,
in_channels: int,
out_channels: int,
bottleneck_channels: Optional[int] = None,
layers: Optional[List[torch.nn.Module]] = None,
) -> None:
"""Initialize the base Bottleneck layer."""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = out_channels
self.bottleneck_channels = (
bottleneck_channels
if bottleneck_channels is not None
else out_channels
)
self.layers = nn.ModuleList(layers or [])
self.conv_vert = VerticalConv(
in_channels=in_channels,
out_channels=out_channels,
out_channels=self.bottleneck_channels,
input_height=input_height,
)
@ -132,73 +120,52 @@ class Bottleneck(nn.Module):
convolution.
"""
x = self.conv_vert(x)
for layer in self.layers:
x = layer(x)
return x.repeat([1, 1, self.input_height, 1])
class BottleneckAttn(Bottleneck):
"""Bottleneck module including a Self-Attention layer.
BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
the initial `VerticalConv`. This allows the bottleneck to capture global
temporal dependencies in the summarized frequency features before passing
them to the decoder.
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Parameters
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
input_height : int
Height (frequency bins) of the input tensor from the encoder.
in_channels : int
Number of channels in the input tensor from the encoder.
out_channels : int
Number of output channels produced by the `VerticalConv` and
subsequently processed and output by this bottleneck. Also determines
the input/output channels of the internal `SelfAttention` layer.
attention : nn.Module
An initialized `SelfAttention` module instance.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
def __init__(
self,
input_height: int,
in_channels: int,
out_channels: int,
attention: nn.Module,
) -> None:
"""Initialize the Bottleneck with Self-Attention."""
super().__init__(input_height, in_channels, out_channels)
self.attention = attention
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
and repeating the height dimension.
"""
x = self.conv_vert(x)
x = self.attention(x)
return x.repeat([1, 1, self.input_height, 1])
channels: int
layers: List[BottleneckLayerConfig] = Field(
default_factory=list,
)
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
channels=256,
self_attention=True,
layers=[
SelfAttentionConfig(attention_channels=256),
],
)
@ -234,21 +201,25 @@ def build_bottleneck(
"""
config = config or DEFAULT_BOTTLENECK_CONFIG
if config.self_attention:
attention = SelfAttention(
in_channels=config.channels,
attention_channels=config.channels,
)
current_channels = in_channels
current_height = input_height
return BottleneckAttn(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
attention=attention,
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=layer_config,
)
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)
layers.append(layer)
return Bottleneck(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
layers=layers,
)

View File

@ -26,19 +26,32 @@ def create_ax(
def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray],
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
min_freq: Optional[float] = None,
max_freq: Optional[float] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
cmap="gray",
) -> axes.Axes:
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
ax = create_ax(ax=ax, figsize=figsize)
if start_time is None:
start_time = 0
if end_time is None:
end_time = spec.shape[-1]
if min_freq is None:
min_freq = 0
if max_freq is None:
max_freq = spec.shape[-2]
ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),

View File

@ -2,6 +2,7 @@
from typing import List, Optional
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
@ -20,13 +21,15 @@ from batdetect2.postprocess.nms import (
)
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
Detections,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
@ -128,7 +131,6 @@ def load_postprocess_config(
def build_postprocessor(
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol:
@ -139,29 +141,52 @@ def build_postprocessor(
lambda: config.to_yaml_string(),
)
return Postprocessor(
targets=targets,
preprocessor=preprocessor,
config=config,
samplerate=preprocessor.output_samplerate,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
)
class Postprocessor(PostprocessorProtocol):
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline."""
targets: TargetProtocol
preprocessor: PreprocessorProtocol
def __init__(
self,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: PostprocessConfig,
samplerate: float,
min_freq: float,
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
):
"""Initialize the Postprocessor."""
self.targets = targets
self.preprocessor = preprocessor
self.config = config
super().__init__()
self.samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[Detections]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]
def get_detections(
self,
@ -169,13 +194,13 @@ class Postprocessor(PostprocessorProtocol):
clips: Optional[List[data.Clip]] = None,
) -> List[Detections]:
width = output.detection_probs.shape[-1]
duration = width / self.preprocessor.output_samplerate
max_detections = int(self.config.top_k_per_sec * duration)
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.config.detection_threshold,
threshold=self.detection_threshold,
)
if clips is None:
@ -186,17 +211,19 @@ class Postprocessor(PostprocessorProtocol):
detection,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, clip in zip(detections, clips)
]
def get_raw_predictions(
self,
def get_raw_predictions(
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data
@ -216,21 +243,29 @@ class Postprocessor(PostprocessorProtocol):
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detections = self.get_detections(output, clips)
detections = postprocessor.get_detections(output, clips)
return [
convert_detections_to_raw_predictions(
dataset,
targets=self.targets,
targets=targets,
)
for dataset in detections
]
def get_sound_event_predictions(
self,
def get_sound_event_predictions(
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
clips,
targets=targets,
postprocessor=postprocessor,
)
return [
[
BatDetect2Prediction(
@ -238,8 +273,8 @@ class Postprocessor(PostprocessorProtocol):
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
targets=targets,
classification_threshold=classification_threshold,
),
)
for raw in predictions
@ -247,9 +282,14 @@ class Postprocessor(PostprocessorProtocol):
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]:
def get_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
@ -269,13 +309,18 @@ class Postprocessor(PostprocessorProtocol):
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = self.get_raw_predictions(output, clips)
raw_predictions = get_raw_predictions(
output,
clips,
targets=targets,
postprocessor=postprocessor,
)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -139,7 +139,21 @@ class FrequencyClip(torch.nn.Module):
self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec[self.low_index : self.high_index]
low_index = self.low_index
if low_index is None:
low_index = 0
if self.high_index is None:
length = spec.shape[-2] - low_index
else:
length = self.high_index - low_index
return torch.narrow(
spec,
dim=-2,
start=low_index,
length=length,
)
class PcenConfig(BaseConfig):
@ -256,15 +270,21 @@ class ResizeSpec(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
return (
torch.nn.functional.interpolate(
spec.unsqueeze(0).unsqueeze(0),
original_ndim = spec.ndim
while spec.ndim < 4:
spec = spec.unsqueeze(0)
resized = torch.nn.functional.interpolate(
spec,
size=(self.height, target_length),
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
class PeakNormalizeConfig(BaseConfig):

View File

@ -2,6 +2,7 @@ from batdetect2.train.augmentations import (
AugmentationsConfig,
EchoAugmentationConfig,
FrequencyMaskAugmentationConfig,
RandomExampleSource,
TimeMaskAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
@ -23,7 +24,6 @@ from batdetect2.train.config import (
)
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
list_preprocessed_files,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config

View File

@ -1,6 +1,7 @@
"""Applies data augmentation techniques to BatDetect2 training examples."""
import warnings
from collections.abc import Sequence
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
import numpy as np
@ -10,8 +11,12 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
from batdetect2.utils.arrays import adjust_width
__all__ = [
@ -39,21 +44,6 @@ ExampleSource = Callable[[], PreprocessedExample]
"""Type alias for a function that returns a training example"""
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
def mix_examples(
example: PreprocessedExample,
other: PreprocessedExample,
@ -149,7 +139,12 @@ def add_echo(
audio = example.audio
delay_steps = int(preprocessor.input_samplerate * delay)
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
slices = [slice(None)] * audio.ndim
slices[-1] = slice(None, -delay_steps)
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
delay_steps, dims=-1
)
audio = audio + weight * audio_delay
spectrogram = preprocessor(audio)
@ -184,7 +179,7 @@ class VolumeAugmentationConfig(BaseConfig):
class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float, max_scaling: float):
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__()
self.min_scaling = min_scaling
self.max_scaling = max_scaling
@ -228,32 +223,22 @@ def warp_spectrogram(
example: PreprocessedExample, factor: float
) -> PreprocessedExample:
"""Apply time warping by resampling the time axis."""
target_shape = example.spectrogram.shape
width = example.spectrogram.shape[-1]
height = example.spectrogram.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor)
spectrogram = (
torch.nn.functional.interpolate(
adjust_width(example.spectrogram, new_width)
.unsqueeze(0)
.unsqueeze(0),
spectrogram = torch.nn.functional.interpolate(
adjust_width(example.spectrogram, new_width).unsqueeze(0),
size=target_shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
).squeeze(0)
detection = (
torch.nn.functional.interpolate(
adjust_width(example.detection_heatmap, new_width)
.unsqueeze(0)
.unsqueeze(0),
detection = torch.nn.functional.interpolate(
adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
size=target_shape,
mode="nearest",
)
.squeeze(0)
.squeeze(0)
)
).squeeze(0)
classification = torch.nn.functional.interpolate(
adjust_width(example.class_heatmap, new_width).unsqueeze(1),
@ -284,10 +269,16 @@ class TimeMaskAugmentationConfig(BaseConfig):
class MaskTime(torch.nn.Module):
def __init__(self, max_perc: float = 0.05, max_masks: int = 3) -> None:
def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
@ -306,20 +297,28 @@ class MaskTime(torch.nn.Module):
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_time(example, masks)
return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_time(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply time masking to the spectrogram."""
for start, end in masks:
example.spectrogram[:, start:end] = example.spectrogram.mean()
example.class_heatmap[:, :, start:end] = 0
example.size_heatmap[:, :, start:end] = 0
example.detection_heatmap[:, start:end] = 0
slices = [slice(None)] * example.spectrogram.ndim
slices[-1] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
@ -335,13 +334,20 @@ class FrequencyMaskAugmentationConfig(BaseConfig):
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module):
def __init__(self, max_perc: float = 0.10, max_masks: int = 3) -> None:
def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
@ -360,19 +366,26 @@ class MaskFrequency(torch.nn.Module):
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_frequency(example, masks)
return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_frequency(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply frequency masking to the spectrogram."""
for start, end in masks:
example.spectrogram[start:end, :] = example.spectrogram.mean()
example.class_heatmap[:, start:end, :] = 0
example.size_heatmap[:, start:end, :] = 0
example.detection_heatmap[start:end, :] = 0
slices = [slice(None)] * example.spectrogram.ndim
slices[-2] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
@ -383,6 +396,50 @@ def mask_frequency(
)
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
class MixAudio(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
min_weight: float = 0.3,
max_weight: float = 0.7,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
AugmentationConfig = Annotated[
Union[
MixAugmentationConfig,
@ -445,35 +502,6 @@ class MaybeApply(torch.nn.Module):
return self.augmentation(example)
class AudioMixer(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
min_weight: float,
max_weight: float,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
def build_augmentation_from_config(
config: AugmentationConfig,
preprocessor: PreprocessorProtocol,
@ -489,7 +517,7 @@ def build_augmentation_from_config(
)
return None
return AudioMixer(
return MixAudio(
example_source=example_source,
preprocessor=preprocessor,
min_weight=config.min_weight,
@ -585,3 +613,25 @@ def load_augmentation_config(
) -> AugmentationsConfig:
"""Load the augmentations configuration from a file."""
return load_config(path, schema=AugmentationsConfig, field=field)
class RandomExampleSource:
def __init__(
self,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example
@classmethod
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol):
filenames = list_preprocessed_files(path)
return cls(filenames, clipper=clipper)

View File

@ -14,7 +14,9 @@ from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
)
from batdetect2.models import Model
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions
from batdetect2.train.dataset import LabeledDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import (
@ -22,7 +24,6 @@ from batdetect2.typing import (
MatchEvaluation,
MetricsProtocol,
ModelOutput,
PostprocessorProtocol,
TargetProtocol,
TrainExample,
)
@ -127,8 +128,7 @@ class ValidationMetrics(Callback):
batch,
outputs,
dataset=self.get_dataset(trainer),
postprocessor=pl_module.model.postprocessor,
targets=pl_module.model.targets,
model=pl_module.model,
)
)
@ -137,15 +137,14 @@ def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
model: Model,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=targets,
targets=model.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
@ -156,9 +155,11 @@ def _get_batch_clips_and_predictions(
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions(
raw_predictions = get_sound_event_predictions(
outputs,
clips,
targets=model.targets,
postprocessor=model.postprocessor
)
return [

View File

@ -8,7 +8,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width
from batdetect2.utils.arrays import adjust_width, slice_tensor
DEFAULT_TRAIN_CLIP_DURATION = 0.512
DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -90,7 +90,12 @@ def select_subclip(
audio_start = int(np.floor(start * input_samplerate))
audio = adjust_width(
example.audio[audio_start : audio_start + audio_width],
slice_tensor(
example.audio,
start=audio_start,
end=audio_start + audio_width,
dim=-1,
),
audio_width,
value=fill_value,
)
@ -100,19 +105,39 @@ def select_subclip(
return PreprocessedExample(
audio=audio,
spectrogram=adjust_width(
example.spectrogram[:, spec_start : spec_start + spec_width],
slice_tensor(
example.spectrogram,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
class_heatmap=adjust_width(
example.class_heatmap[:, :, spec_start : spec_start + spec_width],
slice_tensor(
example.class_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
detection_heatmap=adjust_width(
example.detection_heatmap[:, spec_start : spec_start + spec_width],
slice_tensor(
example.detection_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
size_heatmap=adjust_width(
example.size_heatmap[:, :, spec_start : spec_start + spec_width],
slice_tensor(
example.size_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
)

View File

@ -44,8 +44,8 @@ class PLTrainerConfig(BaseConfig):
class DataLoaderConfig(BaseConfig):
batch_size: int
shuffle: bool
batch_size: int = 8
shuffle: bool = False
num_workers: int = 0

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import List, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple
import numpy as np
import torch
@ -7,6 +6,10 @@ from soundevent import data
from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.train import PreprocessedExample
@ -38,8 +41,8 @@ class LabeledDataset(Dataset):
example = self.augmentation(example)
return TrainExample(
spec=example.spectrogram.unsqueeze(0),
detection_heatmap=example.detection_heatmap.unsqueeze(0),
spec=example.spectrogram,
detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
idx=torch.tensor(idx),
@ -73,37 +76,3 @@ class LabeledDataset(Dataset):
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
return item["clip_annotation"].tolist()
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
class RandomExampleSource:
def __init__(
self,
filenames: List[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example

View File

@ -41,7 +41,6 @@ from batdetect2.typing import (
__all__ = [
"LabelConfig",
"build_clip_labeler",
"generate_clip_label",
"generate_heatmaps",
"load_label_config",
]
@ -99,21 +98,26 @@ def build_clip_labeler(
lambda: config.to_yaml_string(),
)
return partial(
generate_clip_label,
generate_heatmaps,
targets=targets,
config=config,
min_freq=min_freq,
max_freq=max_freq,
target_sigma=config.sigma,
)
def generate_clip_label(
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
config: LabelConfig,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip.
@ -150,57 +154,14 @@ def generate_clip_label(
num=len(clip_annotation.sound_events),
)
sound_events = []
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps(
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
spec=spec,
targets=targets,
target_sigma=config.sigma,
min_freq=min_freq,
max_freq=max_freq,
)
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps:
if not spec.ndim == 2:
raise ValueError(
"Expecting a 2-dimensional tensor of shape (H, W), "
"H is the height of the spectrogram "
"(frequency bins), and W is the width of the spectrogram "
f"(temporal bins). Instead got: {spec.shape}"
)
height, width = spec.shape
height = spec.shape[-2]
width = spec.shape[-1]
num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names)
clip = clip_annotation.clip
# Initialize heatmaps
detection_heatmap = torch.zeros([height, width], dtype=dtype)
detection_heatmap = torch.zeros([1, height, width], dtype=dtype)
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
@ -214,6 +175,16 @@ def generate_heatmaps(
times = times.to(spec.device)
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_event_annotation = targets.transform(sound_event_annotation)
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
@ -245,7 +216,10 @@ def generate_heatmaps(
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob)
detection_heatmap[0] = torch.maximum(
detection_heatmap[0],
gaussian_blob,
)
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
# Get the class name of the sound event

View File

@ -34,7 +34,7 @@ class TrainingModule(L.LightningModule):
return self.model(spec)
def training_step(self, batch: TrainExample):
outputs = self.model(batch.spec)
outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True)
@ -47,7 +47,7 @@ class TrainingModule(L.LightningModule):
batch: TrainExample,
batch_idx: int,
) -> ModelOutput:
outputs = self.model(batch.spec)
outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True)

View File

@ -2,7 +2,7 @@
import os
from pathlib import Path
from typing import Callable, Optional, Sequence, TypedDict
from typing import Callable, List, Optional, Sequence, TypedDict
import numpy as np
import torch
@ -28,6 +28,8 @@ __all__ = [
"preprocess_dataset",
"TrainPreprocessConfig",
"load_train_preprocessing_config",
"save_preprocessed_example",
"load_preprocessed_example",
]
FilenameFn = Callable[[data.ClipAnnotation], str]
@ -94,8 +96,10 @@ def generate_train_example(
labeller: ClipLabeller,
) -> PreprocessedExample:
"""Generate a complete training example for one annotation."""
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
spectrogram = preprocessor(wave)
wave = torch.tensor(
audio_loader.load_clip(clip_annotation.clip)
).unsqueeze(0)
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
heatmaps = labeller(clip_annotation, spectrogram)
return PreprocessedExample(
audio=wave,
@ -145,7 +149,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
labeller=self.labeller,
)
save_example_to_file(example, clip_annotation, path)
save_preprocessed_example(example, clip_annotation, path)
return idx
@ -153,7 +157,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
return len(self.clips)
def save_example_to_file(
def save_preprocessed_example(
example: PreprocessedExample,
clip_annotation: data.ClipAnnotation,
path: data.PathLike,
@ -169,6 +173,23 @@ def save_example_to_file(
)
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}"

View File

@ -14,14 +14,16 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import build_model
from batdetect2.train.augmentations import build_augmentations
from batdetect2.models import Model, build_model
from batdetect2.train.augmentations import (
RandomExampleSource,
build_augmentations,
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
)
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
@ -53,17 +55,13 @@ def train(
):
config = config or FullTrainingConfig()
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(config)
model = build_model(config=config)
trainer = build_trainer(config, targets=module.model.targets)
trainer = build_trainer(config, targets=model.targets)
train_dataloader = build_train_loader(
train_examples,
preprocessor=module.model.preprocessor,
preprocessor=model.preprocessor,
config=config.train,
num_workers=train_workers,
)
@ -71,7 +69,7 @@ def train(
val_dataloader = (
build_val_loader(
val_examples,
preprocessor=module.model.preprocessor,
preprocessor=model.preprocessor,
config=config.train,
num_workers=val_workers,
)
@ -79,6 +77,16 @@ def train(
else None
)
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(
model,
config,
batches_per_epoch=len(train_dataloader),
)
logger.info("Starting main training loop...")
trainer.fit(
module,
@ -88,14 +96,17 @@ def train(
logger.info("Training complete.")
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
model = build_model(config=config)
def build_training_module(
model: Model,
config: FullTrainingConfig,
batches_per_epoch: int,
) -> TrainingModule:
loss = build_loss(config=config.train.loss)
return TrainingModule(
model=model,
loss=loss,
learning_rate=config.train.learning_rate,
t_max=config.train.t_max,
t_max=config.train.t_max * batches_per_epoch,
)

View File

@ -95,69 +95,10 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[Detections]: ...
def get_detections(
self,
output: ModelOutput,
clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ...
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes the raw model output for a batch through remapping, NMS,
detection, data extraction, and geometry recovery to produce a list of
`RawPrediction` objects for each corresponding input clip. This provides
a simplified, intermediate representation before final tag decoding.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[List[RawPrediction]]
A list of lists (one inner list per input clip, in order). Each
inner list contains the `RawPrediction` objects extracted for the
corresponding input clip.
"""
...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output for a batch and corresponding clips, applies the
entire postprocessing chain, and returns the final, interpretable
predictions as a list of `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[data.ClipPrediction]
A list containing one `ClipPrediction` object for each input clip
(in the same order), populated with `SoundEventPrediction` objects
representing the final detections with decoded tags and geometry.
"""
...

View File

@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2.
"""
from collections.abc import Callable
from typing import List, Optional, Protocol
from collections.abc import Callable, Iterable
from typing import List, Optional, Protocol, Tuple
import numpy as np
from soundevent import data

View File

@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
import torch
import xarray as xr
@ -80,3 +82,14 @@ def adjust_width(
for index in range(dims)
]
return tensor[tuple(slices)]
def slice_tensor(
tensor: torch.Tensor,
start: Optional[int] = None,
end: Optional[int] = None,
dim: int = -1,
) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start, end)
return tensor[tuple(slices)]

View File

@ -38,7 +38,6 @@ def build_from_config(
max_freq=preprocessor.max_freq,
)
postprocessor = build_postprocessor(
targets,
preprocessor=preprocessor,
config=postprocessing_config,
)