mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Fix type errors
This commit is contained in:
parent
cce1b49a8d
commit
d52e988b8f
53
justfile
53
justfile
@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov"
|
||||
help:
|
||||
@just --list
|
||||
|
||||
install:
|
||||
uv sync
|
||||
|
||||
# Testing & Coverage
|
||||
# Run tests using pytest.
|
||||
test:
|
||||
pytest {{TESTS_DIR}}
|
||||
uv run pytest {{TESTS_DIR}}
|
||||
|
||||
# Run tests and generate coverage data.
|
||||
coverage:
|
||||
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||
uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||
|
||||
# Generate an HTML coverage report.
|
||||
coverage-html: coverage
|
||||
@echo "Generating HTML coverage report..."
|
||||
coverage html -d {{HTML_COVERAGE_DIR}}
|
||||
uv run coverage html -d {{HTML_COVERAGE_DIR}}
|
||||
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
||||
|
||||
# Serve the HTML coverage report locally.
|
||||
coverage-serve: coverage-html
|
||||
@echo "Serving report at http://localhost:8000/ ..."
|
||||
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||
uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||
|
||||
# Documentation
|
||||
# Build documentation using Sphinx.
|
||||
docs:
|
||||
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||
|
||||
# Serve documentation with live reload.
|
||||
docs-serve:
|
||||
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||
|
||||
# Formatting & Linting
|
||||
# Format code using ruff.
|
||||
format:
|
||||
ruff format {{PYTHON_DIRS}}
|
||||
|
||||
# Check code formatting using ruff.
|
||||
format-check:
|
||||
ruff format --check {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff.
|
||||
lint:
|
||||
ruff check {{PYTHON_DIRS}}
|
||||
fix-format:
|
||||
uv run ruff format {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff and apply automatic fixes.
|
||||
lint-fix:
|
||||
ruff check --fix {{PYTHON_DIRS}}
|
||||
fix-lint:
|
||||
uv run ruff check --fix {{PYTHON_DIRS}}
|
||||
|
||||
# Combined Formatting & Linting
|
||||
fix: fix-format fix-lint
|
||||
|
||||
# Checking tasks
|
||||
# Check code formatting using ruff.
|
||||
check-format:
|
||||
uv run ruff format --check {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff.
|
||||
check-lint:
|
||||
uv run ruff check {{PYTHON_DIRS}}
|
||||
|
||||
# Type Checking
|
||||
# Type check code using pyright.
|
||||
typecheck:
|
||||
pyright {{PYTHON_DIRS}}
|
||||
# Type check code using ty.
|
||||
check-types:
|
||||
uv run ty check {{PYTHON_DIRS}}
|
||||
|
||||
# Combined Checks
|
||||
# Run all checks (format-check, lint, typecheck).
|
||||
check: format-check lint typecheck test
|
||||
check: check-format check-lint check-types
|
||||
|
||||
# Cleaning tasks
|
||||
# Remove Python bytecode and cache.
|
||||
@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs
|
||||
|
||||
# Train on example data.
|
||||
example-train OPTIONS="":
|
||||
batdetect2 train \
|
||||
uv run batdetect2 train \
|
||||
--val-dataset example_data/dataset.yaml \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
|
||||
@ -75,7 +75,6 @@ dev = [
|
||||
"ruff>=0.7.3",
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"basedpyright>=1.31.0",
|
||||
"myst-parser>=3.0.1",
|
||||
"sphinx-autobuild>=2024.10.3",
|
||||
"numpydoc>=1.8.0",
|
||||
@ -94,6 +93,12 @@ mlflow = ["mlflow>=3.1.1"]
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
"src/batdetect2/train/legacy",
|
||||
"src/batdetect2/plotting/legacy",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/finetune",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
@ -105,15 +110,11 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
[tool.ty.src]
|
||||
include = ["src", "tests"]
|
||||
pythonVersion = "3.10"
|
||||
pythonPlatform = "All"
|
||||
exclude = [
|
||||
"src/batdetect2/detector/",
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
"src/batdetect2/plot",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/train/legacy",
|
||||
"src/batdetect2/plotting/legacy",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/finetune",
|
||||
]
|
||||
|
||||
@ -113,7 +113,7 @@ def load_file_audio(
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
@ -137,7 +137,7 @@ def load_recording_audio(
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
@ -159,7 +159,7 @@ def load_clip_audio(
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
@ -286,7 +286,7 @@ def resample_audio_fourier(
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample( # type: ignore
|
||||
return resample(
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
|
||||
@ -103,17 +103,17 @@ def convert_to_annotation_group(
|
||||
y_inds.append(0)
|
||||
|
||||
annotations.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": 1.0,
|
||||
"det_prob": 1.0,
|
||||
"individual": "0",
|
||||
"event": event,
|
||||
"class_id": class_id, # type: ignore
|
||||
}
|
||||
Annotation(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_prob=1.0,
|
||||
det_prob=1.0,
|
||||
individual="0",
|
||||
event=event,
|
||||
class_id=class_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
@ -89,7 +89,7 @@ class FileAnnotation(TypedDict):
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
file_path: NotRequired[str]
|
||||
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
|
||||
"""Path to file."""
|
||||
|
||||
|
||||
|
||||
@ -26,15 +26,15 @@ def split_dataset_by_recordings(
|
||||
)
|
||||
|
||||
majority_class = (
|
||||
sound_events.groupby("recording_id")
|
||||
sound_events.groupby("recording_id") # type: ignore
|
||||
.apply(
|
||||
lambda group: (
|
||||
group["class_name"] # type: ignore
|
||||
group["class_name"]
|
||||
.value_counts()
|
||||
.sort_values(ascending=False)
|
||||
.index[0]
|
||||
),
|
||||
include_groups=False, # type: ignore
|
||||
include_groups=False,
|
||||
)
|
||||
.rename("class_name")
|
||||
.to_frame()
|
||||
@ -48,8 +48,8 @@ def split_dataset_by_recordings(
|
||||
random_state=random_state,
|
||||
)
|
||||
|
||||
train_ids_set = set(train.values) # type: ignore
|
||||
test_ids_set = set(test.values) # type: ignore
|
||||
train_ids_set = set(train.values)
|
||||
test_ids_set = set(test.values)
|
||||
|
||||
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set
|
||||
|
||||
|
||||
@ -175,14 +175,14 @@ def compute_class_summary(
|
||||
.rename("num recordings")
|
||||
)
|
||||
durations = (
|
||||
sound_events.groupby("class_name")
|
||||
sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
|
||||
.apply(
|
||||
lambda group: recordings[
|
||||
recordings["clip_annotation_id"].isin(
|
||||
group["clip_annotation_id"] # type: ignore
|
||||
group["clip_annotation_id"]
|
||||
)
|
||||
]["duration"].sum(),
|
||||
include_groups=False, # type: ignore
|
||||
include_groups=False,
|
||||
)
|
||||
.sort_values(ascending=False)
|
||||
.rename("duration")
|
||||
|
||||
@ -176,12 +176,7 @@ class MapTagValue:
|
||||
if self.target_key is None:
|
||||
tags.append(tag.model_copy(update=dict(value=value)))
|
||||
else:
|
||||
tags.append(
|
||||
data.Tag(
|
||||
key=self.target_key, # type: ignore
|
||||
value=value,
|
||||
)
|
||||
)
|
||||
tags.append(data.Tag(key=self.target_key, value=value))
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
|
||||
@ -120,11 +120,11 @@ class BBoxIOU(AffinityFunction):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prediction: Detection,
|
||||
gt: data.SoundEventAnnotation,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = gt.sound_event.geometry
|
||||
source_geometry = prediction.geometry
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
|
||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
@ -168,11 +168,11 @@ class GeometricIOU(AffinityFunction):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prediction: Detection,
|
||||
gt: data.SoundEventAnnotation,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = gt.sound_event.geometry
|
||||
source_geometry = prediction.geometry
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
|
||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
|
||||
@ -51,8 +51,8 @@ class TestDataset(Dataset[TestExample]):
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx: int) -> TestExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
def __getitem__(self, index: int) -> TestExample:
|
||||
clip_annotation = self.clip_annotations[index]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
@ -63,7 +63,7 @@ class TestDataset(Dataset[TestExample]):
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return TestExample(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
idx=torch.tensor(index),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: ClipDetections,
|
||||
) -> T_Output: ...
|
||||
) -> T_Output: ... # ty: ignore[empty-body]
|
||||
|
||||
def include_sound_event_annotation(
|
||||
self,
|
||||
|
||||
@ -46,14 +46,14 @@ class InferenceDataset(Dataset[DatasetItem]):
|
||||
def __len__(self):
|
||||
return len(self.clips)
|
||||
|
||||
def __getitem__(self, idx: int) -> DatasetItem:
|
||||
clip = self.clips[idx]
|
||||
def __getitem__(self, index: int) -> DatasetItem:
|
||||
clip = self.clips[index]
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return DatasetItem(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
idx=torch.tensor(index),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
@ -185,7 +185,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
)
|
||||
|
||||
decoder = build_decoder(
|
||||
in_channels=bottleneck.out_channels,
|
||||
in_channels=bottleneck.get_output_channels(),
|
||||
input_height=encoder.output_height,
|
||||
config=config.decoder,
|
||||
)
|
||||
|
||||
@ -20,20 +20,21 @@ research:
|
||||
spatial frequency information to filters, potentially enabling them to learn
|
||||
frequency-dependent patterns more effectively.
|
||||
|
||||
These blocks can be utilized directly in custom PyTorch model definitions or
|
||||
These blocks can be used directly in custom PyTorch model definitions or
|
||||
assembled into larger architectures.
|
||||
|
||||
A unified factory function `build_layer_from_config` allows creating instances
|
||||
of these blocks based on configuration objects.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Tuple, Union
|
||||
from typing import Annotated, List, Literal, Protocol, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
@ -51,17 +52,30 @@ __all__ = [
|
||||
"FreqCoordConvUpConfig",
|
||||
"StandardConvUpConfig",
|
||||
"LayerConfig",
|
||||
"build_layer_from_config",
|
||||
"build_layer",
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
def get_output_channels(self) -> int: ...
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height
|
||||
|
||||
|
||||
class Block(nn.Module, BlockProtocol): ...
|
||||
|
||||
|
||||
block_registry: Registry[Block, [int, int]] = Registry("block")
|
||||
|
||||
|
||||
class SelfAttentionConfig(BaseConfig):
|
||||
name: Literal["SelfAttention"] = "SelfAttention"
|
||||
attention_channels: int
|
||||
temperature: float = 1
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
class SelfAttention(Block):
|
||||
"""Self-Attention mechanism operating along the time dimension.
|
||||
|
||||
This module implements a scaled dot-product self-attention mechanism,
|
||||
@ -121,6 +135,7 @@ class SelfAttention(nn.Module):
|
||||
# Note, does not encode position information (absolute or relative)
|
||||
self.temperature = temperature
|
||||
self.att_dim = attention_channels
|
||||
self.output_channels = in_channels
|
||||
|
||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||
@ -190,6 +205,22 @@ class SelfAttention(nn.Module):
|
||||
att_weights = F.softmax(kk_qq, 1)
|
||||
return att_weights
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.output_channels
|
||||
|
||||
@block_registry.register(SelfAttentionConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: SelfAttentionConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
) -> "SelfAttention":
|
||||
return SelfAttention(
|
||||
in_channels=input_channels,
|
||||
attention_channels=config.attention_channels,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
|
||||
|
||||
class ConvConfig(BaseConfig):
|
||||
"""Configuration for a basic ConvBlock."""
|
||||
@ -207,7 +238,7 @@ class ConvConfig(BaseConfig):
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
class ConvBlock(Block):
|
||||
"""Basic Convolutional Block.
|
||||
|
||||
A standard building block consisting of a 2D convolution, followed by
|
||||
@ -258,8 +289,30 @@ class ConvBlock(nn.Module):
|
||||
"""
|
||||
return F.relu_(self.batch_norm(self.conv(x)))
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
@block_registry.register(ConvConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ConvConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return ConvBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
)
|
||||
|
||||
|
||||
class VerticalConvConfig(BaseConfig):
|
||||
name: Literal["VerticalConv"] = "VerticalConv"
|
||||
channels: int
|
||||
|
||||
|
||||
class VerticalConv(Block):
|
||||
"""Convolutional layer that aggregates features across the entire height.
|
||||
|
||||
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
||||
@ -312,6 +365,22 @@ class VerticalConv(nn.Module):
|
||||
"""
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
@block_registry.register(VerticalConvConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: VerticalConvConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return VerticalConv(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
|
||||
class FreqCoordConvDownConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvDownBlock."""
|
||||
@ -329,7 +398,7 @@ class FreqCoordConvDownConfig(BaseConfig):
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class FreqCoordConvDownBlock(nn.Module):
|
||||
class FreqCoordConvDownBlock(Block):
|
||||
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
||||
@ -402,6 +471,27 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
x = F.relu(self.batch_norm(x), inplace=True)
|
||||
return x
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height // 2
|
||||
|
||||
@block_registry.register(FreqCoordConvDownConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: FreqCoordConvDownConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return FreqCoordConvDownBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
)
|
||||
|
||||
|
||||
class StandardConvDownConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvDownBlock."""
|
||||
@ -419,7 +509,7 @@ class StandardConvDownConfig(BaseConfig):
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class StandardConvDownBlock(nn.Module):
|
||||
class StandardConvDownBlock(Block):
|
||||
"""Standard Downsampling Convolutional Block.
|
||||
|
||||
A basic downsampling block consisting of a 2D convolution, followed by
|
||||
@ -472,6 +562,26 @@ class StandardConvDownBlock(nn.Module):
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.batch_norm(x), inplace=True)
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height // 2
|
||||
|
||||
@block_registry.register(StandardConvDownConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: StandardConvDownConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return StandardConvDownBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
)
|
||||
|
||||
|
||||
class FreqCoordConvUpConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvUpBlock."""
|
||||
@ -488,8 +598,14 @@ class FreqCoordConvUpConfig(BaseConfig):
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
up_mode: str = "bilinear"
|
||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||
|
||||
class FreqCoordConvUpBlock(nn.Module):
|
||||
up_scale: Tuple[int, int] = (2, 2)
|
||||
"""Scaling factor for height and width during upsampling."""
|
||||
|
||||
|
||||
class FreqCoordConvUpBlock(Block):
|
||||
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block implements an upsampling step followed by a convolution,
|
||||
@ -581,6 +697,29 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height * 2
|
||||
|
||||
@block_registry.register(FreqCoordConvUpConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: FreqCoordConvUpConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return FreqCoordConvUpBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
up_mode=config.up_mode,
|
||||
up_scale=config.up_scale,
|
||||
)
|
||||
|
||||
|
||||
class StandardConvUpConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvUpBlock."""
|
||||
@ -597,8 +736,14 @@ class StandardConvUpConfig(BaseConfig):
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
up_mode: str = "bilinear"
|
||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||
|
||||
class StandardConvUpBlock(nn.Module):
|
||||
up_scale: Tuple[int, int] = (2, 2)
|
||||
"""Scaling factor for height and width during upsampling."""
|
||||
|
||||
|
||||
class StandardConvUpBlock(Block):
|
||||
"""Standard Upsampling Convolutional Block.
|
||||
|
||||
A basic upsampling block used in CNN decoders. It first upsamples the input
|
||||
@ -669,6 +814,28 @@ class StandardConvUpBlock(nn.Module):
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height * 2
|
||||
|
||||
@block_registry.register(StandardConvUpConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: StandardConvUpConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return StandardConvUpBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
up_mode=config.up_mode,
|
||||
up_scale=config.up_scale,
|
||||
)
|
||||
|
||||
|
||||
LayerConfig = Annotated[
|
||||
Union[
|
||||
@ -690,11 +857,61 @@ class LayerGroupConfig(BaseConfig):
|
||||
layers: List[LayerConfig]
|
||||
|
||||
|
||||
def build_layer_from_config(
|
||||
class LayerGroup(nn.Module):
|
||||
"""Standard implementation of the `LayerGroup` architecture."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers: list[Block],
|
||||
input_height: int,
|
||||
input_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = layers
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layers(x)
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.blocks[-1].get_output_channels()
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
for block in self.blocks:
|
||||
input_height = block.get_output_height(input_height)
|
||||
return input_height
|
||||
|
||||
@block_registry.register(LayerGroupConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: LayerGroupConfig,
|
||||
input_height: int,
|
||||
input_channels: int,
|
||||
):
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer = build_layer(
|
||||
input_height=input_height,
|
||||
in_channels=input_channels,
|
||||
config=layer_config,
|
||||
)
|
||||
layers.append(layer)
|
||||
input_height = layer.get_output_height(input_height)
|
||||
input_channels = layer.get_output_channels()
|
||||
|
||||
return LayerGroup(
|
||||
layers=layers,
|
||||
input_height=input_height,
|
||||
input_channels=input_channels,
|
||||
)
|
||||
|
||||
|
||||
def build_layer(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
config: LayerConfig,
|
||||
) -> Tuple[nn.Module, int, int]:
|
||||
) -> Block:
|
||||
"""Factory function to build a specific nn.Module block from its config.
|
||||
|
||||
Takes configuration object (one of the types included in the `LayerConfig`
|
||||
@ -731,93 +948,4 @@ def build_layer_from_config(
|
||||
ValueError
|
||||
If parameters derived from the config are invalid for the block.
|
||||
"""
|
||||
if config.name == "ConvBlock":
|
||||
return (
|
||||
ConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.name == "FreqCoordConvDown":
|
||||
return (
|
||||
FreqCoordConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.name == "StandardConvDown":
|
||||
return (
|
||||
StandardConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.name == "FreqCoordConvUp":
|
||||
return (
|
||||
FreqCoordConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.name == "StandardConvUp":
|
||||
return (
|
||||
StandardConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.name == "SelfAttention":
|
||||
return (
|
||||
SelfAttention(
|
||||
in_channels=in_channels,
|
||||
attention_channels=config.attention_channels,
|
||||
temperature=config.temperature,
|
||||
),
|
||||
config.attention_channels,
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.name == "LayerGroup":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
blocks = []
|
||||
|
||||
for block_config in config.layers:
|
||||
block, current_channels, current_height = build_layer_from_config(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
config=block_config,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
return nn.Sequential(*blocks), current_channels, current_height
|
||||
|
||||
raise NotImplementedError(f"Unknown block type {config.name}")
|
||||
return block_registry.build(config, in_channels, input_height)
|
||||
|
||||
@ -22,9 +22,10 @@ from torch import nn
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
Block,
|
||||
SelfAttentionConfig,
|
||||
VerticalConv,
|
||||
build_layer_from_config,
|
||||
build_layer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -34,7 +35,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
class Bottleneck(Block):
|
||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||
|
||||
This implementation represents the simplest bottleneck structure
|
||||
@ -86,6 +87,7 @@ class Bottleneck(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.input_height = input_height
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bottleneck_channels = (
|
||||
bottleneck_channels
|
||||
if bottleneck_channels is not None
|
||||
@ -172,7 +174,7 @@ def build_bottleneck(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
config: BottleneckConfig | None = None,
|
||||
) -> nn.Module:
|
||||
) -> Block:
|
||||
"""Factory function to build the Bottleneck module from configuration.
|
||||
|
||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||
@ -206,11 +208,13 @@ def build_bottleneck(
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
layer = build_layer(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
config=layer_config,
|
||||
)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.get_output_channels()
|
||||
assert current_height == input_height, (
|
||||
"Bottleneck layers should not change the spectrogram height"
|
||||
)
|
||||
|
||||
@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
|
||||
FreqCoordConvUpConfig,
|
||||
LayerGroupConfig,
|
||||
StandardConvUpConfig,
|
||||
build_layer_from_config,
|
||||
build_layer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -259,11 +259,13 @@ def build_decoder(
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
layer = build_layer(
|
||||
in_channels=current_channels,
|
||||
input_height=current_height,
|
||||
config=layer_config,
|
||||
)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.get_output_channels()
|
||||
layers.append(layer)
|
||||
|
||||
return Decoder(
|
||||
|
||||
@ -32,7 +32,7 @@ from batdetect2.models.blocks import (
|
||||
FreqCoordConvDownConfig,
|
||||
LayerGroupConfig,
|
||||
StandardConvDownConfig,
|
||||
build_layer_from_config,
|
||||
build_layer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -300,12 +300,14 @@ def build_encoder(
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
layer = build_layer(
|
||||
in_channels=current_channels,
|
||||
input_height=current_height,
|
||||
config=layer_config,
|
||||
)
|
||||
layers.append(layer)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.get_output_channels()
|
||||
|
||||
return Encoder(
|
||||
input_height=input_height,
|
||||
|
||||
@ -19,9 +19,9 @@ def create_ax(
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
|
||||
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs)
|
||||
|
||||
return ax # type: ignore
|
||||
return ax
|
||||
|
||||
|
||||
def plot_spectrogram(
|
||||
|
||||
@ -66,7 +66,7 @@ def plot_clip_detections(
|
||||
if m.gt is not None:
|
||||
color = gt_color if is_match else missed_gt_color
|
||||
plot_geometry(
|
||||
m.gt.sound_event.geometry, # type: ignore
|
||||
m.gt.sound_event.geometry,
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
linewidth=linewidth,
|
||||
|
||||
@ -100,7 +100,7 @@ def plot_classification_heatmap(
|
||||
class_heatmap,
|
||||
vmax=1,
|
||||
vmin=0,
|
||||
cmap=create_colormap(color), # type: ignore
|
||||
cmap=create_colormap(color),
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
|
||||
@ -301,4 +301,4 @@ def _get_marker_positions(
|
||||
size = len(thresholds)
|
||||
cut_points = np.linspace(0, 1, n_points)
|
||||
indices = np.searchsorted(thresholds[::-1], cut_points)
|
||||
return np.clip(size - indices, 0, size - 1) # type: ignore
|
||||
return np.clip(size - indices, 0, size - 1)
|
||||
|
||||
@ -72,7 +72,7 @@ class TargetClassConfig(BaseConfig):
|
||||
|
||||
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||
name="bat",
|
||||
match_if=AllOfConfig(
|
||||
condition_input=AllOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
||||
NotConfig(
|
||||
|
||||
@ -27,6 +27,7 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core.arrays import spec_to_xarray
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
@ -88,6 +89,9 @@ DEFAULT_ANCHOR = "bottom-left"
|
||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||
|
||||
|
||||
roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper")
|
||||
|
||||
|
||||
class AnchorBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `AnchorBBoxMapper`.
|
||||
|
||||
@ -244,6 +248,15 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
anchor=self.anchor,
|
||||
)
|
||||
|
||||
@roi_mapper_registry.register(AnchorBBoxMapperConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AnchorBBoxMapperConfig):
|
||||
return AnchorBBoxMapper(
|
||||
anchor=config.anchor,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `PeakEnergyBBoxMapper`.
|
||||
@ -412,6 +425,22 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
]
|
||||
)
|
||||
|
||||
@roi_mapper_registry.register(PeakEnergyBBoxMapperConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PeakEnergyBBoxMapperConfig):
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
loading_buffer=config.loading_buffer,
|
||||
)
|
||||
|
||||
|
||||
ROIMapperConfig = Annotated[
|
||||
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
|
||||
@ -445,31 +474,7 @@ def build_roi_mapper(
|
||||
If the `name` in the config does not correspond to a known mapper.
|
||||
"""
|
||||
config = config or AnchorBBoxMapperConfig()
|
||||
|
||||
if config.name == "anchor_bbox":
|
||||
return AnchorBBoxMapper(
|
||||
anchor=config.anchor,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
if config.name == "peak_energy_bbox":
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
loading_buffer=config.loading_buffer,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"No ROI mapper of name '{config.name}' is implemented"
|
||||
)
|
||||
return roi_mapper_registry.build(config)
|
||||
|
||||
|
||||
VALID_ANCHORS = [
|
||||
@ -636,7 +641,7 @@ def get_peak_energy_coordinates(
|
||||
)
|
||||
|
||||
index = selection.argmax(dim=["time", "frequency"])
|
||||
point = selection.isel(index) # type: ignore
|
||||
point = selection.isel(index)
|
||||
peak_time: float = point.time.item()
|
||||
peak_freq: float = point.frequency.item()
|
||||
return peak_time, peak_freq
|
||||
|
||||
@ -548,63 +548,6 @@ class MaybeApply(torch.nn.Module):
|
||||
return self.augmentation(tensor, clip_annotation)
|
||||
|
||||
|
||||
def build_augmentation_from_config(
|
||||
config: AugmentationConfig,
|
||||
samplerate: int,
|
||||
audio_source: AudioSource | None = None,
|
||||
) -> Augmentation | None:
|
||||
"""Factory function to build a single augmentation from its config."""
|
||||
if config.name == "mix_audio":
|
||||
if audio_source is None:
|
||||
warnings.warn(
|
||||
"Mix audio augmentation ('mix_audio') requires an "
|
||||
"'example_source' callable to be provided.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return None
|
||||
|
||||
return MixAudio(
|
||||
example_source=audio_source,
|
||||
min_weight=config.min_weight,
|
||||
max_weight=config.max_weight,
|
||||
)
|
||||
|
||||
if config.name == "add_echo":
|
||||
return AddEcho(
|
||||
max_delay=int(config.max_delay * samplerate),
|
||||
min_weight=config.min_weight,
|
||||
max_weight=config.max_weight,
|
||||
)
|
||||
|
||||
if config.name == "scale_volume":
|
||||
return ScaleVolume(
|
||||
max_scaling=config.max_scaling,
|
||||
min_scaling=config.min_scaling,
|
||||
)
|
||||
|
||||
if config.name == "warp":
|
||||
return Warp(
|
||||
delta=config.delta,
|
||||
)
|
||||
|
||||
if config.name == "mask_time":
|
||||
return MaskTime(
|
||||
max_perc=config.max_perc,
|
||||
max_masks=config.max_masks,
|
||||
)
|
||||
|
||||
if config.name == "mask_freq":
|
||||
return MaskFrequency(
|
||||
max_perc=config.max_perc,
|
||||
max_masks=config.max_masks,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Invalid or unimplemented augmentation type: "
|
||||
f"{config.augmentation_type}"
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||
enabled=True,
|
||||
audio=[
|
||||
|
||||
@ -61,8 +61,8 @@ class TrainingDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
def __getitem__(self, index) -> TrainExample:
|
||||
clip_annotation = self.clip_annotations[index]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
@ -95,7 +95,7 @@ class TrainingDataset(Dataset):
|
||||
detection_heatmap=heatmaps.detection,
|
||||
class_heatmap=heatmaps.classes,
|
||||
size_heatmap=heatmaps.size,
|
||||
idx=torch.tensor(idx),
|
||||
idx=torch.tensor(index),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
@ -121,8 +121,8 @@ class ValidationDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
def __getitem__(self, index) -> TrainExample:
|
||||
clip_annotation = self.clip_annotations[index]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
@ -141,7 +141,7 @@ class ValidationDataset(Dataset):
|
||||
detection_heatmap=heatmaps.detection,
|
||||
class_heatmap=heatmaps.classes,
|
||||
size_heatmap=heatmaps.size,
|
||||
idx=torch.tensor(idx),
|
||||
idx=torch.tensor(index),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
@ -472,7 +472,7 @@ def build_loss(
|
||||
|
||||
size_loss_fn = BBoxLoss()
|
||||
|
||||
return LossFunction( # type: ignore
|
||||
return LossFunction(
|
||||
size_loss=size_loss_fn,
|
||||
classification_loss=classification_loss_fn,
|
||||
detection_loss=detection_loss_fn,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user