Fix type errors

This commit is contained in:
mbsantiago 2026-03-08 12:55:36 +00:00
parent cce1b49a8d
commit d52e988b8f
26 changed files with 371 additions and 284 deletions

View File

@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov"
help:
@just --list
install:
uv sync
# Testing & Coverage
# Run tests using pytest.
test:
pytest {{TESTS_DIR}}
uv run pytest {{TESTS_DIR}}
# Run tests and generate coverage data.
coverage:
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
# Generate an HTML coverage report.
coverage-html: coverage
@echo "Generating HTML coverage report..."
coverage html -d {{HTML_COVERAGE_DIR}}
uv run coverage html -d {{HTML_COVERAGE_DIR}}
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
# Serve the HTML coverage report locally.
coverage-serve: coverage-html
@echo "Serving report at http://localhost:8000/ ..."
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
# Documentation
# Build documentation using Sphinx.
docs:
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
# Serve documentation with live reload.
docs-serve:
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
# Formatting & Linting
# Format code using ruff.
format:
ruff format {{PYTHON_DIRS}}
# Check code formatting using ruff.
format-check:
ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
lint:
ruff check {{PYTHON_DIRS}}
fix-format:
uv run ruff format {{PYTHON_DIRS}}
# Lint code using ruff and apply automatic fixes.
lint-fix:
ruff check --fix {{PYTHON_DIRS}}
fix-lint:
uv run ruff check --fix {{PYTHON_DIRS}}
# Combined Formatting & Linting
fix: fix-format fix-lint
# Checking tasks
# Check code formatting using ruff.
check-format:
uv run ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
check-lint:
uv run ruff check {{PYTHON_DIRS}}
# Type Checking
# Type check code using pyright.
typecheck:
pyright {{PYTHON_DIRS}}
# Type check code using ty.
check-types:
uv run ty check {{PYTHON_DIRS}}
# Combined Checks
# Run all checks (format-check, lint, typecheck).
check: format-check lint typecheck test
check: check-format check-lint check-types
# Cleaning tasks
# Remove Python bytecode and cache.
@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs
# Train on example data.
example-train OPTIONS="":
batdetect2 train \
uv run batdetect2 train \
--val-dataset example_data/dataset.yaml \
--config example_data/config.yaml \
{{OPTIONS}} \

View File

