mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
4 Commits
ff754a1269
...
93e89ecc46
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93e89ecc46 | ||
|
|
34ef9e92a1 | ||
|
|
0b5ac96fe8 | ||
|
|
dba6d2d918 |
@ -1,14 +1,3 @@
|
|||||||
datasets:
|
|
||||||
train:
|
|
||||||
name: example dataset
|
|
||||||
description: Only for demonstration purposes
|
|
||||||
sources:
|
|
||||||
- format: batdetect2
|
|
||||||
name: Example Data
|
|
||||||
description: Examples included for testing batdetect2
|
|
||||||
annotations_dir: example_data/anns
|
|
||||||
audio_dir: example_data/audio
|
|
||||||
|
|
||||||
targets:
|
targets:
|
||||||
classes:
|
classes:
|
||||||
classes:
|
classes:
|
||||||
@ -99,7 +88,9 @@ model:
|
|||||||
out_channels: 256
|
out_channels: 256
|
||||||
bottleneck:
|
bottleneck:
|
||||||
channels: 256
|
channels: 256
|
||||||
self_attention: true
|
layers:
|
||||||
|
- block_type: SelfAttention
|
||||||
|
attention_channels: 256
|
||||||
decoder:
|
decoder:
|
||||||
layers:
|
layers:
|
||||||
- block_type: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
@ -114,9 +105,19 @@ model:
|
|||||||
out_channels: 32
|
out_channels: 32
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 8
|
|
||||||
learning_rate: 0.001
|
learning_rate: 0.001
|
||||||
t_max: 100
|
t_max: 100
|
||||||
|
|
||||||
|
dataloaders:
|
||||||
|
train:
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 2
|
||||||
|
shuffle: True
|
||||||
|
|
||||||
|
val:
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 2
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
detection:
|
detection:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
@ -130,14 +131,12 @@ train:
|
|||||||
alpha: 2
|
alpha: 2
|
||||||
size:
|
size:
|
||||||
weight: 0.1
|
weight: 0.1
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
logger_type: mlflow
|
logger_type: csv
|
||||||
experiment_name: batdetect2
|
|
||||||
tracking_uri: http://localhost:5000
|
|
||||||
log_model: true
|
|
||||||
save_dir: outputs/log/
|
save_dir: outputs/log/
|
||||||
artifact_location: outputs/artifacts/
|
name: logs
|
||||||
checkpoint_path_prefix: outputs/checkpoints/
|
|
||||||
augmentations:
|
augmentations:
|
||||||
steps:
|
steps:
|
||||||
- augmentation_type: mix_audio
|
- augmentation_type: mix_audio
|
||||||
|
|||||||
8
example_data/dataset.yaml
Normal file
8
example_data/dataset.yaml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
name: example dataset
|
||||||
|
description: Only for demonstration purposes
|
||||||
|
sources:
|
||||||
|
- format: batdetect2
|
||||||
|
name: Example Data
|
||||||
|
description: Examples included for testing batdetect2
|
||||||
|
annotations_dir: example_data/anns
|
||||||
|
audio_dir: example_data/audio
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Literal, Optional, Tuple
|
from typing import List, Literal, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -23,6 +23,16 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
|||||||
"""The geometry representation to use for matching."""
|
"""The geometry representation to use for matching."""
|
||||||
|
|
||||||
|
|
||||||
|
class AffinityFunction(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float: ...
|
||||||
|
|
||||||
|
|
||||||
class MatchConfig(BaseConfig):
|
class MatchConfig(BaseConfig):
|
||||||
"""Configuration for matching geometries.
|
"""Configuration for matching geometries.
|
||||||
|
|
||||||
@ -74,6 +84,65 @@ _geometry_cast_functions: Mapping[
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _timestamp_affinity(
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float:
|
||||||
|
assert isinstance(geometry1, data.TimeStamp)
|
||||||
|
assert isinstance(geometry2, data.TimeStamp)
|
||||||
|
|
||||||
|
start_time1 = geometry1.coordinates
|
||||||
|
start_time2 = geometry2.coordinates
|
||||||
|
|
||||||
|
a = min(start_time1, start_time2)
|
||||||
|
b = max(start_time1, start_time2)
|
||||||
|
|
||||||
|
if b - a >= 2 * time_buffer:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
intersection = a - b + 2 * time_buffer
|
||||||
|
union = b - a + 2 * time_buffer
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
|
def _interval_affinity(
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float:
|
||||||
|
assert isinstance(geometry1, data.TimeInterval)
|
||||||
|
assert isinstance(geometry2, data.TimeInterval)
|
||||||
|
|
||||||
|
start_time1, end_time1 = geometry1.coordinates
|
||||||
|
start_time2, end_time2 = geometry1.coordinates
|
||||||
|
|
||||||
|
start_time1 -= time_buffer
|
||||||
|
start_time2 -= time_buffer
|
||||||
|
end_time1 += time_buffer
|
||||||
|
end_time2 += time_buffer
|
||||||
|
|
||||||
|
intersection = max(
|
||||||
|
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
||||||
|
)
|
||||||
|
union = (
|
||||||
|
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||||
|
)
|
||||||
|
|
||||||
|
if union == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
|
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
|
||||||
|
"timestamp": _timestamp_affinity,
|
||||||
|
"interval": _interval_affinity,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def match_geometries(
|
def match_geometries(
|
||||||
source: List[data.Geometry],
|
source: List[data.Geometry],
|
||||||
target: List[data.Geometry],
|
target: List[data.Geometry],
|
||||||
@ -81,6 +150,10 @@ def match_geometries(
|
|||||||
scores: Optional[List[float]] = None,
|
scores: Optional[List[float]] = None,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
geometry_cast = _geometry_cast_functions[config.geometry]
|
geometry_cast = _geometry_cast_functions[config.geometry]
|
||||||
|
affinity_function = _affinity_functions.get(
|
||||||
|
config.geometry,
|
||||||
|
compute_affinity,
|
||||||
|
)
|
||||||
|
|
||||||
if config.strategy == "optimal":
|
if config.strategy == "optimal":
|
||||||
return optimal_match(
|
return optimal_match(
|
||||||
@ -98,6 +171,7 @@ def match_geometries(
|
|||||||
time_buffer=config.time_buffer,
|
time_buffer=config.time_buffer,
|
||||||
freq_buffer=config.frequency_buffer,
|
freq_buffer=config.frequency_buffer,
|
||||||
affinity_threshold=config.affinity_threshold,
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
affinity_function=affinity_function,
|
||||||
scores=scores,
|
scores=scores,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,6 +185,7 @@ def greedy_match(
|
|||||||
target: List[data.Geometry],
|
target: List[data.Geometry],
|
||||||
scores: Optional[List[float]] = None,
|
scores: Optional[List[float]] = None,
|
||||||
affinity_threshold: float = 0.5,
|
affinity_threshold: float = 0.5,
|
||||||
|
affinity_function: AffinityFunction = compute_affinity,
|
||||||
time_buffer: float = 0.001,
|
time_buffer: float = 0.001,
|
||||||
freq_buffer: float = 1000,
|
freq_buffer: float = 1000,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
@ -168,7 +243,7 @@ def greedy_match(
|
|||||||
|
|
||||||
affinities = np.array(
|
affinities = np.array(
|
||||||
[
|
[
|
||||||
compute_affinity(
|
affinity_function(
|
||||||
source_geometry,
|
source_geometry,
|
||||||
target_geometry,
|
target_geometry,
|
||||||
time_buffer=time_buffer,
|
time_buffer=time_buffer,
|
||||||
|
|||||||
@ -26,13 +26,14 @@ for creating a standard BatDetect2 model instance is the `build_model` function
|
|||||||
provided here.
|
provided here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lightning import LightningModule
|
from lightning import LightningModule
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Backbone,
|
Backbone,
|
||||||
BackboneConfig,
|
BackboneConfig,
|
||||||
@ -66,8 +67,8 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
|||||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.typing.models import DetectionModel, ModelOutput
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import PostprocessorProtocol
|
from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
@ -119,9 +120,12 @@ class Model(LightningModule):
|
|||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, wav: torch.Tensor) -> List[Detections]:
|
||||||
return self.detector(spec)
|
spec = self.preprocessor(wav)
|
||||||
|
outputs = self.detector(spec)
|
||||||
|
return self.postprocessor(outputs)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
class ModelConfig(BaseConfig):
|
||||||
@ -139,7 +143,6 @@ def build_model(config: Optional[ModelConfig] = None):
|
|||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
@ -153,3 +156,9 @@ def build_model(config: Optional[ModelConfig] = None):
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_config(
|
||||||
|
path: PathLike, field: Optional[str] = None
|
||||||
|
) -> ModelConfig:
|
||||||
|
return load_config(path, schema=ModelConfig, field=field)
|
||||||
|
|||||||
@ -55,6 +55,12 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionConfig(BaseConfig):
|
||||||
|
block_type: Literal["SelfAttention"] = "SelfAttention"
|
||||||
|
attention_channels: int
|
||||||
|
temperature: float = 1
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Self-Attention mechanism operating along the time dimension.
|
"""Self-Attention mechanism operating along the time dimension.
|
||||||
|
|
||||||
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
|
|||||||
# Note, does not encode position information (absolute or relative)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.att_dim = attention_channels
|
self.att_dim = attention_channels
|
||||||
|
|
||||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.query_fun = nn.Linear(in_channels, attention_channels)
|
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||||
@ -654,6 +661,7 @@ LayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
|
SelfAttentionConfig,
|
||||||
"LayerGroupConfig",
|
"LayerGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
@ -769,6 +777,17 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.block_type == "SelfAttention":
|
||||||
|
return (
|
||||||
|
SelfAttention(
|
||||||
|
in_channels=in_channels,
|
||||||
|
attention_channels=config.attention_channels,
|
||||||
|
temperature=config.temperature,
|
||||||
|
),
|
||||||
|
config.attention_channels,
|
||||||
|
input_height,
|
||||||
|
)
|
||||||
|
|
||||||
if config.block_type == "LayerGroup":
|
if config.block_type == "LayerGroup":
|
||||||
current_channels = in_channels
|
current_channels = in_channels
|
||||||
current_height = input_height
|
current_height = input_height
|
||||||
|
|||||||
@ -14,47 +14,27 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
|
|||||||
module based on the provided configuration.
|
module based on the provided configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import SelfAttention, VerticalConv
|
from batdetect2.models.blocks import (
|
||||||
|
LayerConfig,
|
||||||
|
SelfAttentionConfig,
|
||||||
|
VerticalConv,
|
||||||
|
build_layer_from_config,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
"BottleneckAttn",
|
|
||||||
"build_bottleneck",
|
"build_bottleneck",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BottleneckConfig(BaseConfig):
|
|
||||||
"""Configuration for the bottleneck layer(s).
|
|
||||||
|
|
||||||
Defines the number of channels within the bottleneck and whether to include
|
|
||||||
a self-attention mechanism.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
channels : int
|
|
||||||
The number of output channels produced by the main convolutional layer
|
|
||||||
within the bottleneck. This often matches the number of channels coming
|
|
||||||
from the last encoder stage, but can be different. Must be positive.
|
|
||||||
This also defines the channel dimensions used within the optional
|
|
||||||
`SelfAttention` layer.
|
|
||||||
self_attention : bool
|
|
||||||
If True, includes a `SelfAttention` layer operating on the time
|
|
||||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
|
||||||
If False, only the initial `VerticalConv` (and height repetition) is
|
|
||||||
performed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
channels: int
|
|
||||||
self_attention: bool
|
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
class Bottleneck(nn.Module):
|
||||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||||
|
|
||||||
@ -99,16 +79,24 @@ class Bottleneck(nn.Module):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
|
bottleneck_channels: Optional[int] = None,
|
||||||
|
layers: Optional[List[torch.nn.Module]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the base Bottleneck layer."""
|
"""Initialize the base Bottleneck layer."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.bottleneck_channels = (
|
||||||
|
bottleneck_channels
|
||||||
|
if bottleneck_channels is not None
|
||||||
|
else out_channels
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(layers or [])
|
||||||
|
|
||||||
self.conv_vert = VerticalConv(
|
self.conv_vert = VerticalConv(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=self.bottleneck_channels,
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -132,73 +120,52 @@ class Bottleneck(nn.Module):
|
|||||||
convolution.
|
convolution.
|
||||||
"""
|
"""
|
||||||
x = self.conv_vert(x)
|
x = self.conv_vert(x)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
return x.repeat([1, 1, self.input_height, 1])
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
|
|
||||||
class BottleneckAttn(Bottleneck):
|
BottleneckLayerConfig = Annotated[
|
||||||
"""Bottleneck module including a Self-Attention layer.
|
Union[SelfAttentionConfig,],
|
||||||
|
Field(discriminator="block_type"),
|
||||||
|
]
|
||||||
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
|
|
||||||
the initial `VerticalConv`. This allows the bottleneck to capture global
|
|
||||||
temporal dependencies in the summarized frequency features before passing
|
|
||||||
them to the decoder.
|
|
||||||
|
|
||||||
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
|
class BottleneckConfig(BaseConfig):
|
||||||
|
"""Configuration for the bottleneck layer(s).
|
||||||
|
|
||||||
Parameters
|
Defines the number of channels within the bottleneck and whether to include
|
||||||
|
a self-attention mechanism.
|
||||||
|
|
||||||
|
Attributes
|
||||||
----------
|
----------
|
||||||
input_height : int
|
channels : int
|
||||||
Height (frequency bins) of the input tensor from the encoder.
|
The number of output channels produced by the main convolutional layer
|
||||||
in_channels : int
|
within the bottleneck. This often matches the number of channels coming
|
||||||
Number of channels in the input tensor from the encoder.
|
from the last encoder stage, but can be different. Must be positive.
|
||||||
out_channels : int
|
This also defines the channel dimensions used within the optional
|
||||||
Number of output channels produced by the `VerticalConv` and
|
`SelfAttention` layer.
|
||||||
subsequently processed and output by this bottleneck. Also determines
|
self_attention : bool
|
||||||
the input/output channels of the internal `SelfAttention` layer.
|
If True, includes a `SelfAttention` layer operating on the time
|
||||||
attention : nn.Module
|
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||||
An initialized `SelfAttention` module instance.
|
If False, only the initial `VerticalConv` (and height repetition) is
|
||||||
|
performed.
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
channels: int
|
||||||
self,
|
layers: List[BottleneckLayerConfig] = Field(
|
||||||
input_height: int,
|
default_factory=list,
|
||||||
in_channels: int,
|
)
|
||||||
out_channels: int,
|
|
||||||
attention: nn.Module,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the Bottleneck with Self-Attention."""
|
|
||||||
super().__init__(input_height, in_channels, out_channels)
|
|
||||||
self.attention = attention
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Process input tensor.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : torch.Tensor
|
|
||||||
Input tensor from the encoder bottleneck, shape
|
|
||||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
|
||||||
`H_in` must match `self.input_height`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
|
|
||||||
and repeating the height dimension.
|
|
||||||
"""
|
|
||||||
x = self.conv_vert(x)
|
|
||||||
x = self.attention(x)
|
|
||||||
return x.repeat([1, 1, self.input_height, 1])
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||||
channels=256,
|
channels=256,
|
||||||
self_attention=True,
|
layers=[
|
||||||
|
SelfAttentionConfig(attention_channels=256),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -234,21 +201,25 @@ def build_bottleneck(
|
|||||||
"""
|
"""
|
||||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
|
||||||
if config.self_attention:
|
current_channels = in_channels
|
||||||
attention = SelfAttention(
|
current_height = input_height
|
||||||
in_channels=config.channels,
|
|
||||||
attention_channels=config.channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
return BottleneckAttn(
|
layers = []
|
||||||
input_height=input_height,
|
|
||||||
in_channels=in_channels,
|
for layer_config in config.layers:
|
||||||
out_channels=config.channels,
|
layer, current_channels, current_height = build_layer_from_config(
|
||||||
attention=attention,
|
input_height=current_height,
|
||||||
|
in_channels=current_channels,
|
||||||
|
config=layer_config,
|
||||||
)
|
)
|
||||||
|
assert current_height == input_height, (
|
||||||
|
"Bottleneck layers should not change the spectrogram height"
|
||||||
|
)
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
return Bottleneck(
|
return Bottleneck(
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=config.channels,
|
out_channels=config.channels,
|
||||||
|
layers=layers,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -26,19 +26,32 @@ def create_ax(
|
|||||||
|
|
||||||
def plot_spectrogram(
|
def plot_spectrogram(
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
spec: Union[torch.Tensor, np.ndarray],
|
||||||
start_time: float,
|
start_time: Optional[float] = None,
|
||||||
end_time: float,
|
end_time: Optional[float] = None,
|
||||||
min_freq: float,
|
min_freq: Optional[float] = None,
|
||||||
max_freq: float,
|
max_freq: Optional[float] = None,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
cmap="gray",
|
cmap="gray",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
|
|
||||||
if isinstance(spec, torch.Tensor):
|
if isinstance(spec, torch.Tensor):
|
||||||
spec = spec.numpy()
|
spec = spec.numpy()
|
||||||
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
if start_time is None:
|
||||||
|
start_time = 0
|
||||||
|
|
||||||
|
if end_time is None:
|
||||||
|
end_time = spec.shape[-1]
|
||||||
|
|
||||||
|
if min_freq is None:
|
||||||
|
min_freq = 0
|
||||||
|
|
||||||
|
if max_freq is None:
|
||||||
|
max_freq = spec.shape[-2]
|
||||||
|
|
||||||
ax.pcolormesh(
|
ax.pcolormesh(
|
||||||
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
|
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
|
||||||
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
|
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -20,13 +21,15 @@ from batdetect2.postprocess.nms import (
|
|||||||
)
|
)
|
||||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol
|
from batdetect2.typing import ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
Detections,
|
Detections,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
@ -128,7 +131,6 @@ def load_postprocess_config(
|
|||||||
|
|
||||||
|
|
||||||
def build_postprocessor(
|
def build_postprocessor(
|
||||||
targets: TargetProtocol,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[PostprocessConfig] = None,
|
config: Optional[PostprocessConfig] = None,
|
||||||
) -> PostprocessorProtocol:
|
) -> PostprocessorProtocol:
|
||||||
@ -139,29 +141,52 @@ def build_postprocessor(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
return Postprocessor(
|
return Postprocessor(
|
||||||
targets=targets,
|
samplerate=preprocessor.output_samplerate,
|
||||||
preprocessor=preprocessor,
|
min_freq=preprocessor.min_freq,
|
||||||
config=config,
|
max_freq=preprocessor.max_freq,
|
||||||
|
top_k_per_sec=config.top_k_per_sec,
|
||||||
|
detection_threshold=config.detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Postprocessor(PostprocessorProtocol):
|
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||||
"""Standard implementation of the postprocessing pipeline."""
|
"""Standard implementation of the postprocessing pipeline."""
|
||||||
|
|
||||||
targets: TargetProtocol
|
|
||||||
|
|
||||||
preprocessor: PreprocessorProtocol
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
samplerate: float,
|
||||||
preprocessor: PreprocessorProtocol,
|
min_freq: float,
|
||||||
config: PostprocessConfig,
|
max_freq: float,
|
||||||
|
top_k_per_sec: int = 200,
|
||||||
|
detection_threshold: float = 0.01,
|
||||||
):
|
):
|
||||||
"""Initialize the Postprocessor."""
|
"""Initialize the Postprocessor."""
|
||||||
self.targets = targets
|
super().__init__()
|
||||||
self.preprocessor = preprocessor
|
self.samplerate = samplerate
|
||||||
self.config = config
|
self.min_freq = min_freq
|
||||||
|
self.max_freq = max_freq
|
||||||
|
self.top_k_per_sec = top_k_per_sec
|
||||||
|
self.detection_threshold = detection_threshold
|
||||||
|
|
||||||
|
def forward(self, output: ModelOutput) -> List[Detections]:
|
||||||
|
width = output.detection_probs.shape[-1]
|
||||||
|
duration = width / self.samplerate
|
||||||
|
max_detections = int(self.top_k_per_sec * duration)
|
||||||
|
detections = extract_prediction_tensor(
|
||||||
|
output,
|
||||||
|
max_detections=max_detections,
|
||||||
|
threshold=self.detection_threshold,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
map_detection_to_clip(
|
||||||
|
detection,
|
||||||
|
start_time=0,
|
||||||
|
end_time=duration,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for detection in detections
|
||||||
|
]
|
||||||
|
|
||||||
def get_detections(
|
def get_detections(
|
||||||
self,
|
self,
|
||||||
@ -169,13 +194,13 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
clips: Optional[List[data.Clip]] = None,
|
clips: Optional[List[data.Clip]] = None,
|
||||||
) -> List[Detections]:
|
) -> List[Detections]:
|
||||||
width = output.detection_probs.shape[-1]
|
width = output.detection_probs.shape[-1]
|
||||||
duration = width / self.preprocessor.output_samplerate
|
duration = width / self.samplerate
|
||||||
max_detections = int(self.config.top_k_per_sec * duration)
|
max_detections = int(self.top_k_per_sec * duration)
|
||||||
|
|
||||||
detections = extract_prediction_tensor(
|
detections = extract_prediction_tensor(
|
||||||
output,
|
output,
|
||||||
max_detections=max_detections,
|
max_detections=max_detections,
|
||||||
threshold=self.config.detection_threshold,
|
threshold=self.detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
if clips is None:
|
if clips is None:
|
||||||
@ -186,96 +211,116 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
detection,
|
detection,
|
||||||
start_time=clip.start_time,
|
start_time=clip.start_time,
|
||||||
end_time=clip.end_time,
|
end_time=clip.end_time,
|
||||||
min_freq=self.preprocessor.min_freq,
|
min_freq=self.min_freq,
|
||||||
max_freq=self.preprocessor.max_freq,
|
max_freq=self.max_freq,
|
||||||
)
|
)
|
||||||
for detection, clip in zip(detections, clips)
|
for detection, clip in zip(detections, clips)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_raw_predictions(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[List[RawPrediction]]:
|
|
||||||
"""Extract intermediate RawPrediction objects for a batch.
|
|
||||||
|
|
||||||
Processes raw model output through remapping, NMS, detection, data
|
def get_raw_predictions(
|
||||||
extraction, and geometry recovery via the configured
|
output: ModelOutput,
|
||||||
`targets.recover_roi`.
|
clips: List[data.Clip],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
) -> List[List[RawPrediction]]:
|
||||||
|
"""Extract intermediate RawPrediction objects for a batch.
|
||||||
|
|
||||||
Parameters
|
Processes raw model output through remapping, NMS, detection, data
|
||||||
----------
|
extraction, and geometry recovery via the configured
|
||||||
output : ModelOutput
|
`targets.recover_roi`.
|
||||||
Raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
|
||||||
|
|
||||||
Returns
|
Parameters
|
||||||
-------
|
----------
|
||||||
List[List[RawPrediction]]
|
output : ModelOutput
|
||||||
List of lists (one inner list per input clip). Each inner list
|
Raw output from the neural network model for a batch.
|
||||||
contains `RawPrediction` objects for detections in that clip.
|
clips : List[data.Clip]
|
||||||
"""
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
detections = self.get_detections(output, clips)
|
|
||||||
return [
|
Returns
|
||||||
convert_detections_to_raw_predictions(
|
-------
|
||||||
dataset,
|
List[List[RawPrediction]]
|
||||||
targets=self.targets,
|
List of lists (one inner list per input clip). Each inner list
|
||||||
|
contains `RawPrediction` objects for detections in that clip.
|
||||||
|
"""
|
||||||
|
detections = postprocessor.get_detections(output, clips)
|
||||||
|
return [
|
||||||
|
convert_detections_to_raw_predictions(
|
||||||
|
dataset,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
for dataset in detections
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_sound_event_predictions(
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
) -> List[List[BatDetect2Prediction]]:
|
||||||
|
raw_predictions = get_raw_predictions(
|
||||||
|
output,
|
||||||
|
clips,
|
||||||
|
targets=targets,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
BatDetect2Prediction(
|
||||||
|
raw=raw,
|
||||||
|
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||||
|
raw,
|
||||||
|
recording=clip.recording,
|
||||||
|
targets=targets,
|
||||||
|
classification_threshold=classification_threshold,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for dataset in detections
|
for raw in predictions
|
||||||
]
|
]
|
||||||
|
for predictions, clip in zip(raw_predictions, clips)
|
||||||
|
]
|
||||||
|
|
||||||
def get_sound_event_predictions(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[List[BatDetect2Prediction]]:
|
|
||||||
raw_predictions = self.get_raw_predictions(output, clips)
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
BatDetect2Prediction(
|
|
||||||
raw=raw,
|
|
||||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
|
||||||
raw,
|
|
||||||
recording=clip.recording,
|
|
||||||
targets=self.targets,
|
|
||||||
classification_threshold=self.config.classification_threshold,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for raw in predictions
|
|
||||||
]
|
|
||||||
for predictions, clip in zip(raw_predictions, clips)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_predictions(
|
def get_predictions(
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
output: ModelOutput,
|
||||||
) -> List[data.ClipPrediction]:
|
clips: List[data.Clip],
|
||||||
"""Perform the full postprocessing pipeline for a batch.
|
targets: TargetProtocol,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
) -> List[data.ClipPrediction]:
|
||||||
|
"""Perform the full postprocessing pipeline for a batch.
|
||||||
|
|
||||||
Takes raw model output and corresponding clips, applies the entire
|
Takes raw model output and corresponding clips, applies the entire
|
||||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
output : ModelOutput
|
output : ModelOutput
|
||||||
Raw output from the neural network model for a batch.
|
Raw output from the neural network model for a batch.
|
||||||
clips : List[data.Clip]
|
clips : List[data.Clip]
|
||||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[data.ClipPrediction]
|
List[data.ClipPrediction]
|
||||||
List containing one `ClipPrediction` object for each input clip,
|
List containing one `ClipPrediction` object for each input clip,
|
||||||
populated with `SoundEventPrediction` objects.
|
populated with `SoundEventPrediction` objects.
|
||||||
"""
|
"""
|
||||||
raw_predictions = self.get_raw_predictions(output, clips)
|
raw_predictions = get_raw_predictions(
|
||||||
return [
|
output,
|
||||||
convert_raw_predictions_to_clip_prediction(
|
clips,
|
||||||
prediction,
|
targets=targets,
|
||||||
clip,
|
postprocessor=postprocessor,
|
||||||
targets=self.targets,
|
)
|
||||||
classification_threshold=self.config.classification_threshold,
|
return [
|
||||||
)
|
convert_raw_predictions_to_clip_prediction(
|
||||||
for prediction, clip in zip(raw_predictions, clips)
|
prediction,
|
||||||
]
|
clip,
|
||||||
|
targets=targets,
|
||||||
|
classification_threshold=classification_threshold,
|
||||||
|
)
|
||||||
|
for prediction, clip in zip(raw_predictions, clips)
|
||||||
|
]
|
||||||
|
|||||||
@ -139,7 +139,21 @@ class FrequencyClip(torch.nn.Module):
|
|||||||
self.high_index = high_index
|
self.high_index = high_index
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
return spec[self.low_index : self.high_index]
|
low_index = self.low_index
|
||||||
|
if low_index is None:
|
||||||
|
low_index = 0
|
||||||
|
|
||||||
|
if self.high_index is None:
|
||||||
|
length = spec.shape[-2] - low_index
|
||||||
|
else:
|
||||||
|
length = self.high_index - low_index
|
||||||
|
|
||||||
|
return torch.narrow(
|
||||||
|
spec,
|
||||||
|
dim=-2,
|
||||||
|
start=low_index,
|
||||||
|
length=length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseConfig):
|
class PcenConfig(BaseConfig):
|
||||||
@ -256,16 +270,22 @@ class ResizeSpec(torch.nn.Module):
|
|||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
current_length = spec.shape[-1]
|
current_length = spec.shape[-1]
|
||||||
target_length = int(self.time_factor * current_length)
|
target_length = int(self.time_factor * current_length)
|
||||||
return (
|
|
||||||
torch.nn.functional.interpolate(
|
original_ndim = spec.ndim
|
||||||
spec.unsqueeze(0).unsqueeze(0),
|
while spec.ndim < 4:
|
||||||
size=(self.height, target_length),
|
spec = spec.unsqueeze(0)
|
||||||
mode="bilinear",
|
|
||||||
)
|
resized = torch.nn.functional.interpolate(
|
||||||
.squeeze(0)
|
spec,
|
||||||
.squeeze(0)
|
size=(self.height, target_length),
|
||||||
|
mode="bilinear",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
while resized.ndim != original_ndim:
|
||||||
|
resized = resized.squeeze(0)
|
||||||
|
|
||||||
|
return resized
|
||||||
|
|
||||||
|
|
||||||
class PeakNormalizeConfig(BaseConfig):
|
class PeakNormalizeConfig(BaseConfig):
|
||||||
name: Literal["peak_normalize"] = "peak_normalize"
|
name: Literal["peak_normalize"] = "peak_normalize"
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from batdetect2.train.augmentations import (
|
|||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
EchoAugmentationConfig,
|
EchoAugmentationConfig,
|
||||||
FrequencyMaskAugmentationConfig,
|
FrequencyMaskAugmentationConfig,
|
||||||
|
RandomExampleSource,
|
||||||
TimeMaskAugmentationConfig,
|
TimeMaskAugmentationConfig,
|
||||||
VolumeAugmentationConfig,
|
VolumeAugmentationConfig,
|
||||||
WarpAugmentationConfig,
|
WarpAugmentationConfig,
|
||||||
@ -23,7 +24,6 @@ from batdetect2.train.config import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
|
||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Applies data augmentation techniques to BatDetect2 training examples."""
|
"""Applies data augmentation techniques to BatDetect2 training examples."""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
|
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,8 +11,12 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
list_preprocessed_files,
|
||||||
|
load_preprocessed_example,
|
||||||
|
)
|
||||||
from batdetect2.typing import Augmentation, PreprocessorProtocol
|
from batdetect2.typing import Augmentation, PreprocessorProtocol
|
||||||
from batdetect2.typing.train import PreprocessedExample
|
from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
|
||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -39,21 +44,6 @@ ExampleSource = Callable[[], PreprocessedExample]
|
|||||||
"""Type alias for a function that returns a training example"""
|
"""Type alias for a function that returns a training example"""
|
||||||
|
|
||||||
|
|
||||||
class MixAugmentationConfig(BaseConfig):
|
|
||||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
|
||||||
|
|
||||||
augmentation_type: Literal["mix_audio"] = "mix_audio"
|
|
||||||
|
|
||||||
probability: float = 0.2
|
|
||||||
"""Probability of applying this augmentation to an example."""
|
|
||||||
|
|
||||||
min_weight: float = 0.3
|
|
||||||
"""Minimum mixing weight (lambda) applied to the primary example."""
|
|
||||||
|
|
||||||
max_weight: float = 0.7
|
|
||||||
"""Maximum mixing weight (lambda) applied to the primary example."""
|
|
||||||
|
|
||||||
|
|
||||||
def mix_examples(
|
def mix_examples(
|
||||||
example: PreprocessedExample,
|
example: PreprocessedExample,
|
||||||
other: PreprocessedExample,
|
other: PreprocessedExample,
|
||||||
@ -149,7 +139,12 @@ def add_echo(
|
|||||||
|
|
||||||
audio = example.audio
|
audio = example.audio
|
||||||
delay_steps = int(preprocessor.input_samplerate * delay)
|
delay_steps = int(preprocessor.input_samplerate * delay)
|
||||||
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
|
|
||||||
|
slices = [slice(None)] * audio.ndim
|
||||||
|
slices[-1] = slice(None, -delay_steps)
|
||||||
|
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
|
||||||
|
delay_steps, dims=-1
|
||||||
|
)
|
||||||
|
|
||||||
audio = audio + weight * audio_delay
|
audio = audio + weight * audio_delay
|
||||||
spectrogram = preprocessor(audio)
|
spectrogram = preprocessor(audio)
|
||||||
@ -184,7 +179,7 @@ class VolumeAugmentationConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class ScaleVolume(torch.nn.Module):
|
class ScaleVolume(torch.nn.Module):
|
||||||
def __init__(self, min_scaling: float, max_scaling: float):
|
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.min_scaling = min_scaling
|
self.min_scaling = min_scaling
|
||||||
self.max_scaling = max_scaling
|
self.max_scaling = max_scaling
|
||||||
@ -228,32 +223,22 @@ def warp_spectrogram(
|
|||||||
example: PreprocessedExample, factor: float
|
example: PreprocessedExample, factor: float
|
||||||
) -> PreprocessedExample:
|
) -> PreprocessedExample:
|
||||||
"""Apply time warping by resampling the time axis."""
|
"""Apply time warping by resampling the time axis."""
|
||||||
target_shape = example.spectrogram.shape
|
width = example.spectrogram.shape[-1]
|
||||||
|
height = example.spectrogram.shape[-2]
|
||||||
|
target_shape = [height, width]
|
||||||
new_width = int(target_shape[-1] * factor)
|
new_width = int(target_shape[-1] * factor)
|
||||||
|
|
||||||
spectrogram = (
|
spectrogram = torch.nn.functional.interpolate(
|
||||||
torch.nn.functional.interpolate(
|
adjust_width(example.spectrogram, new_width).unsqueeze(0),
|
||||||
adjust_width(example.spectrogram, new_width)
|
size=target_shape,
|
||||||
.unsqueeze(0)
|
mode="bilinear",
|
||||||
.unsqueeze(0),
|
).squeeze(0)
|
||||||
size=target_shape,
|
|
||||||
mode="bilinear",
|
|
||||||
)
|
|
||||||
.squeeze(0)
|
|
||||||
.squeeze(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
detection = (
|
detection = torch.nn.functional.interpolate(
|
||||||
torch.nn.functional.interpolate(
|
adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
|
||||||
adjust_width(example.detection_heatmap, new_width)
|
size=target_shape,
|
||||||
.unsqueeze(0)
|
mode="nearest",
|
||||||
.unsqueeze(0),
|
).squeeze(0)
|
||||||
size=target_shape,
|
|
||||||
mode="nearest",
|
|
||||||
)
|
|
||||||
.squeeze(0)
|
|
||||||
.squeeze(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
classification = torch.nn.functional.interpolate(
|
classification = torch.nn.functional.interpolate(
|
||||||
adjust_width(example.class_heatmap, new_width).unsqueeze(1),
|
adjust_width(example.class_heatmap, new_width).unsqueeze(1),
|
||||||
@ -284,10 +269,16 @@ class TimeMaskAugmentationConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class MaskTime(torch.nn.Module):
|
class MaskTime(torch.nn.Module):
|
||||||
def __init__(self, max_perc: float = 0.05, max_masks: int = 3) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_perc: float = 0.05,
|
||||||
|
max_masks: int = 3,
|
||||||
|
mask_heatmaps: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_perc = max_perc
|
self.max_perc = max_perc
|
||||||
self.max_masks = max_masks
|
self.max_masks = max_masks
|
||||||
|
self.mask_heatmaps = mask_heatmaps
|
||||||
|
|
||||||
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
||||||
num_masks = np.random.randint(1, self.max_masks + 1)
|
num_masks = np.random.randint(1, self.max_masks + 1)
|
||||||
@ -306,20 +297,28 @@ class MaskTime(torch.nn.Module):
|
|||||||
masks = [
|
masks = [
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
(start, start + size) for start, size in zip(mask_start, mask_size)
|
||||||
]
|
]
|
||||||
return mask_time(example, masks)
|
return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
|
||||||
|
|
||||||
|
|
||||||
def mask_time(
|
def mask_time(
|
||||||
example: PreprocessedExample,
|
example: PreprocessedExample,
|
||||||
masks: List[Tuple[int, int]],
|
masks: List[Tuple[int, int]],
|
||||||
|
mask_heatmaps: bool = False,
|
||||||
) -> PreprocessedExample:
|
) -> PreprocessedExample:
|
||||||
"""Apply time masking to the spectrogram."""
|
"""Apply time masking to the spectrogram."""
|
||||||
|
|
||||||
for start, end in masks:
|
for start, end in masks:
|
||||||
example.spectrogram[:, start:end] = example.spectrogram.mean()
|
slices = [slice(None)] * example.spectrogram.ndim
|
||||||
example.class_heatmap[:, :, start:end] = 0
|
slices[-1] = slice(start, end)
|
||||||
example.size_heatmap[:, :, start:end] = 0
|
|
||||||
example.detection_heatmap[:, start:end] = 0
|
example.spectrogram[tuple(slices)] = 0
|
||||||
|
|
||||||
|
if not mask_heatmaps:
|
||||||
|
continue
|
||||||
|
|
||||||
|
example.class_heatmap[tuple(slices)] = 0
|
||||||
|
example.size_heatmap[tuple(slices)] = 0
|
||||||
|
example.detection_heatmap[tuple(slices)] = 0
|
||||||
|
|
||||||
return PreprocessedExample(
|
return PreprocessedExample(
|
||||||
audio=example.audio,
|
audio=example.audio,
|
||||||
@ -335,13 +334,20 @@ class FrequencyMaskAugmentationConfig(BaseConfig):
|
|||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_perc: float = 0.10
|
max_perc: float = 0.10
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
|
mask_heatmaps: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MaskFrequency(torch.nn.Module):
|
class MaskFrequency(torch.nn.Module):
|
||||||
def __init__(self, max_perc: float = 0.10, max_masks: int = 3) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_perc: float = 0.10,
|
||||||
|
max_masks: int = 3,
|
||||||
|
mask_heatmaps: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_perc = max_perc
|
self.max_perc = max_perc
|
||||||
self.max_masks = max_masks
|
self.max_masks = max_masks
|
||||||
|
self.mask_heatmaps = mask_heatmaps
|
||||||
|
|
||||||
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
||||||
num_masks = np.random.randint(1, self.max_masks + 1)
|
num_masks = np.random.randint(1, self.max_masks + 1)
|
||||||
@ -360,19 +366,26 @@ class MaskFrequency(torch.nn.Module):
|
|||||||
masks = [
|
masks = [
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
(start, start + size) for start, size in zip(mask_start, mask_size)
|
||||||
]
|
]
|
||||||
return mask_frequency(example, masks)
|
return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
|
||||||
|
|
||||||
|
|
||||||
def mask_frequency(
|
def mask_frequency(
|
||||||
example: PreprocessedExample,
|
example: PreprocessedExample,
|
||||||
masks: List[Tuple[int, int]],
|
masks: List[Tuple[int, int]],
|
||||||
|
mask_heatmaps: bool = False,
|
||||||
) -> PreprocessedExample:
|
) -> PreprocessedExample:
|
||||||
"""Apply frequency masking to the spectrogram."""
|
"""Apply frequency masking to the spectrogram."""
|
||||||
for start, end in masks:
|
for start, end in masks:
|
||||||
example.spectrogram[start:end, :] = example.spectrogram.mean()
|
slices = [slice(None)] * example.spectrogram.ndim
|
||||||
example.class_heatmap[:, start:end, :] = 0
|
slices[-2] = slice(start, end)
|
||||||
example.size_heatmap[:, start:end, :] = 0
|
example.spectrogram[tuple(slices)] = 0
|
||||||
example.detection_heatmap[start:end, :] = 0
|
|
||||||
|
if not mask_heatmaps:
|
||||||
|
continue
|
||||||
|
|
||||||
|
example.class_heatmap[tuple(slices)] = 0
|
||||||
|
example.size_heatmap[tuple(slices)] = 0
|
||||||
|
example.detection_heatmap[tuple(slices)] = 0
|
||||||
|
|
||||||
return PreprocessedExample(
|
return PreprocessedExample(
|
||||||
audio=example.audio,
|
audio=example.audio,
|
||||||
@ -383,6 +396,50 @@ def mask_frequency(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MixAugmentationConfig(BaseConfig):
|
||||||
|
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||||
|
|
||||||
|
augmentation_type: Literal["mix_audio"] = "mix_audio"
|
||||||
|
|
||||||
|
probability: float = 0.2
|
||||||
|
"""Probability of applying this augmentation to an example."""
|
||||||
|
|
||||||
|
min_weight: float = 0.3
|
||||||
|
"""Minimum mixing weight (lambda) applied to the primary example."""
|
||||||
|
|
||||||
|
max_weight: float = 0.7
|
||||||
|
"""Maximum mixing weight (lambda) applied to the primary example."""
|
||||||
|
|
||||||
|
|
||||||
|
class MixAudio(torch.nn.Module):
|
||||||
|
"""Callable class for MixUp augmentation, handling example fetching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
example_source: ExampleSource,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
min_weight: float = 0.3,
|
||||||
|
max_weight: float = 0.7,
|
||||||
|
):
|
||||||
|
"""Initialize the AudioMixer."""
|
||||||
|
super().__init__()
|
||||||
|
self.min_weight = min_weight
|
||||||
|
self.example_source = example_source
|
||||||
|
self.max_weight = max_weight
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
|
||||||
|
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
|
||||||
|
"""Fetch another example and perform mixup."""
|
||||||
|
other = self.example_source()
|
||||||
|
weight = np.random.uniform(self.min_weight, self.max_weight)
|
||||||
|
return mix_examples(
|
||||||
|
example,
|
||||||
|
other,
|
||||||
|
self.preprocessor,
|
||||||
|
weight=weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
MixAugmentationConfig,
|
MixAugmentationConfig,
|
||||||
@ -445,35 +502,6 @@ class MaybeApply(torch.nn.Module):
|
|||||||
return self.augmentation(example)
|
return self.augmentation(example)
|
||||||
|
|
||||||
|
|
||||||
class AudioMixer(torch.nn.Module):
|
|
||||||
"""Callable class for MixUp augmentation, handling example fetching."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
min_weight: float,
|
|
||||||
max_weight: float,
|
|
||||||
example_source: ExampleSource,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
):
|
|
||||||
"""Initialize the AudioMixer."""
|
|
||||||
super().__init__()
|
|
||||||
self.min_weight = min_weight
|
|
||||||
self.example_source = example_source
|
|
||||||
self.max_weight = max_weight
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
|
|
||||||
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
|
|
||||||
"""Fetch another example and perform mixup."""
|
|
||||||
other = self.example_source()
|
|
||||||
weight = np.random.uniform(self.min_weight, self.max_weight)
|
|
||||||
return mix_examples(
|
|
||||||
example,
|
|
||||||
other,
|
|
||||||
self.preprocessor,
|
|
||||||
weight=weight,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_augmentation_from_config(
|
def build_augmentation_from_config(
|
||||||
config: AugmentationConfig,
|
config: AugmentationConfig,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
@ -489,7 +517,7 @@ def build_augmentation_from_config(
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AudioMixer(
|
return MixAudio(
|
||||||
example_source=example_source,
|
example_source=example_source,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
min_weight=config.min_weight,
|
min_weight=config.min_weight,
|
||||||
@ -585,3 +613,25 @@ def load_augmentation_config(
|
|||||||
) -> AugmentationsConfig:
|
) -> AugmentationsConfig:
|
||||||
"""Load the augmentations configuration from a file."""
|
"""Load the augmentations configuration from a file."""
|
||||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomExampleSource:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filenames: Sequence[data.PathLike],
|
||||||
|
clipper: ClipperProtocol,
|
||||||
|
):
|
||||||
|
self.filenames = filenames
|
||||||
|
self.clipper = clipper
|
||||||
|
|
||||||
|
def __call__(self) -> PreprocessedExample:
|
||||||
|
index = int(np.random.randint(len(self.filenames)))
|
||||||
|
filename = self.filenames[index]
|
||||||
|
example = load_preprocessed_example(filename)
|
||||||
|
example, _, _ = self.clipper(example)
|
||||||
|
return example
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol):
|
||||||
|
filenames = list_preprocessed_files(path)
|
||||||
|
return cls(filenames, clipper=clipper)
|
||||||
|
|||||||
@ -14,7 +14,9 @@ from batdetect2.evaluate.match import (
|
|||||||
MatchConfig,
|
MatchConfig,
|
||||||
match_sound_events_and_raw_predictions,
|
match_sound_events_and_raw_predictions,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models import Model
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
|
from batdetect2.postprocess import get_sound_event_predictions
|
||||||
from batdetect2.train.dataset import LabeledDataset
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
@ -22,7 +24,6 @@ from batdetect2.typing import (
|
|||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
MetricsProtocol,
|
MetricsProtocol,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
PostprocessorProtocol,
|
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
)
|
)
|
||||||
@ -127,8 +128,7 @@ class ValidationMetrics(Callback):
|
|||||||
batch,
|
batch,
|
||||||
outputs,
|
outputs,
|
||||||
dataset=self.get_dataset(trainer),
|
dataset=self.get_dataset(trainer),
|
||||||
postprocessor=pl_module.model.postprocessor,
|
model=pl_module.model,
|
||||||
targets=pl_module.model.targets,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,15 +137,14 @@ def _get_batch_clips_and_predictions(
|
|||||||
batch: TrainExample,
|
batch: TrainExample,
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
dataset: LabeledDataset,
|
dataset: LabeledDataset,
|
||||||
postprocessor: PostprocessorProtocol,
|
model: Model,
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
_get_subclip(
|
_get_subclip(
|
||||||
dataset.get_clip_annotation(example_id),
|
dataset.get_clip_annotation(example_id),
|
||||||
start_time=start_time.item(),
|
start_time=start_time.item(),
|
||||||
end_time=end_time.item(),
|
end_time=end_time.item(),
|
||||||
targets=targets,
|
targets=model.targets,
|
||||||
)
|
)
|
||||||
for example_id, start_time, end_time in zip(
|
for example_id, start_time, end_time in zip(
|
||||||
batch.idx,
|
batch.idx,
|
||||||
@ -156,9 +155,11 @@ def _get_batch_clips_and_predictions(
|
|||||||
|
|
||||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||||
|
|
||||||
raw_predictions = postprocessor.get_sound_event_predictions(
|
raw_predictions = get_sound_event_predictions(
|
||||||
outputs,
|
outputs,
|
||||||
clips,
|
clips,
|
||||||
|
targets=model.targets,
|
||||||
|
postprocessor=model.postprocessor
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from batdetect2.configs import BaseConfig
|
|||||||
from batdetect2.typing import ClipperProtocol
|
from batdetect2.typing import ClipperProtocol
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.train import PreprocessedExample
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width, slice_tensor
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.512
|
DEFAULT_TRAIN_CLIP_DURATION = 0.512
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
@ -90,7 +90,12 @@ def select_subclip(
|
|||||||
audio_start = int(np.floor(start * input_samplerate))
|
audio_start = int(np.floor(start * input_samplerate))
|
||||||
|
|
||||||
audio = adjust_width(
|
audio = adjust_width(
|
||||||
example.audio[audio_start : audio_start + audio_width],
|
slice_tensor(
|
||||||
|
example.audio,
|
||||||
|
start=audio_start,
|
||||||
|
end=audio_start + audio_width,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
audio_width,
|
audio_width,
|
||||||
value=fill_value,
|
value=fill_value,
|
||||||
)
|
)
|
||||||
@ -100,19 +105,39 @@ def select_subclip(
|
|||||||
return PreprocessedExample(
|
return PreprocessedExample(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
spectrogram=adjust_width(
|
spectrogram=adjust_width(
|
||||||
example.spectrogram[:, spec_start : spec_start + spec_width],
|
slice_tensor(
|
||||||
|
example.spectrogram,
|
||||||
|
start=spec_start,
|
||||||
|
end=spec_start + spec_width,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
spec_width,
|
spec_width,
|
||||||
),
|
),
|
||||||
class_heatmap=adjust_width(
|
class_heatmap=adjust_width(
|
||||||
example.class_heatmap[:, :, spec_start : spec_start + spec_width],
|
slice_tensor(
|
||||||
|
example.class_heatmap,
|
||||||
|
start=spec_start,
|
||||||
|
end=spec_start + spec_width,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
spec_width,
|
spec_width,
|
||||||
),
|
),
|
||||||
detection_heatmap=adjust_width(
|
detection_heatmap=adjust_width(
|
||||||
example.detection_heatmap[:, spec_start : spec_start + spec_width],
|
slice_tensor(
|
||||||
|
example.detection_heatmap,
|
||||||
|
start=spec_start,
|
||||||
|
end=spec_start + spec_width,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
spec_width,
|
spec_width,
|
||||||
),
|
),
|
||||||
size_heatmap=adjust_width(
|
size_heatmap=adjust_width(
|
||||||
example.size_heatmap[:, :, spec_start : spec_start + spec_width],
|
slice_tensor(
|
||||||
|
example.size_heatmap,
|
||||||
|
start=spec_start,
|
||||||
|
end=spec_start + spec_width,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
spec_width,
|
spec_width,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -44,8 +44,8 @@ class PLTrainerConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class DataLoaderConfig(BaseConfig):
|
class DataLoaderConfig(BaseConfig):
|
||||||
batch_size: int
|
batch_size: int = 8
|
||||||
shuffle: bool
|
shuffle: bool = False
|
||||||
num_workers: int = 0
|
num_workers: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from typing import Optional, Sequence, Tuple
|
||||||
from typing import List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -7,6 +6,10 @@ from soundevent import data
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from batdetect2.train.augmentations import Augmentation
|
from batdetect2.train.augmentations import Augmentation
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
list_preprocessed_files,
|
||||||
|
load_preprocessed_example,
|
||||||
|
)
|
||||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||||
from batdetect2.typing.train import PreprocessedExample
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
|
|
||||||
@ -38,8 +41,8 @@ class LabeledDataset(Dataset):
|
|||||||
example = self.augmentation(example)
|
example = self.augmentation(example)
|
||||||
|
|
||||||
return TrainExample(
|
return TrainExample(
|
||||||
spec=example.spectrogram.unsqueeze(0),
|
spec=example.spectrogram,
|
||||||
detection_heatmap=example.detection_heatmap.unsqueeze(0),
|
detection_heatmap=example.detection_heatmap,
|
||||||
class_heatmap=example.class_heatmap,
|
class_heatmap=example.class_heatmap,
|
||||||
size_heatmap=example.size_heatmap,
|
size_heatmap=example.size_heatmap,
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
@ -73,37 +76,3 @@ class LabeledDataset(Dataset):
|
|||||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||||
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
|
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
|
||||||
return item["clip_annotation"].tolist()
|
return item["clip_annotation"].tolist()
|
||||||
|
|
||||||
|
|
||||||
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
|
||||||
item = np.load(path, mmap_mode="r+")
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=torch.tensor(item["audio"]),
|
|
||||||
spectrogram=torch.tensor(item["spectrogram"]),
|
|
||||||
size_heatmap=torch.tensor(item["size_heatmap"]),
|
|
||||||
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
|
||||||
class_heatmap=torch.tensor(item["class_heatmap"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def list_preprocessed_files(
|
|
||||||
directory: data.PathLike, extension: str = ".npz"
|
|
||||||
) -> List[Path]:
|
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
|
||||||
|
|
||||||
|
|
||||||
class RandomExampleSource:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
filenames: List[data.PathLike],
|
|
||||||
clipper: ClipperProtocol,
|
|
||||||
):
|
|
||||||
self.filenames = filenames
|
|
||||||
self.clipper = clipper
|
|
||||||
|
|
||||||
def __call__(self) -> PreprocessedExample:
|
|
||||||
index = int(np.random.randint(len(self.filenames)))
|
|
||||||
filename = self.filenames[index]
|
|
||||||
example = load_preprocessed_example(filename)
|
|
||||||
example, _, _ = self.clipper(example)
|
|
||||||
return example
|
|
||||||
|
|||||||
@ -41,7 +41,6 @@ from batdetect2.typing import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
"build_clip_labeler",
|
"build_clip_labeler",
|
||||||
"generate_clip_label",
|
|
||||||
"generate_heatmaps",
|
"generate_heatmaps",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
]
|
]
|
||||||
@ -99,21 +98,26 @@ def build_clip_labeler(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
return partial(
|
return partial(
|
||||||
generate_clip_label,
|
generate_heatmaps,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config,
|
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
max_freq=max_freq,
|
max_freq=max_freq,
|
||||||
|
target_sigma=config.sigma,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_clip_label(
|
def map_to_pixels(x, size, min_val, max_val) -> int:
|
||||||
|
return int(np.interp(x, [min_val, max_val], [0, size]))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_heatmaps(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: LabelConfig,
|
|
||||||
min_freq: float,
|
min_freq: float,
|
||||||
max_freq: float,
|
max_freq: float,
|
||||||
|
target_sigma: float = 3.0,
|
||||||
|
dtype=torch.float32,
|
||||||
) -> Heatmaps:
|
) -> Heatmaps:
|
||||||
"""Generate training heatmaps for a single annotated clip.
|
"""Generate training heatmaps for a single annotated clip.
|
||||||
|
|
||||||
@ -150,57 +154,14 @@ def generate_clip_label(
|
|||||||
num=len(clip_annotation.sound_events),
|
num=len(clip_annotation.sound_events),
|
||||||
)
|
)
|
||||||
|
|
||||||
sound_events = []
|
height = spec.shape[-2]
|
||||||
|
width = spec.shape[-1]
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
|
||||||
if not targets.filter(sound_event_annotation):
|
|
||||||
logger.debug(
|
|
||||||
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
|
|
||||||
sound_event=sound_event_annotation,
|
|
||||||
tags=sound_event_annotation.tags,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
sound_events.append(targets.transform(sound_event_annotation))
|
|
||||||
|
|
||||||
return generate_heatmaps(
|
|
||||||
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
|
|
||||||
spec=spec,
|
|
||||||
targets=targets,
|
|
||||||
target_sigma=config.sigma,
|
|
||||||
min_freq=min_freq,
|
|
||||||
max_freq=max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def map_to_pixels(x, size, min_val, max_val) -> int:
|
|
||||||
return int(np.interp(x, [min_val, max_val], [0, size]))
|
|
||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
spec: torch.Tensor,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
min_freq: float,
|
|
||||||
max_freq: float,
|
|
||||||
target_sigma: float = 3.0,
|
|
||||||
dtype=torch.float32,
|
|
||||||
) -> Heatmaps:
|
|
||||||
if not spec.ndim == 2:
|
|
||||||
raise ValueError(
|
|
||||||
"Expecting a 2-dimensional tensor of shape (H, W), "
|
|
||||||
"H is the height of the spectrogram "
|
|
||||||
"(frequency bins), and W is the width of the spectrogram "
|
|
||||||
f"(temporal bins). Instead got: {spec.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
height, width = spec.shape
|
|
||||||
num_classes = len(targets.class_names)
|
num_classes = len(targets.class_names)
|
||||||
num_dims = len(targets.dimension_names)
|
num_dims = len(targets.dimension_names)
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
detection_heatmap = torch.zeros([height, width], dtype=dtype)
|
detection_heatmap = torch.zeros([1, height, width], dtype=dtype)
|
||||||
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
|
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
|
||||||
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
|
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
|
||||||
|
|
||||||
@ -214,6 +175,16 @@ def generate_heatmaps(
|
|||||||
times = times.to(spec.device)
|
times = times.to(spec.device)
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
|
if not targets.filter(sound_event_annotation):
|
||||||
|
logger.debug(
|
||||||
|
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
|
||||||
|
sound_event=sound_event_annotation,
|
||||||
|
tags=sound_event_annotation.tags,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
sound_event_annotation = targets.transform(sound_event_annotation)
|
||||||
|
|
||||||
geom = sound_event_annotation.sound_event.geometry
|
geom = sound_event_annotation.sound_event.geometry
|
||||||
if geom is None:
|
if geom is None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -245,7 +216,10 @@ def generate_heatmaps(
|
|||||||
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
|
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
|
||||||
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
|
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
|
||||||
|
|
||||||
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob)
|
detection_heatmap[0] = torch.maximum(
|
||||||
|
detection_heatmap[0],
|
||||||
|
gaussian_blob,
|
||||||
|
)
|
||||||
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
|
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
|
||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
return self.model(spec)
|
return self.model(spec)
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample):
|
def training_step(self, batch: TrainExample):
|
||||||
outputs = self.model(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/train", losses.total, logger=True)
|
self.log("detection_loss/train", losses.total, logger=True)
|
||||||
@ -47,7 +47,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
batch: TrainExample,
|
batch: TrainExample,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
outputs = self.model(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/val", losses.total, logger=True)
|
self.log("detection_loss/val", losses.total, logger=True)
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Sequence, TypedDict
|
from typing import Callable, List, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -28,6 +28,8 @@ __all__ = [
|
|||||||
"preprocess_dataset",
|
"preprocess_dataset",
|
||||||
"TrainPreprocessConfig",
|
"TrainPreprocessConfig",
|
||||||
"load_train_preprocessing_config",
|
"load_train_preprocessing_config",
|
||||||
|
"save_preprocessed_example",
|
||||||
|
"load_preprocessed_example",
|
||||||
]
|
]
|
||||||
|
|
||||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||||
@ -94,8 +96,10 @@ def generate_train_example(
|
|||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
) -> PreprocessedExample:
|
) -> PreprocessedExample:
|
||||||
"""Generate a complete training example for one annotation."""
|
"""Generate a complete training example for one annotation."""
|
||||||
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
|
wave = torch.tensor(
|
||||||
spectrogram = preprocessor(wave)
|
audio_loader.load_clip(clip_annotation.clip)
|
||||||
|
).unsqueeze(0)
|
||||||
|
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
|
||||||
heatmaps = labeller(clip_annotation, spectrogram)
|
heatmaps = labeller(clip_annotation, spectrogram)
|
||||||
return PreprocessedExample(
|
return PreprocessedExample(
|
||||||
audio=wave,
|
audio=wave,
|
||||||
@ -145,7 +149,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
|||||||
labeller=self.labeller,
|
labeller=self.labeller,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_example_to_file(example, clip_annotation, path)
|
save_preprocessed_example(example, clip_annotation, path)
|
||||||
|
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
@ -153,7 +157,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
|||||||
return len(self.clips)
|
return len(self.clips)
|
||||||
|
|
||||||
|
|
||||||
def save_example_to_file(
|
def save_preprocessed_example(
|
||||||
example: PreprocessedExample,
|
example: PreprocessedExample,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
@ -169,6 +173,23 @@ def save_example_to_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||||
|
item = np.load(path, mmap_mode="r+")
|
||||||
|
return PreprocessedExample(
|
||||||
|
audio=torch.tensor(item["audio"]),
|
||||||
|
spectrogram=torch.tensor(item["spectrogram"]),
|
||||||
|
size_heatmap=torch.tensor(item["size_heatmap"]),
|
||||||
|
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
||||||
|
class_heatmap=torch.tensor(item["class_heatmap"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_preprocessed_files(
|
||||||
|
directory: data.PathLike, extension: str = ".npz"
|
||||||
|
) -> List[Path]:
|
||||||
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
|
||||||
|
|
||||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||||
"""Generate a default output filename based on the annotation UUID."""
|
"""Generate a default output filename based on the annotation UUID."""
|
||||||
return f"{clip_annotation.uuid}"
|
return f"{clip_annotation.uuid}"
|
||||||
|
|||||||
@ -14,14 +14,16 @@ from batdetect2.evaluate.metrics import (
|
|||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
from batdetect2.models import build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.train.augmentations import build_augmentations
|
from batdetect2.train.augmentations import (
|
||||||
|
RandomExampleSource,
|
||||||
|
build_augmentations,
|
||||||
|
)
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
|
||||||
)
|
)
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
@ -53,17 +55,13 @@ def train(
|
|||||||
):
|
):
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
if model_path is not None:
|
model = build_model(config=config)
|
||||||
logger.debug("Loading model from: {path}", path=model_path)
|
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
|
||||||
else:
|
|
||||||
module = build_training_module(config)
|
|
||||||
|
|
||||||
trainer = build_trainer(config, targets=module.model.targets)
|
trainer = build_trainer(config, targets=model.targets)
|
||||||
|
|
||||||
train_dataloader = build_train_loader(
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor=module.model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
@ -71,7 +69,7 @@ def train(
|
|||||||
val_dataloader = (
|
val_dataloader = (
|
||||||
build_val_loader(
|
build_val_loader(
|
||||||
val_examples,
|
val_examples,
|
||||||
preprocessor=module.model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
@ -79,6 +77,16 @@ def train(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_path is not None:
|
||||||
|
logger.debug("Loading model from: {path}", path=model_path)
|
||||||
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
|
else:
|
||||||
|
module = build_training_module(
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
batches_per_epoch=len(train_dataloader),
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
@ -88,14 +96,17 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
def build_training_module(
|
||||||
model = build_model(config=config)
|
model: Model,
|
||||||
|
config: FullTrainingConfig,
|
||||||
|
batches_per_epoch: int,
|
||||||
|
) -> TrainingModule:
|
||||||
loss = build_loss(config=config.train.loss)
|
loss = build_loss(config=config.train.loss)
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model=model,
|
model=model,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
learning_rate=config.train.learning_rate,
|
learning_rate=config.train.learning_rate,
|
||||||
t_max=config.train.t_max,
|
t_max=config.train.t_max * batches_per_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -95,69 +95,10 @@ class BatDetect2Prediction:
|
|||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||||
|
|
||||||
|
def __call__(self, output: ModelOutput) -> List[Detections]: ...
|
||||||
|
|
||||||
def get_detections(
|
def get_detections(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: Optional[List[data.Clip]] = None,
|
clips: Optional[List[data.Clip]] = None,
|
||||||
) -> List[Detections]: ...
|
) -> List[Detections]: ...
|
||||||
|
|
||||||
def get_raw_predictions(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[List[RawPrediction]]:
|
|
||||||
"""Extract intermediate RawPrediction objects for a batch.
|
|
||||||
|
|
||||||
Processes the raw model output for a batch through remapping, NMS,
|
|
||||||
detection, data extraction, and geometry recovery to produce a list of
|
|
||||||
`RawPrediction` objects for each corresponding input clip. This provides
|
|
||||||
a simplified, intermediate representation before final tag decoding.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
The raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
|
||||||
items, providing context. Must match the batch size of `output`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[List[RawPrediction]]
|
|
||||||
A list of lists (one inner list per input clip, in order). Each
|
|
||||||
inner list contains the `RawPrediction` objects extracted for the
|
|
||||||
corresponding input clip.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_sound_event_predictions(
|
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
|
||||||
) -> List[List[BatDetect2Prediction]]: ...
|
|
||||||
|
|
||||||
def get_predictions(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[data.ClipPrediction]:
|
|
||||||
"""Perform the full postprocessing pipeline for a batch.
|
|
||||||
|
|
||||||
Takes raw model output for a batch and corresponding clips, applies the
|
|
||||||
entire postprocessing chain, and returns the final, interpretable
|
|
||||||
predictions as a list of `soundevent.data.ClipPrediction` objects.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
The raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
|
||||||
items, providing context. Must match the batch size of `output`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.ClipPrediction]
|
|
||||||
A list containing one `ClipPrediction` object for each input clip
|
|
||||||
(in the same order), populated with `SoundEventPrediction` objects
|
|
||||||
representing the final detections with decoded tags and geometry.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|||||||
@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
|
|||||||
throughout BatDetect2.
|
throughout BatDetect2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
@ -80,3 +82,14 @@ def adjust_width(
|
|||||||
for index in range(dims)
|
for index in range(dims)
|
||||||
]
|
]
|
||||||
return tensor[tuple(slices)]
|
return tensor[tuple(slices)]
|
||||||
|
|
||||||
|
|
||||||
|
def slice_tensor(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
start: Optional[int] = None,
|
||||||
|
end: Optional[int] = None,
|
||||||
|
dim: int = -1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
slices = [slice(None)] * tensor.ndim
|
||||||
|
slices[dim] = slice(start, end)
|
||||||
|
return tensor[tuple(slices)]
|
||||||
|
|||||||
@ -38,7 +38,6 @@ def build_from_config(
|
|||||||
max_freq=preprocessor.max_freq,
|
max_freq=preprocessor.max_freq,
|
||||||
)
|
)
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
targets,
|
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=postprocessing_config,
|
config=postprocessing_config,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user