Fix type errors

This commit is contained in:
mbsantiago 2026-03-08 12:55:36 +00:00
parent cce1b49a8d
commit d52e988b8f
26 changed files with 371 additions and 284 deletions

View File

@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov"
help: help:
@just --list @just --list
install:
uv sync
# Testing & Coverage # Testing & Coverage
# Run tests using pytest. # Run tests using pytest.
test: test:
pytest {{TESTS_DIR}} uv run pytest {{TESTS_DIR}}
# Run tests and generate coverage data. # Run tests and generate coverage data.
coverage: coverage:
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}} uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
# Generate an HTML coverage report. # Generate an HTML coverage report.
coverage-html: coverage coverage-html: coverage
@echo "Generating HTML coverage report..." @echo "Generating HTML coverage report..."
coverage html -d {{HTML_COVERAGE_DIR}} uv run coverage html -d {{HTML_COVERAGE_DIR}}
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/" @echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
# Serve the HTML coverage report locally. # Serve the HTML coverage report locally.
coverage-serve: coverage-html coverage-serve: coverage-html
@echo "Serving report at http://localhost:8000/ ..." @echo "Serving report at http://localhost:8000/ ..."
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000 uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
# Documentation # Documentation
# Build documentation using Sphinx. # Build documentation using Sphinx.
docs: docs:
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}} uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
# Serve documentation with live reload. # Serve documentation with live reload.
docs-serve: docs-serve:
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
# Formatting & Linting # Formatting & Linting
# Format code using ruff. # Format code using ruff.
format: fix-format:
ruff format {{PYTHON_DIRS}} uv run ruff format {{PYTHON_DIRS}}
# Check code formatting using ruff.
format-check:
ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
lint:
ruff check {{PYTHON_DIRS}}
# Lint code using ruff and apply automatic fixes. # Lint code using ruff and apply automatic fixes.
lint-fix: fix-lint:
ruff check --fix {{PYTHON_DIRS}} uv run ruff check --fix {{PYTHON_DIRS}}
# Combined Formatting & Linting
fix: fix-format fix-lint
# Checking tasks
# Check code formatting using ruff.
check-format:
uv run ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
check-lint:
uv run ruff check {{PYTHON_DIRS}}
# Type Checking # Type Checking
# Type check code using pyright. # Type check code using ty.
typecheck: check-types:
pyright {{PYTHON_DIRS}} uv run ty check {{PYTHON_DIRS}}
# Combined Checks # Combined Checks
# Run all checks (format-check, lint, typecheck). # Run all checks (format-check, lint, typecheck).
check: format-check lint typecheck test check: check-format check-lint check-types
# Cleaning tasks # Cleaning tasks
# Remove Python bytecode and cache. # Remove Python bytecode and cache.
@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs
# Train on example data. # Train on example data.
example-train OPTIONS="": example-train OPTIONS="":
batdetect2 train \ uv run batdetect2 train \
--val-dataset example_data/dataset.yaml \ --val-dataset example_data/dataset.yaml \
--config example_data/config.yaml \ --config example_data/config.yaml \
{{OPTIONS}} \ {{OPTIONS}} \

View File