@ -75,7 +75,6 @@ dev = [
"ruff>=0.7.3",
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"basedpyright>=1.31.0",
"myst-parser>=3.0.1",
"sphinx-autobuild>=2024.10.3",
"numpydoc>=1.8.0",
@ -94,6 +93,12 @@ mlflow = ["mlflow>=3.1.1"]
[tool.ruff]
line-length = 79
target-version = "py310"
exclude = [
"src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
]
[tool.ruff.format]
docstring-code-format = true
@ -105,15 +110,11 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
[tool.ruff.lint.pydocstyle]
convention = "numpy"
[tool.pyright]
[tool.ty.src]
include = ["src", "tests"]
pythonVersion = "3.10"
pythonPlatform = "All"
exclude = [
"src/batdetect2/detector/",
"src/batdetect2/finetune",
"src/batdetect2/utils",
"src/batdetect2/plot",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
]

View File

@ -113,7 +113,7 @@ def load_file_audio(
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore
dtype: DTypeLike = np.float32,
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
@ -137,7 +137,7 @@ def load_recording_audio(
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore
dtype: DTypeLike = np.float32,
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
@ -159,7 +159,7 @@ def load_clip_audio(
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore
dtype: DTypeLike = np.float32,
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
@ -286,7 +286,7 @@ def resample_audio_fourier(
If `num` is negative.
"""
ratio = sr_new / sr_orig
return resample( # type: ignore
return resample(
array,
int(array.shape[axis] * ratio),
axis=axis,

View File

@ -103,17 +103,17 @@ def convert_to_annotation_group(
y_inds.append(0)
annotations.append(
{
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": 1.0,
"det_prob": 1.0,
"individual": "0",
"event": event,
"class_id": class_id, # type: ignore
}
Annotation(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_prob=1.0,
det_prob=1.0,
individual="0",
event=event,
class_id=class_id,
)
)
return {

View File

@ -89,7 +89,7 @@ class FileAnnotation(TypedDict):
annotation: List[Annotation]
"""List of annotations."""
file_path: NotRequired[str]
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
"""Path to file."""

View File

@ -26,15 +26,15 @@ def split_dataset_by_recordings(
)
majority_class = (
sound_events.groupby("recording_id")
sound_events.groupby("recording_id") # type: ignore
.apply(
lambda group: (
group["class_name"] # type: ignore
group["class_name"]
.value_counts()
.sort_values(ascending=False)
.index[0]
),
include_groups=False, # type: ignore
include_groups=False,
)
.rename("class_name")
.to_frame()
@ -48,8 +48,8 @@ def split_dataset_by_recordings(
random_state=random_state,
)
train_ids_set = set(train.values) # type: ignore
test_ids_set = set(test.values) # type: ignore
train_ids_set = set(train.values)
test_ids_set = set(test.values)
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set

View File

@ -175,14 +175,14 @@ def compute_class_summary(
.rename("num recordings")
)
durations = (
sound_events.groupby("class_name")
sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
.apply(
lambda group: recordings[
recordings["clip_annotation_id"].isin(
group["clip_annotation_id"] # type: ignore
group["clip_annotation_id"]
)
]["duration"].sum(),
include_groups=False, # type: ignore
include_groups=False,
)
.sort_values(ascending=False)
.rename("duration")

View File

@ -176,12 +176,7 @@ class MapTagValue:
if self.target_key is None:
tags.append(tag.model_copy(update=dict(value=value)))
else:
tags.append(
data.Tag(
key=self.target_key, # type: ignore
value=value,
)
)
tags.append(data.Tag(key=self.target_key, value=value))
return sound_event_annotation.model_copy(update=dict(tags=tags))

View File

@ -120,11 +120,11 @@ class BBoxIOU(AffinityFunction):
def __call__(
self,
prediction: Detection,
gt: data.SoundEventAnnotation,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry
source_geometry = prediction.geometry
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
@ -168,11 +168,11 @@ class GeometricIOU(AffinityFunction):
def __call__(
self,
prediction: Detection,
gt: data.SoundEventAnnotation,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry
source_geometry = prediction.geometry
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(

View File

@ -51,8 +51,8 @@ class TestDataset(Dataset[TestExample]):
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample:
clip_annotation = self.clip_annotations[idx]
def __getitem__(self, index: int) -> TestExample:
clip_annotation = self.clip_annotations[index]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
@ -63,7 +63,7 @@ class TestDataset(Dataset[TestExample]):
spectrogram = self.preprocessor(wav_tensor)
return TestExample(
spec=spectrogram,
idx=torch.tensor(idx),
idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)

View File

@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self,
clip_annotation: data.ClipAnnotation,
prediction: ClipDetections,
) -> T_Output: ...
) -> T_Output: ... # ty: ignore[empty-body]
def include_sound_event_annotation(
self,

View File

@ -46,14 +46,14 @@ class InferenceDataset(Dataset[DatasetItem]):
def __len__(self):
return len(self.clips)
def __getitem__(self, idx: int) -> DatasetItem:
clip = self.clips[idx]
def __getitem__(self, index: int) -> DatasetItem:
clip = self.clips[index]
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return DatasetItem(
spec=spectrogram,
idx=torch.tensor(idx),
idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)

View File

@ -185,7 +185,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
)
decoder = build_decoder(
in_channels=bottleneck.out_channels,
in_channels=bottleneck.get_output_channels(),
input_height=encoder.output_height,
config=config.decoder,
)

View File

@ -20,20 +20,21 @@ research:
spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively.
These blocks can be utilized directly in custom PyTorch model definitions or
These blocks can be used directly in custom PyTorch model definitions or
assembled into larger architectures.
A unified factory function `build_layer_from_config` allows creating instances
of these blocks based on configuration objects.
"""
from typing import Annotated, List, Literal, Tuple, Union
from typing import Annotated, List, Literal, Protocol, Tuple, Union
import torch
import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.core import Registry
from batdetect2.core.configs import BaseConfig
__all__ = [
@ -51,17 +52,30 @@ __all__ = [
"FreqCoordConvUpConfig",
"StandardConvUpConfig",
"LayerConfig",
"build_layer_from_config",
"build_layer",
]
class BlockProtocol(Protocol):
def get_output_channels(self) -> int: ...
def get_output_height(self, input_height: int) -> int:
return input_height
class Block(nn.Module, BlockProtocol): ...
block_registry: Registry[Block, [int, int]] = Registry("block")
class SelfAttentionConfig(BaseConfig):
name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(nn.Module):
class SelfAttention(Block):
"""Self-Attention mechanism operating along the time dimension.
This module implements a scaled dot-product self-attention mechanism,
@ -121,6 +135,7 @@ class SelfAttention(nn.Module):
# Note, does not encode position information (absolute or relative)
self.temperature = temperature
self.att_dim = attention_channels
self.output_channels = in_channels
self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels)
@ -190,6 +205,22 @@ class SelfAttention(nn.Module):
att_weights = F.softmax(kk_qq, 1)
return att_weights
def get_output_channels(self) -> int:
return self.output_channels
@block_registry.register(SelfAttentionConfig)
@staticmethod
def from_config(
config: SelfAttentionConfig,
input_channels: int,
input_height: int,
) -> "SelfAttention":
return SelfAttention(
in_channels=input_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
)
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""
@ -207,7 +238,7 @@ class ConvConfig(BaseConfig):
"""Padding size."""
class ConvBlock(nn.Module):
class ConvBlock(Block):
"""Basic Convolutional Block.
A standard building block consisting of a 2D convolution, followed by
@ -258,8 +289,30 @@ class ConvBlock(nn.Module):
"""
return F.relu_(self.batch_norm(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
class VerticalConv(nn.Module):
@block_registry.register(ConvConfig)
@staticmethod
def from_config(
config: ConvConfig,
input_channels: int,
input_height: int,
):
return ConvBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class VerticalConvConfig(BaseConfig):
name: Literal["VerticalConv"] = "VerticalConv"
channels: int
class VerticalConv(Block):
"""Convolutional layer that aggregates features across the entire height.
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
@ -312,6 +365,22 @@ class VerticalConv(nn.Module):
"""
return F.relu_(self.bn(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
@block_registry.register(VerticalConvConfig)
@staticmethod
def from_config(
config: VerticalConvConfig,
input_channels: int,
input_height: int,
):
return VerticalConv(
in_channels=input_channels,
out_channels=config.channels,
input_height=input_height,
)
class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock."""
@ -329,7 +398,7 @@ class FreqCoordConvDownConfig(BaseConfig):
"""Padding size."""
class FreqCoordConvDownBlock(nn.Module):
class FreqCoordConvDownBlock(Block):
"""Downsampling Conv Block incorporating Frequency Coordinate features.
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
@ -402,6 +471,27 @@ class FreqCoordConvDownBlock(nn.Module):
x = F.relu(self.batch_norm(x), inplace=True)
return x
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(FreqCoordConvDownConfig)
@staticmethod
def from_config(
config: FreqCoordConvDownConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock."""
@ -419,7 +509,7 @@ class StandardConvDownConfig(BaseConfig):
"""Padding size."""
class StandardConvDownBlock(nn.Module):
class StandardConvDownBlock(Block):
"""Standard Downsampling Convolutional Block.
A basic downsampling block consisting of a 2D convolution, followed by
@ -472,6 +562,26 @@ class StandardConvDownBlock(nn.Module):
x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.batch_norm(x), inplace=True)
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(StandardConvDownConfig)
@staticmethod
def from_config(
config: StandardConvDownConfig,
input_channels: int,
input_height: int,
):
return StandardConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock."""
@ -488,8 +598,14 @@ class FreqCoordConvUpConfig(BaseConfig):
pad_size: int = 1
"""Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class FreqCoordConvUpBlock(nn.Module):
up_scale: Tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class FreqCoordConvUpBlock(Block):
"""Upsampling Conv Block incorporating Frequency Coordinate features.
This block implements an upsampling step followed by a convolution,
@ -581,6 +697,29 @@ class FreqCoordConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True)
return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(FreqCoordConvUpConfig)
@staticmethod
def from_config(
config: FreqCoordConvUpConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock."""
@ -597,8 +736,14 @@ class StandardConvUpConfig(BaseConfig):
pad_size: int = 1
"""Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class StandardConvUpBlock(nn.Module):
up_scale: Tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class StandardConvUpBlock(Block):
"""Standard Upsampling Convolutional Block.
A basic upsampling block used in CNN decoders. It first upsamples the input
@ -669,6 +814,28 @@ class StandardConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True)
return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(StandardConvUpConfig)
@staticmethod
def from_config(
config: StandardConvUpConfig,
input_channels: int,
input_height: int,
):
return StandardConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
LayerConfig = Annotated[
Union[
@ -690,11 +857,61 @@ class LayerGroupConfig(BaseConfig):
layers: List[LayerConfig]
def build_layer_from_config(
class LayerGroup(nn.Module):
"""Standard implementation of the `LayerGroup` architecture."""
def __init__(
self,
layers: list[Block],
input_height: int,
input_channels: int,
):
super().__init__()
self.blocks = layers
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
def get_output_channels(self) -> int:
return self.blocks[-1].get_output_channels()
def get_output_height(self, input_height: int) -> int:
for block in self.blocks:
input_height = block.get_output_height(input_height)
return input_height
@block_registry.register(LayerGroupConfig)
@staticmethod
def from_config(
config: LayerGroupConfig,
input_height: int,
input_channels: int,
):
layers = []
for layer_config in config.layers:
layer = build_layer(
input_height=input_height,
in_channels=input_channels,
config=layer_config,
)
layers.append(layer)
input_height = layer.get_output_height(input_height)
input_channels = layer.get_output_channels()
return LayerGroup(
layers=layers,
input_height=input_height,
input_channels=input_channels,
)
def build_layer(
input_height: int,
in_channels: int,
config: LayerConfig,
) -> Tuple[nn.Module, int, int]:
) -> Block:
"""Factory function to build a specific nn.Module block from its config.
Takes configuration object (one of the types included in the `LayerConfig`
@ -731,93 +948,4 @@ def build_layer_from_config(
ValueError
If parameters derived from the config are invalid for the block.
"""
if config.name == "ConvBlock":
return (
ConvBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height,
)
if config.name == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.name == "LayerGroup":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.layers:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=block_config,
)
blocks.append(block)
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.name}")
return block_registry.build(config, in_channels, input_height)

View File

@ -22,9 +22,10 @@ from torch import nn
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
Block,
SelfAttentionConfig,
VerticalConv,
build_layer_from_config,
build_layer,
)
__all__ = [
@ -34,7 +35,7 @@ __all__ = [
]
class Bottleneck(nn.Module):
class Bottleneck(Block):
"""Base Bottleneck module for Encoder-Decoder architectures.
This implementation represents the simplest bottleneck structure
@ -86,6 +87,7 @@ class Bottleneck(nn.Module):
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = out_channels
self.bottleneck_channels = (
bottleneck_channels
if bottleneck_channels is not None
@ -172,7 +174,7 @@ def build_bottleneck(
input_height: int,
in_channels: int,
config: BottleneckConfig | None = None,
) -> nn.Module:
) -> Block:
"""Factory function to build the Bottleneck module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
@ -206,11 +208,13 @@ def build_bottleneck(
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
layer = build_layer(
input_height=current_height,
in_channels=current_channels,
config=layer_config,
)
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)

View File

@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
FreqCoordConvUpConfig,
LayerGroupConfig,
StandardConvUpConfig,
build_layer_from_config,
build_layer,
)
__all__ = [
@ -259,11 +259,13 @@ def build_decoder(
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
layer = build_layer(
in_channels=current_channels,
input_height=current_height,
config=layer_config,
)
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
layers.append(layer)
return Decoder(

View File

@ -32,7 +32,7 @@ from batdetect2.models.blocks import (
FreqCoordConvDownConfig,
LayerGroupConfig,
StandardConvDownConfig,
build_layer_from_config,
build_layer,
)
__all__ = [
@ -300,12 +300,14 @@ def build_encoder(
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
layer = build_layer(
in_channels=current_channels,
input_height=current_height,
config=layer_config,
)
layers.append(layer)
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
return Encoder(
input_height=input_height,

View File

@ -19,9 +19,9 @@ def create_ax(
) -> axes.Axes:
"""Create a new axis if none is provided"""
if ax is None:
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs)
return ax # type: ignore
return ax
def plot_spectrogram(

View File

@ -66,7 +66,7 @@ def plot_clip_detections(
if m.gt is not None:
color = gt_color if is_match else missed_gt_color
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
m.gt.sound_event.geometry,
ax=ax,
add_points=False,
linewidth=linewidth,

View File

@ -100,7 +100,7 @@ def plot_classification_heatmap(
class_heatmap,
vmax=1,
vmin=0,
cmap=create_colormap(color), # type: ignore
cmap=create_colormap(color),
alpha=alpha,
)

View File

@ -301,4 +301,4 @@ def _get_marker_positions(
size = len(thresholds)
cut_points = np.linspace(0, 1, n_points)
indices = np.searchsorted(thresholds[::-1], cut_points)
return np.clip(size - indices, 0, size - 1) # type: ignore
return np.clip(size - indices, 0, size - 1)

View File

@ -72,7 +72,7 @@ class TargetClassConfig(BaseConfig):
DEFAULT_DETECTION_CLASS = TargetClassConfig(
name="bat",
match_if=AllOfConfig(
condition_input=AllOfConfig(
conditions=[
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
NotConfig(

View File

@ -27,6 +27,7 @@ from pydantic import Field
from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
@ -88,6 +89,9 @@ DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper")
class AnchorBBoxMapperConfig(BaseConfig):
"""Configuration for `AnchorBBoxMapper`.
@ -244,6 +248,15 @@ class AnchorBBoxMapper(ROITargetMapper):
anchor=self.anchor,
)
@roi_mapper_registry.register(AnchorBBoxMapperConfig)
@staticmethod
def from_config(config: AnchorBBoxMapperConfig):
return AnchorBBoxMapper(
anchor=config.anchor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
class PeakEnergyBBoxMapperConfig(BaseConfig):
"""Configuration for `PeakEnergyBBoxMapper`.
@ -412,6 +425,22 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
]
)
@roi_mapper_registry.register(PeakEnergyBBoxMapperConfig)
@staticmethod
def from_config(config: PeakEnergyBBoxMapperConfig):
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
)
ROIMapperConfig = Annotated[
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
@ -445,31 +474,7 @@ def build_roi_mapper(
If the `name` in the config does not correspond to a known mapper.
"""
config = config or AnchorBBoxMapperConfig()
if config.name == "anchor_bbox":
return AnchorBBoxMapper(
anchor=config.anchor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
if config.name == "peak_energy_bbox":
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
)
raise NotImplementedError(
f"No ROI mapper of name '{config.name}' is implemented"
)
return roi_mapper_registry.build(config)
VALID_ANCHORS = [
@ -636,7 +641,7 @@ def get_peak_energy_coordinates(
)
index = selection.argmax(dim=["time", "frequency"])
point = selection.isel(index) # type: ignore
point = selection.isel(index)
peak_time: float = point.time.item()
peak_freq: float = point.frequency.item()
return peak_time, peak_freq

View File

@ -548,63 +548,6 @@ class MaybeApply(torch.nn.Module):
return self.augmentation(tensor, clip_annotation)
def build_augmentation_from_config(
config: AugmentationConfig,
samplerate: int,
audio_source: AudioSource | None = None,
) -> Augmentation | None:
"""Factory function to build a single augmentation from its config."""
if config.name == "mix_audio":
if audio_source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.",
stacklevel=2,
)
return None
return MixAudio(
example_source=audio_source,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.name == "add_echo":
return AddEcho(
max_delay=int(config.max_delay * samplerate),
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.name == "scale_volume":
return ScaleVolume(
max_scaling=config.max_scaling,
min_scaling=config.min_scaling,
)
if config.name == "warp":
return Warp(
delta=config.delta,
)
if config.name == "mask_time":
return MaskTime(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
if config.name == "mask_freq":
return MaskFrequency(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
raise NotImplementedError(
"Invalid or unimplemented augmentation type: "
f"{config.augmentation_type}"
)
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
enabled=True,
audio=[

View File

@ -61,8 +61,8 @@ class TrainingDataset(Dataset):
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx]
def __getitem__(self, index) -> TrainExample:
clip_annotation = self.clip_annotations[index]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
@ -95,7 +95,7 @@ class TrainingDataset(Dataset):
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
idx=torch.tensor(idx),
idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
@ -121,8 +121,8 @@ class ValidationDataset(Dataset):
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx]
def __getitem__(self, index) -> TrainExample:
clip_annotation = self.clip_annotations[index]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
@ -141,7 +141,7 @@ class ValidationDataset(Dataset):
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
idx=torch.tensor(idx),
idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)

View File

@ -472,7 +472,7 @@ def build_loss(
size_loss_fn = BBoxLoss()
return LossFunction( # type: ignore
return LossFunction(
size_loss=size_loss_fn,
classification_loss=classification_loss_fn,
detection_loss=detection_loss_fn,