@ -75,7 +75,6 @@ dev = [
"ruff>=0.7.3", "ruff>=0.7.3",
"ipykernel>=6.29.4", "ipykernel>=6.29.4",
"setuptools>=69.5.1", "setuptools>=69.5.1",
"basedpyright>=1.31.0",
"myst-parser>=3.0.1", "myst-parser>=3.0.1",
"sphinx-autobuild>=2024.10.3", "sphinx-autobuild>=2024.10.3",
"numpydoc>=1.8.0", "numpydoc>=1.8.0",
@ -94,6 +93,12 @@ mlflow = ["mlflow>=3.1.1"]
[tool.ruff] [tool.ruff]
line-length = 79 line-length = 79
target-version = "py310" target-version = "py310"
exclude = [
"src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
]
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true
@ -105,15 +110,11 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
[tool.ruff.lint.pydocstyle] [tool.ruff.lint.pydocstyle]
convention = "numpy" convention = "numpy"
[tool.pyright] [tool.ty.src]
include = ["src", "tests"] include = ["src", "tests"]
pythonVersion = "3.10"
pythonPlatform = "All"
exclude = [ exclude = [
"src/batdetect2/detector/",
"src/batdetect2/finetune",
"src/batdetect2/utils",
"src/batdetect2/plot",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/train/legacy", "src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
] ]

View File

@ -113,7 +113,7 @@ def load_file_audio(
samplerate: int | None = None, samplerate: int | None = None,
config: ResampleConfig | None = None, config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config.""" """Load and preprocess audio from a file path using specified config."""
try: try:
@ -137,7 +137,7 @@ def load_recording_audio(
samplerate: int | None = None, samplerate: int | None = None,
config: ResampleConfig | None = None, config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config.""" """Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip( clip = data.Clip(
@ -159,7 +159,7 @@ def load_clip_audio(
samplerate: int | None = None, samplerate: int | None = None,
config: ResampleConfig | None = None, config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config.""" """Load and preprocess a specific audio clip segment based on config."""
try: try:
@ -286,7 +286,7 @@ def resample_audio_fourier(
If `num` is negative. If `num` is negative.
""" """
ratio = sr_new / sr_orig ratio = sr_new / sr_orig
return resample( # type: ignore return resample(
array, array,
int(array.shape[axis] * ratio), int(array.shape[axis] * ratio),
axis=axis, axis=axis,

View File

@ -103,17 +103,17 @@ def convert_to_annotation_group(
y_inds.append(0) y_inds.append(0)
annotations.append( annotations.append(
{ Annotation(
"start_time": start_time, start_time=start_time,
"end_time": end_time, end_time=end_time,
"low_freq": low_freq, low_freq=low_freq,
"high_freq": high_freq, high_freq=high_freq,
"class_prob": 1.0, class_prob=1.0,
"det_prob": 1.0, det_prob=1.0,
"individual": "0", individual="0",
"event": event, event=event,
"class_id": class_id, # type: ignore class_id=class_id,
} )
) )
return { return {

View File

@ -89,7 +89,7 @@ class FileAnnotation(TypedDict):
annotation: List[Annotation] annotation: List[Annotation]
"""List of annotations.""" """List of annotations."""
file_path: NotRequired[str] file_path: NotRequired[str] # ty: ignore[invalid-type-form]
"""Path to file.""" """Path to file."""

View File

@ -26,15 +26,15 @@ def split_dataset_by_recordings(
) )
majority_class = ( majority_class = (
sound_events.groupby("recording_id") sound_events.groupby("recording_id") # type: ignore
.apply( .apply(
lambda group: ( lambda group: (
group["class_name"] # type: ignore group["class_name"]
.value_counts() .value_counts()
.sort_values(ascending=False) .sort_values(ascending=False)
.index[0] .index[0]
), ),
include_groups=False, # type: ignore include_groups=False,
) )
.rename("class_name") .rename("class_name")
.to_frame() .to_frame()
@ -48,8 +48,8 @@ def split_dataset_by_recordings(
random_state=random_state, random_state=random_state,
) )
train_ids_set = set(train.values) # type: ignore train_ids_set = set(train.values)
test_ids_set = set(test.values) # type: ignore test_ids_set = set(test.values)
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set

View File

@ -175,14 +175,14 @@ def compute_class_summary(
.rename("num recordings") .rename("num recordings")
) )
durations = ( durations = (
sound_events.groupby("class_name") sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
.apply( .apply(
lambda group: recordings[ lambda group: recordings[
recordings["clip_annotation_id"].isin( recordings["clip_annotation_id"].isin(
group["clip_annotation_id"] # type: ignore group["clip_annotation_id"]
) )
]["duration"].sum(), ]["duration"].sum(),
include_groups=False, # type: ignore include_groups=False,
) )
.sort_values(ascending=False) .sort_values(ascending=False)
.rename("duration") .rename("duration")

View File

@ -176,12 +176,7 @@ class MapTagValue:
if self.target_key is None: if self.target_key is None:
tags.append(tag.model_copy(update=dict(value=value))) tags.append(tag.model_copy(update=dict(value=value)))
else: else:
tags.append( tags.append(data.Tag(key=self.target_key, value=value))
data.Tag(
key=self.target_key, # type: ignore
value=value,
)
)
return sound_event_annotation.model_copy(update=dict(tags=tags)) return sound_event_annotation.model_copy(update=dict(tags=tags))

View File

@ -120,11 +120,11 @@ class BBoxIOU(AffinityFunction):
def __call__( def __call__(
self, self,
prediction: Detection, detection: Detection,
gt: data.SoundEventAnnotation, ground_truth: data.SoundEventAnnotation,
): ):
target_geometry = gt.sound_event.geometry target_geometry = ground_truth.sound_event.geometry
source_geometry = prediction.geometry source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0: if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry( target_geometry = buffer_geometry(
@ -168,11 +168,11 @@ class GeometricIOU(AffinityFunction):
def __call__( def __call__(
self, self,
prediction: Detection, detection: Detection,
gt: data.SoundEventAnnotation, ground_truth: data.SoundEventAnnotation,
): ):
target_geometry = gt.sound_event.geometry target_geometry = ground_truth.sound_event.geometry
source_geometry = prediction.geometry source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0: if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry( target_geometry = buffer_geometry(

View File

@ -51,8 +51,8 @@ class TestDataset(Dataset[TestExample]):
def __len__(self): def __len__(self):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample: def __getitem__(self, index: int) -> TestExample:
clip_annotation = self.clip_annotations[idx] clip_annotation = self.clip_annotations[index]
if self.clipper is not None: if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation) clip_annotation = self.clipper(clip_annotation)
@ -63,7 +63,7 @@ class TestDataset(Dataset[TestExample]):
spectrogram = self.preprocessor(wav_tensor) spectrogram = self.preprocessor(wav_tensor)
return TestExample( return TestExample(
spec=spectrogram, spec=spectrogram,
idx=torch.tensor(idx), idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )

View File

@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: ClipDetections, prediction: ClipDetections,
) -> T_Output: ... ) -> T_Output: ... # ty: ignore[empty-body]
def include_sound_event_annotation( def include_sound_event_annotation(
self, self,

View File

@ -46,14 +46,14 @@ class InferenceDataset(Dataset[DatasetItem]):
def __len__(self): def __len__(self):
return len(self.clips) return len(self.clips)
def __getitem__(self, idx: int) -> DatasetItem: def __getitem__(self, index: int) -> DatasetItem:
clip = self.clips[idx] clip = self.clips[index]
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0) wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor) spectrogram = self.preprocessor(wav_tensor)
return DatasetItem( return DatasetItem(
spec=spectrogram, spec=spectrogram,
idx=torch.tensor(idx), idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )

View File

@ -185,7 +185,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
) )
decoder = build_decoder( decoder = build_decoder(
in_channels=bottleneck.out_channels, in_channels=bottleneck.get_output_channels(),
input_height=encoder.output_height, input_height=encoder.output_height,
config=config.decoder, config=config.decoder,
) )

View File

@ -20,20 +20,21 @@ research:
spatial frequency information to filters, potentially enabling them to learn spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively. frequency-dependent patterns more effectively.
These blocks can be utilized directly in custom PyTorch model definitions or These blocks can be used directly in custom PyTorch model definitions or
assembled into larger architectures. assembled into larger architectures.
A unified factory function `build_layer_from_config` allows creating instances A unified factory function `build_layer_from_config` allows creating instances
of these blocks based on configuration objects. of these blocks based on configuration objects.
""" """
from typing import Annotated, List, Literal, Tuple, Union from typing import Annotated, List, Literal, Protocol, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.core import Registry
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
__all__ = [ __all__ = [
@ -51,17 +52,30 @@ __all__ = [
"FreqCoordConvUpConfig", "FreqCoordConvUpConfig",
"StandardConvUpConfig", "StandardConvUpConfig",
"LayerConfig", "LayerConfig",
"build_layer_from_config", "build_layer",
] ]
class BlockProtocol(Protocol):
def get_output_channels(self) -> int: ...
def get_output_height(self, input_height: int) -> int:
return input_height
class Block(nn.Module, BlockProtocol): ...
block_registry: Registry[Block, [int, int]] = Registry("block")
class SelfAttentionConfig(BaseConfig): class SelfAttentionConfig(BaseConfig):
name: Literal["SelfAttention"] = "SelfAttention" name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int attention_channels: int
temperature: float = 1 temperature: float = 1
class SelfAttention(nn.Module): class SelfAttention(Block):
"""Self-Attention mechanism operating along the time dimension. """Self-Attention mechanism operating along the time dimension.
This module implements a scaled dot-product self-attention mechanism, This module implements a scaled dot-product self-attention mechanism,
@ -121,6 +135,7 @@ class SelfAttention(nn.Module):
# Note, does not encode position information (absolute or relative) # Note, does not encode position information (absolute or relative)
self.temperature = temperature self.temperature = temperature
self.att_dim = attention_channels self.att_dim = attention_channels
self.output_channels = in_channels
self.key_fun = nn.Linear(in_channels, attention_channels) self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels) self.value_fun = nn.Linear(in_channels, attention_channels)
@ -190,6 +205,22 @@ class SelfAttention(nn.Module):
att_weights = F.softmax(kk_qq, 1) att_weights = F.softmax(kk_qq, 1)
return att_weights return att_weights
def get_output_channels(self) -> int:
return self.output_channels
@block_registry.register(SelfAttentionConfig)
@staticmethod
def from_config(
config: SelfAttentionConfig,
input_channels: int,
input_height: int,
) -> "SelfAttention":
return SelfAttention(
in_channels=input_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
)
class ConvConfig(BaseConfig): class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock.""" """Configuration for a basic ConvBlock."""
@ -207,7 +238,7 @@ class ConvConfig(BaseConfig):
"""Padding size.""" """Padding size."""
class ConvBlock(nn.Module): class ConvBlock(Block):
"""Basic Convolutional Block. """Basic Convolutional Block.
A standard building block consisting of a 2D convolution, followed by A standard building block consisting of a 2D convolution, followed by
@ -258,8 +289,30 @@ class ConvBlock(nn.Module):
""" """
return F.relu_(self.batch_norm(self.conv(x))) return F.relu_(self.batch_norm(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
class VerticalConv(nn.Module): @block_registry.register(ConvConfig)
@staticmethod
def from_config(
config: ConvConfig,
input_channels: int,
input_height: int,
):
return ConvBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class VerticalConvConfig(BaseConfig):
name: Literal["VerticalConv"] = "VerticalConv"
channels: int
class VerticalConv(Block):
"""Convolutional layer that aggregates features across the entire height. """Convolutional layer that aggregates features across the entire height.
Applies a 2D convolution using a kernel with shape `(input_height, 1)`. Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
@ -312,6 +365,22 @@ class VerticalConv(nn.Module):
""" """
return F.relu_(self.bn(self.conv(x))) return F.relu_(self.bn(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
@block_registry.register(VerticalConvConfig)
@staticmethod
def from_config(
config: VerticalConvConfig,
input_channels: int,
input_height: int,
):
return VerticalConv(
in_channels=input_channels,
out_channels=config.channels,
input_height=input_height,
)
class FreqCoordConvDownConfig(BaseConfig): class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock.""" """Configuration for a FreqCoordConvDownBlock."""
@ -329,7 +398,7 @@ class FreqCoordConvDownConfig(BaseConfig):
"""Padding size.""" """Padding size."""
class FreqCoordConvDownBlock(nn.Module): class FreqCoordConvDownBlock(Block):
"""Downsampling Conv Block incorporating Frequency Coordinate features. """Downsampling Conv Block incorporating Frequency Coordinate features.
This block implements a downsampling step (Conv2d + MaxPool2d) commonly This block implements a downsampling step (Conv2d + MaxPool2d) commonly
@ -402,6 +471,27 @@ class FreqCoordConvDownBlock(nn.Module):
x = F.relu(self.batch_norm(x), inplace=True) x = F.relu(self.batch_norm(x), inplace=True)
return x return x
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(FreqCoordConvDownConfig)
@staticmethod
def from_config(
config: FreqCoordConvDownConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class StandardConvDownConfig(BaseConfig): class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock.""" """Configuration for a StandardConvDownBlock."""
@ -419,7 +509,7 @@ class StandardConvDownConfig(BaseConfig):
"""Padding size.""" """Padding size."""
class StandardConvDownBlock(nn.Module): class StandardConvDownBlock(Block):
"""Standard Downsampling Convolutional Block. """Standard Downsampling Convolutional Block.
A basic downsampling block consisting of a 2D convolution, followed by A basic downsampling block consisting of a 2D convolution, followed by
@ -472,6 +562,26 @@ class StandardConvDownBlock(nn.Module):
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.batch_norm(x), inplace=True) return F.relu(self.batch_norm(x), inplace=True)
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(StandardConvDownConfig)
@staticmethod
def from_config(
config: StandardConvDownConfig,
input_channels: int,
input_height: int,
):
return StandardConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class FreqCoordConvUpConfig(BaseConfig): class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock.""" """Configuration for a FreqCoordConvUpBlock."""
@ -488,8 +598,14 @@ class FreqCoordConvUpConfig(BaseConfig):
pad_size: int = 1 pad_size: int = 1
"""Padding size.""" """Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class FreqCoordConvUpBlock(nn.Module): up_scale: Tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class FreqCoordConvUpBlock(Block):
"""Upsampling Conv Block incorporating Frequency Coordinate features. """Upsampling Conv Block incorporating Frequency Coordinate features.
This block implements an upsampling step followed by a convolution, This block implements an upsampling step followed by a convolution,
@ -581,6 +697,29 @@ class FreqCoordConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(FreqCoordConvUpConfig)
@staticmethod
def from_config(
config: FreqCoordConvUpConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
class StandardConvUpConfig(BaseConfig): class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock.""" """Configuration for a StandardConvUpBlock."""
@ -597,8 +736,14 @@ class StandardConvUpConfig(BaseConfig):
pad_size: int = 1 pad_size: int = 1
"""Padding size.""" """Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class StandardConvUpBlock(nn.Module): up_scale: Tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class StandardConvUpBlock(Block):
"""Standard Upsampling Convolutional Block. """Standard Upsampling Convolutional Block.
A basic upsampling block used in CNN decoders. It first upsamples the input A basic upsampling block used in CNN decoders. It first upsamples the input
@ -669,6 +814,28 @@ class StandardConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(StandardConvUpConfig)
@staticmethod
def from_config(
config: StandardConvUpConfig,
input_channels: int,
input_height: int,
):
return StandardConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
LayerConfig = Annotated[ LayerConfig = Annotated[
Union[ Union[
@ -690,11 +857,61 @@ class LayerGroupConfig(BaseConfig):
layers: List[LayerConfig] layers: List[LayerConfig]
def build_layer_from_config( class LayerGroup(nn.Module):
"""Standard implementation of the `LayerGroup` architecture."""
def __init__(
self,
layers: list[Block],
input_height: int,
input_channels: int,
):
super().__init__()
self.blocks = layers
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
def get_output_channels(self) -> int:
return self.blocks[-1].get_output_channels()
def get_output_height(self, input_height: int) -> int:
for block in self.blocks:
input_height = block.get_output_height(input_height)
return input_height
@block_registry.register(LayerGroupConfig)
@staticmethod
def from_config(
config: LayerGroupConfig,
input_height: int,
input_channels: int,
):
layers = []
for layer_config in config.layers:
layer = build_layer(
input_height=input_height,
in_channels=input_channels,
config=layer_config,
)
layers.append(layer)
input_height = layer.get_output_height(input_height)
input_channels = layer.get_output_channels()
return LayerGroup(
layers=layers,
input_height=input_height,
input_channels=input_channels,
)
def build_layer(
input_height: int, input_height: int,
in_channels: int, in_channels: int,
config: LayerConfig, config: LayerConfig,
) -> Tuple[nn.Module, int, int]: ) -> Block:
"""Factory function to build a specific nn.Module block from its config. """Factory function to build a specific nn.Module block from its config.
Takes configuration object (one of the types included in the `LayerConfig` Takes configuration object (one of the types included in the `LayerConfig`
@ -731,93 +948,4 @@ def build_layer_from_config(
ValueError ValueError
If parameters derived from the config are invalid for the block. If parameters derived from the config are invalid for the block.
""" """
if config.name == "ConvBlock": return block_registry.build(config, in_channels, input_height)
return (
ConvBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height,
)
if config.name == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.name == "LayerGroup":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.layers:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=block_config,
)
blocks.append(block)
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.name}")

View File

@ -22,9 +22,10 @@ from torch import nn
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
Block,
SelfAttentionConfig, SelfAttentionConfig,
VerticalConv, VerticalConv,
build_layer_from_config, build_layer,
) )
__all__ = [ __all__ = [
@ -34,7 +35,7 @@ __all__ = [
] ]
class Bottleneck(nn.Module): class Bottleneck(Block):
"""Base Bottleneck module for Encoder-Decoder architectures. """Base Bottleneck module for Encoder-Decoder architectures.
This implementation represents the simplest bottleneck structure This implementation represents the simplest bottleneck structure
@ -86,6 +87,7 @@ class Bottleneck(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.input_height = input_height self.input_height = input_height
self.out_channels = out_channels self.out_channels = out_channels
self.bottleneck_channels = ( self.bottleneck_channels = (
bottleneck_channels bottleneck_channels
if bottleneck_channels is not None if bottleneck_channels is not None
@ -172,7 +174,7 @@ def build_bottleneck(
input_height: int, input_height: int,
in_channels: int, in_channels: int,
config: BottleneckConfig | None = None, config: BottleneckConfig | None = None,
) -> nn.Module: ) -> Block:
"""Factory function to build the Bottleneck module from configuration. """Factory function to build the Bottleneck module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
@ -206,11 +208,13 @@ def build_bottleneck(
layers = [] layers = []
for layer_config in config.layers: for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config( layer = build_layer(
input_height=current_height, input_height=current_height,
in_channels=current_channels, in_channels=current_channels,
config=layer_config, config=layer_config,
) )
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
assert current_height == input_height, ( assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height" "Bottleneck layers should not change the spectrogram height"
) )

View File

@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
LayerGroupConfig, LayerGroupConfig,
StandardConvUpConfig, StandardConvUpConfig,
build_layer_from_config, build_layer,
) )
__all__ = [ __all__ = [
@ -259,11 +259,13 @@ def build_decoder(
layers = [] layers = []
for layer_config in config.layers: for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config( layer = build_layer(
in_channels=current_channels, in_channels=current_channels,
input_height=current_height, input_height=current_height,
config=layer_config, config=layer_config,
) )
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
layers.append(layer) layers.append(layer)
return Decoder( return Decoder(

View File

@ -32,7 +32,7 @@ from batdetect2.models.blocks import (
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
LayerGroupConfig, LayerGroupConfig,
StandardConvDownConfig, StandardConvDownConfig,
build_layer_from_config, build_layer,
) )
__all__ = [ __all__ = [
@ -300,12 +300,14 @@ def build_encoder(
layers = [] layers = []
for layer_config in config.layers: for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config( layer = build_layer(
in_channels=current_channels, in_channels=current_channels,
input_height=current_height, input_height=current_height,
config=layer_config, config=layer_config,
) )
layers.append(layer) layers.append(layer)
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
return Encoder( return Encoder(
input_height=input_height, input_height=input_height,

View File

@ -19,9 +19,9 @@ def create_ax(
) -> axes.Axes: ) -> axes.Axes:
"""Create a new axis if none is provided""" """Create a new axis if none is provided"""
if ax is None: if ax is None:
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore _, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs)
return ax # type: ignore return ax
def plot_spectrogram( def plot_spectrogram(

View File

@ -66,7 +66,7 @@ def plot_clip_detections(
if m.gt is not None: if m.gt is not None:
color = gt_color if is_match else missed_gt_color color = gt_color if is_match else missed_gt_color
plot_geometry( plot_geometry(
m.gt.sound_event.geometry, # type: ignore m.gt.sound_event.geometry,
ax=ax, ax=ax,
add_points=False, add_points=False,
linewidth=linewidth, linewidth=linewidth,

View File

@ -100,7 +100,7 @@ def plot_classification_heatmap(
class_heatmap, class_heatmap,
vmax=1, vmax=1,
vmin=0, vmin=0,
cmap=create_colormap(color), # type: ignore cmap=create_colormap(color),
alpha=alpha, alpha=alpha,
) )

View File

@ -301,4 +301,4 @@ def _get_marker_positions(
size = len(thresholds) size = len(thresholds)
cut_points = np.linspace(0, 1, n_points) cut_points = np.linspace(0, 1, n_points)
indices = np.searchsorted(thresholds[::-1], cut_points) indices = np.searchsorted(thresholds[::-1], cut_points)
return np.clip(size - indices, 0, size - 1) # type: ignore return np.clip(size - indices, 0, size - 1)

View File

@ -72,7 +72,7 @@ class TargetClassConfig(BaseConfig):
DEFAULT_DETECTION_CLASS = TargetClassConfig( DEFAULT_DETECTION_CLASS = TargetClassConfig(
name="bat", name="bat",
match_if=AllOfConfig( condition_input=AllOfConfig(
conditions=[ conditions=[
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")), HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
NotConfig( NotConfig(

View File

@ -27,6 +27,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.core.arrays import spec_to_xarray from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
@ -88,6 +89,9 @@ DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner).""" """Default reference position within the geometry ('bottom-left' corner)."""
roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper")
class AnchorBBoxMapperConfig(BaseConfig): class AnchorBBoxMapperConfig(BaseConfig):
"""Configuration for `AnchorBBoxMapper`. """Configuration for `AnchorBBoxMapper`.
@ -244,6 +248,15 @@ class AnchorBBoxMapper(ROITargetMapper):
anchor=self.anchor, anchor=self.anchor,
) )
@roi_mapper_registry.register(AnchorBBoxMapperConfig)
@staticmethod
def from_config(config: AnchorBBoxMapperConfig):
return AnchorBBoxMapper(
anchor=config.anchor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
class PeakEnergyBBoxMapperConfig(BaseConfig): class PeakEnergyBBoxMapperConfig(BaseConfig):
"""Configuration for `PeakEnergyBBoxMapper`. """Configuration for `PeakEnergyBBoxMapper`.
@ -412,6 +425,22 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
] ]
) )
@roi_mapper_registry.register(PeakEnergyBBoxMapperConfig)
@staticmethod
def from_config(config: PeakEnergyBBoxMapperConfig):
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
)
ROIMapperConfig = Annotated[ ROIMapperConfig = Annotated[
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig, AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
@ -445,31 +474,7 @@ def build_roi_mapper(
If the `name` in the config does not correspond to a known mapper. If the `name` in the config does not correspond to a known mapper.
""" """
config = config or AnchorBBoxMapperConfig() config = config or AnchorBBoxMapperConfig()
return roi_mapper_registry.build(config)
if config.name == "anchor_bbox":
return AnchorBBoxMapper(
anchor=config.anchor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
if config.name == "peak_energy_bbox":
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
)
raise NotImplementedError(
f"No ROI mapper of name '{config.name}' is implemented"
)
VALID_ANCHORS = [ VALID_ANCHORS = [
@ -636,7 +641,7 @@ def get_peak_energy_coordinates(
) )
index = selection.argmax(dim=["time", "frequency"]) index = selection.argmax(dim=["time", "frequency"])
point = selection.isel(index) # type: ignore point = selection.isel(index)
peak_time: float = point.time.item() peak_time: float = point.time.item()
peak_freq: float = point.frequency.item() peak_freq: float = point.frequency.item()
return peak_time, peak_freq return peak_time, peak_freq

View File

@ -548,63 +548,6 @@ class MaybeApply(torch.nn.Module):
return self.augmentation(tensor, clip_annotation) return self.augmentation(tensor, clip_annotation)
def build_augmentation_from_config(
config: AugmentationConfig,
samplerate: int,
audio_source: AudioSource | None = None,
) -> Augmentation | None:
"""Factory function to build a single augmentation from its config."""
if config.name == "mix_audio":
if audio_source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.",
stacklevel=2,
)
return None
return MixAudio(
example_source=audio_source,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.name == "add_echo":
return AddEcho(
max_delay=int(config.max_delay * samplerate),
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.name == "scale_volume":
return ScaleVolume(
max_scaling=config.max_scaling,
min_scaling=config.min_scaling,
)
if config.name == "warp":
return Warp(
delta=config.delta,
)
if config.name == "mask_time":
return MaskTime(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
if config.name == "mask_freq":
return MaskFrequency(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
raise NotImplementedError(
"Invalid or unimplemented augmentation type: "
f"{config.augmentation_type}"
)
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
enabled=True, enabled=True,
audio=[ audio=[

View File

@ -61,8 +61,8 @@ class TrainingDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, index) -> TrainExample:
clip_annotation = self.clip_annotations[idx] clip_annotation = self.clip_annotations[index]
if self.clipper is not None: if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation) clip_annotation = self.clipper(clip_annotation)
@ -95,7 +95,7 @@ class TrainingDataset(Dataset):
detection_heatmap=heatmaps.detection, detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes, class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size, size_heatmap=heatmaps.size,
idx=torch.tensor(idx), idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )
@ -121,8 +121,8 @@ class ValidationDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, index) -> TrainExample:
clip_annotation = self.clip_annotations[idx] clip_annotation = self.clip_annotations[index]
if self.clipper is not None: if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation) clip_annotation = self.clipper(clip_annotation)
@ -141,7 +141,7 @@ class ValidationDataset(Dataset):
detection_heatmap=heatmaps.detection, detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes, class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size, size_heatmap=heatmaps.size,
idx=torch.tensor(idx), idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )

View File

@ -472,7 +472,7 @@ def build_loss(
size_loss_fn = BBoxLoss() size_loss_fn = BBoxLoss()
return LossFunction( # type: ignore return LossFunction(
size_loss=size_loss_fn, size_loss=size_loss_fn,
classification_loss=classification_loss_fn, classification_loss=classification_loss_fn,
detection_loss=detection_loss_fn, detection_loss=detection_loss_fn,