mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Fix type errors
This commit is contained in:
parent
cce1b49a8d
commit
d52e988b8f
53
justfile
53
justfile
@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov"
|
|||||||
help:
|
help:
|
||||||
@just --list
|
@just --list
|
||||||
|
|
||||||
|
install:
|
||||||
|
uv sync
|
||||||
|
|
||||||
# Testing & Coverage
|
# Testing & Coverage
|
||||||
# Run tests using pytest.
|
# Run tests using pytest.
|
||||||
test:
|
test:
|
||||||
pytest {{TESTS_DIR}}
|
uv run pytest {{TESTS_DIR}}
|
||||||
|
|
||||||
# Run tests and generate coverage data.
|
# Run tests and generate coverage data.
|
||||||
coverage:
|
coverage:
|
||||||
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||||
|
|
||||||
# Generate an HTML coverage report.
|
# Generate an HTML coverage report.
|
||||||
coverage-html: coverage
|
coverage-html: coverage
|
||||||
@echo "Generating HTML coverage report..."
|
@echo "Generating HTML coverage report..."
|
||||||
coverage html -d {{HTML_COVERAGE_DIR}}
|
uv run coverage html -d {{HTML_COVERAGE_DIR}}
|
||||||
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
||||||
|
|
||||||
# Serve the HTML coverage report locally.
|
# Serve the HTML coverage report locally.
|
||||||
coverage-serve: coverage-html
|
coverage-serve: coverage-html
|
||||||
@echo "Serving report at http://localhost:8000/ ..."
|
@echo "Serving report at http://localhost:8000/ ..."
|
||||||
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||||
|
|
||||||
# Documentation
|
# Documentation
|
||||||
# Build documentation using Sphinx.
|
# Build documentation using Sphinx.
|
||||||
docs:
|
docs:
|
||||||
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||||
|
|
||||||
# Serve documentation with live reload.
|
# Serve documentation with live reload.
|
||||||
docs-serve:
|
docs-serve:
|
||||||
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||||
|
|
||||||
# Formatting & Linting
|
# Formatting & Linting
|
||||||
# Format code using ruff.
|
# Format code using ruff.
|
||||||
format:
|
fix-format:
|
||||||
ruff format {{PYTHON_DIRS}}
|
uv run ruff format {{PYTHON_DIRS}}
|
||||||
|
|
||||||
# Check code formatting using ruff.
|
|
||||||
format-check:
|
|
||||||
ruff format --check {{PYTHON_DIRS}}
|
|
||||||
|
|
||||||
# Lint code using ruff.
|
|
||||||
lint:
|
|
||||||
ruff check {{PYTHON_DIRS}}
|
|
||||||
|
|
||||||
# Lint code using ruff and apply automatic fixes.
|
# Lint code using ruff and apply automatic fixes.
|
||||||
lint-fix:
|
fix-lint:
|
||||||
ruff check --fix {{PYTHON_DIRS}}
|
uv run ruff check --fix {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Combined Formatting & Linting
|
||||||
|
fix: fix-format fix-lint
|
||||||
|
|
||||||
|
# Checking tasks
|
||||||
|
# Check code formatting using ruff.
|
||||||
|
check-format:
|
||||||
|
uv run ruff format --check {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Lint code using ruff.
|
||||||
|
check-lint:
|
||||||
|
uv run ruff check {{PYTHON_DIRS}}
|
||||||
|
|
||||||
# Type Checking
|
# Type Checking
|
||||||
# Type check code using pyright.
|
# Type check code using ty.
|
||||||
typecheck:
|
check-types:
|
||||||
pyright {{PYTHON_DIRS}}
|
uv run ty check {{PYTHON_DIRS}}
|
||||||
|
|
||||||
# Combined Checks
|
# Combined Checks
|
||||||
# Run all checks (format-check, lint, typecheck).
|
# Run all checks (format-check, lint, typecheck).
|
||||||
check: format-check lint typecheck test
|
check: check-format check-lint check-types
|
||||||
|
|
||||||
# Cleaning tasks
|
# Cleaning tasks
|
||||||
# Remove Python bytecode and cache.
|
# Remove Python bytecode and cache.
|
||||||
@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs
|
|||||||
|
|
||||||
# Train on example data.
|
# Train on example data.
|
||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
batdetect2 train \
|
uv run batdetect2 train \
|
||||||
--val-dataset example_data/dataset.yaml \
|
--val-dataset example_data/dataset.yaml \
|
||||||
--config example_data/config.yaml \
|
--config example_data/config.yaml \
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
|
|||||||
@ -75,7 +75,6 @@ dev = [
|
|||||||
"ruff>=0.7.3",
|
"ruff>=0.7.3",
|
||||||
"ipykernel>=6.29.4",
|
"ipykernel>=6.29.4",
|
||||||
"setuptools>=69.5.1",
|
"setuptools>=69.5.1",
|
||||||
"basedpyright>=1.31.0",
|
|
||||||
"myst-parser>=3.0.1",
|
"myst-parser>=3.0.1",
|
||||||
"sphinx-autobuild>=2024.10.3",
|
"sphinx-autobuild>=2024.10.3",
|
||||||
"numpydoc>=1.8.0",
|
"numpydoc>=1.8.0",
|
||||||
@ -94,6 +93,12 @@ mlflow = ["mlflow>=3.1.1"]
|
|||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
|
exclude = [
|
||||||
|
"src/batdetect2/train/legacy",
|
||||||
|
"src/batdetect2/plotting/legacy",
|
||||||
|
"src/batdetect2/evaluate/legacy",
|
||||||
|
"src/batdetect2/finetune",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
@ -105,15 +110,11 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
|||||||
[tool.ruff.lint.pydocstyle]
|
[tool.ruff.lint.pydocstyle]
|
||||||
convention = "numpy"
|
convention = "numpy"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.ty.src]
|
||||||
include = ["src", "tests"]
|
include = ["src", "tests"]
|
||||||
pythonVersion = "3.10"
|
|
||||||
pythonPlatform = "All"
|
|
||||||
exclude = [
|
exclude = [
|
||||||
"src/batdetect2/detector/",
|
|
||||||
"src/batdetect2/finetune",
|
|
||||||
"src/batdetect2/utils",
|
|
||||||
"src/batdetect2/plot",
|
|
||||||
"src/batdetect2/evaluate/legacy",
|
|
||||||
"src/batdetect2/train/legacy",
|
"src/batdetect2/train/legacy",
|
||||||
|
"src/batdetect2/plotting/legacy",
|
||||||
|
"src/batdetect2/evaluate/legacy",
|
||||||
|
"src/batdetect2/finetune",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -113,7 +113,7 @@ def load_file_audio(
|
|||||||
samplerate: int | None = None,
|
samplerate: int | None = None,
|
||||||
config: ResampleConfig | None = None,
|
config: ResampleConfig | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio from a file path using specified config."""
|
"""Load and preprocess audio from a file path using specified config."""
|
||||||
try:
|
try:
|
||||||
@ -137,7 +137,7 @@ def load_recording_audio(
|
|||||||
samplerate: int | None = None,
|
samplerate: int | None = None,
|
||||||
config: ResampleConfig | None = None,
|
config: ResampleConfig | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio content of a recording using config."""
|
"""Load and preprocess the entire audio content of a recording using config."""
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
@ -159,7 +159,7 @@ def load_clip_audio(
|
|||||||
samplerate: int | None = None,
|
samplerate: int | None = None,
|
||||||
config: ResampleConfig | None = None,
|
config: ResampleConfig | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess a specific audio clip segment based on config."""
|
"""Load and preprocess a specific audio clip segment based on config."""
|
||||||
try:
|
try:
|
||||||
@ -286,7 +286,7 @@ def resample_audio_fourier(
|
|||||||
If `num` is negative.
|
If `num` is negative.
|
||||||
"""
|
"""
|
||||||
ratio = sr_new / sr_orig
|
ratio = sr_new / sr_orig
|
||||||
return resample( # type: ignore
|
return resample(
|
||||||
array,
|
array,
|
||||||
int(array.shape[axis] * ratio),
|
int(array.shape[axis] * ratio),
|
||||||
axis=axis,
|
axis=axis,
|
||||||
|
|||||||
@ -103,17 +103,17 @@ def convert_to_annotation_group(
|
|||||||
y_inds.append(0)
|
y_inds.append(0)
|
||||||
|
|
||||||
annotations.append(
|
annotations.append(
|
||||||
{
|
Annotation(
|
||||||
"start_time": start_time,
|
start_time=start_time,
|
||||||
"end_time": end_time,
|
end_time=end_time,
|
||||||
"low_freq": low_freq,
|
low_freq=low_freq,
|
||||||
"high_freq": high_freq,
|
high_freq=high_freq,
|
||||||
"class_prob": 1.0,
|
class_prob=1.0,
|
||||||
"det_prob": 1.0,
|
det_prob=1.0,
|
||||||
"individual": "0",
|
individual="0",
|
||||||
"event": event,
|
event=event,
|
||||||
"class_id": class_id, # type: ignore
|
class_id=class_id,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class FileAnnotation(TypedDict):
|
|||||||
annotation: List[Annotation]
|
annotation: List[Annotation]
|
||||||
"""List of annotations."""
|
"""List of annotations."""
|
||||||
|
|
||||||
file_path: NotRequired[str]
|
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
|
||||||
"""Path to file."""
|
"""Path to file."""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -26,15 +26,15 @@ def split_dataset_by_recordings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
majority_class = (
|
majority_class = (
|
||||||
sound_events.groupby("recording_id")
|
sound_events.groupby("recording_id") # type: ignore
|
||||||
.apply(
|
.apply(
|
||||||
lambda group: (
|
lambda group: (
|
||||||
group["class_name"] # type: ignore
|
group["class_name"]
|
||||||
.value_counts()
|
.value_counts()
|
||||||
.sort_values(ascending=False)
|
.sort_values(ascending=False)
|
||||||
.index[0]
|
.index[0]
|
||||||
),
|
),
|
||||||
include_groups=False, # type: ignore
|
include_groups=False,
|
||||||
)
|
)
|
||||||
.rename("class_name")
|
.rename("class_name")
|
||||||
.to_frame()
|
.to_frame()
|
||||||
@ -48,8 +48,8 @@ def split_dataset_by_recordings(
|
|||||||
random_state=random_state,
|
random_state=random_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_ids_set = set(train.values) # type: ignore
|
train_ids_set = set(train.values)
|
||||||
test_ids_set = set(test.values) # type: ignore
|
test_ids_set = set(test.values)
|
||||||
|
|
||||||
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set
|
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set
|
||||||
|
|
||||||
|
|||||||
@ -175,14 +175,14 @@ def compute_class_summary(
|
|||||||
.rename("num recordings")
|
.rename("num recordings")
|
||||||
)
|
)
|
||||||
durations = (
|
durations = (
|
||||||
sound_events.groupby("class_name")
|
sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
|
||||||
.apply(
|
.apply(
|
||||||
lambda group: recordings[
|
lambda group: recordings[
|
||||||
recordings["clip_annotation_id"].isin(
|
recordings["clip_annotation_id"].isin(
|
||||||
group["clip_annotation_id"] # type: ignore
|
group["clip_annotation_id"]
|
||||||
)
|
)
|
||||||
]["duration"].sum(),
|
]["duration"].sum(),
|
||||||
include_groups=False, # type: ignore
|
include_groups=False,
|
||||||
)
|
)
|
||||||
.sort_values(ascending=False)
|
.sort_values(ascending=False)
|
||||||
.rename("duration")
|
.rename("duration")
|
||||||
|
|||||||
@ -176,12 +176,7 @@ class MapTagValue:
|
|||||||
if self.target_key is None:
|
if self.target_key is None:
|
||||||
tags.append(tag.model_copy(update=dict(value=value)))
|
tags.append(tag.model_copy(update=dict(value=value)))
|
||||||
else:
|
else:
|
||||||
tags.append(
|
tags.append(data.Tag(key=self.target_key, value=value))
|
||||||
data.Tag(
|
|
||||||
key=self.target_key, # type: ignore
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
|
|||||||
@ -120,11 +120,11 @@ class BBoxIOU(AffinityFunction):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prediction: Detection,
|
detection: Detection,
|
||||||
gt: data.SoundEventAnnotation,
|
ground_truth: data.SoundEventAnnotation,
|
||||||
):
|
):
|
||||||
target_geometry = gt.sound_event.geometry
|
target_geometry = ground_truth.sound_event.geometry
|
||||||
source_geometry = prediction.geometry
|
source_geometry = detection.geometry
|
||||||
|
|
||||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||||
target_geometry = buffer_geometry(
|
target_geometry = buffer_geometry(
|
||||||
@ -168,11 +168,11 @@ class GeometricIOU(AffinityFunction):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prediction: Detection,
|
detection: Detection,
|
||||||
gt: data.SoundEventAnnotation,
|
ground_truth: data.SoundEventAnnotation,
|
||||||
):
|
):
|
||||||
target_geometry = gt.sound_event.geometry
|
target_geometry = ground_truth.sound_event.geometry
|
||||||
source_geometry = prediction.geometry
|
source_geometry = detection.geometry
|
||||||
|
|
||||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||||
target_geometry = buffer_geometry(
|
target_geometry = buffer_geometry(
|
||||||
|
|||||||
@ -51,8 +51,8 @@ class TestDataset(Dataset[TestExample]):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.clip_annotations)
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> TestExample:
|
def __getitem__(self, index: int) -> TestExample:
|
||||||
clip_annotation = self.clip_annotations[idx]
|
clip_annotation = self.clip_annotations[index]
|
||||||
|
|
||||||
if self.clipper is not None:
|
if self.clipper is not None:
|
||||||
clip_annotation = self.clipper(clip_annotation)
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
@ -63,7 +63,7 @@ class TestDataset(Dataset[TestExample]):
|
|||||||
spectrogram = self.preprocessor(wav_tensor)
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
return TestExample(
|
return TestExample(
|
||||||
spec=spectrogram,
|
spec=spectrogram,
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(index),
|
||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: ClipDetections,
|
prediction: ClipDetections,
|
||||||
) -> T_Output: ...
|
) -> T_Output: ... # ty: ignore[empty-body]
|
||||||
|
|
||||||
def include_sound_event_annotation(
|
def include_sound_event_annotation(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -46,14 +46,14 @@ class InferenceDataset(Dataset[DatasetItem]):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.clips)
|
return len(self.clips)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> DatasetItem:
|
def __getitem__(self, index: int) -> DatasetItem:
|
||||||
clip = self.clips[idx]
|
clip = self.clips[index]
|
||||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||||
spectrogram = self.preprocessor(wav_tensor)
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
return DatasetItem(
|
return DatasetItem(
|
||||||
spec=spectrogram,
|
spec=spectrogram,
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(index),
|
||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -185,7 +185,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
decoder = build_decoder(
|
decoder = build_decoder(
|
||||||
in_channels=bottleneck.out_channels,
|
in_channels=bottleneck.get_output_channels(),
|
||||||
input_height=encoder.output_height,
|
input_height=encoder.output_height,
|
||||||
config=config.decoder,
|
config=config.decoder,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -20,20 +20,21 @@ research:
|
|||||||
spatial frequency information to filters, potentially enabling them to learn
|
spatial frequency information to filters, potentially enabling them to learn
|
||||||
frequency-dependent patterns more effectively.
|
frequency-dependent patterns more effectively.
|
||||||
|
|
||||||
These blocks can be utilized directly in custom PyTorch model definitions or
|
These blocks can be used directly in custom PyTorch model definitions or
|
||||||
assembled into larger architectures.
|
assembled into larger architectures.
|
||||||
|
|
||||||
A unified factory function `build_layer_from_config` allows creating instances
|
A unified factory function `build_layer_from_config` allows creating instances
|
||||||
of these blocks based on configuration objects.
|
of these blocks based on configuration objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Tuple, Union
|
from typing import Annotated, List, Literal, Protocol, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -51,17 +52,30 @@ __all__ = [
|
|||||||
"FreqCoordConvUpConfig",
|
"FreqCoordConvUpConfig",
|
||||||
"StandardConvUpConfig",
|
"StandardConvUpConfig",
|
||||||
"LayerConfig",
|
"LayerConfig",
|
||||||
"build_layer_from_config",
|
"build_layer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockProtocol(Protocol):
|
||||||
|
def get_output_channels(self) -> int: ...
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
return input_height
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module, BlockProtocol): ...
|
||||||
|
|
||||||
|
|
||||||
|
block_registry: Registry[Block, [int, int]] = Registry("block")
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionConfig(BaseConfig):
|
class SelfAttentionConfig(BaseConfig):
|
||||||
name: Literal["SelfAttention"] = "SelfAttention"
|
name: Literal["SelfAttention"] = "SelfAttention"
|
||||||
attention_channels: int
|
attention_channels: int
|
||||||
temperature: float = 1
|
temperature: float = 1
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(Block):
|
||||||
"""Self-Attention mechanism operating along the time dimension.
|
"""Self-Attention mechanism operating along the time dimension.
|
||||||
|
|
||||||
This module implements a scaled dot-product self-attention mechanism,
|
This module implements a scaled dot-product self-attention mechanism,
|
||||||
@ -121,6 +135,7 @@ class SelfAttention(nn.Module):
|
|||||||
# Note, does not encode position information (absolute or relative)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.att_dim = attention_channels
|
self.att_dim = attention_channels
|
||||||
|
self.output_channels = in_channels
|
||||||
|
|
||||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||||
@ -190,6 +205,22 @@ class SelfAttention(nn.Module):
|
|||||||
att_weights = F.softmax(kk_qq, 1)
|
att_weights = F.softmax(kk_qq, 1)
|
||||||
return att_weights
|
return att_weights
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.output_channels
|
||||||
|
|
||||||
|
@block_registry.register(SelfAttentionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: SelfAttentionConfig,
|
||||||
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
) -> "SelfAttention":
|
||||||
|
return SelfAttention(
|
||||||
|
in_channels=input_channels,
|
||||||
|
attention_channels=config.attention_channels,
|
||||||
|
temperature=config.temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConvConfig(BaseConfig):
|
class ConvConfig(BaseConfig):
|
||||||
"""Configuration for a basic ConvBlock."""
|
"""Configuration for a basic ConvBlock."""
|
||||||
@ -207,7 +238,7 @@ class ConvConfig(BaseConfig):
|
|||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(Block):
|
||||||
"""Basic Convolutional Block.
|
"""Basic Convolutional Block.
|
||||||
|
|
||||||
A standard building block consisting of a 2D convolution, followed by
|
A standard building block consisting of a 2D convolution, followed by
|
||||||
@ -258,8 +289,30 @@ class ConvBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return F.relu_(self.batch_norm(self.conv(x)))
|
return F.relu_(self.batch_norm(self.conv(x)))
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.conv.out_channels
|
||||||
|
|
||||||
class VerticalConv(nn.Module):
|
@block_registry.register(ConvConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ConvConfig,
|
||||||
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
return ConvBlock(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VerticalConvConfig(BaseConfig):
|
||||||
|
name: Literal["VerticalConv"] = "VerticalConv"
|
||||||
|
channels: int
|
||||||
|
|
||||||
|
|
||||||
|
class VerticalConv(Block):
|
||||||
"""Convolutional layer that aggregates features across the entire height.
|
"""Convolutional layer that aggregates features across the entire height.
|
||||||
|
|
||||||
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
||||||
@ -312,6 +365,22 @@ class VerticalConv(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return F.relu_(self.bn(self.conv(x)))
|
return F.relu_(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.conv.out_channels
|
||||||
|
|
||||||
|
@block_registry.register(VerticalConvConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: VerticalConvConfig,
|
||||||
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
return VerticalConv(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=config.channels,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvDownConfig(BaseConfig):
|
class FreqCoordConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvDownBlock."""
|
"""Configuration for a FreqCoordConvDownBlock."""
|
||||||
@ -329,7 +398,7 @@ class FreqCoordConvDownConfig(BaseConfig):
|
|||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvDownBlock(nn.Module):
|
class FreqCoordConvDownBlock(Block):
|
||||||
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
||||||
@ -402,6 +471,27 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
x = F.relu(self.batch_norm(x), inplace=True)
|
x = F.relu(self.batch_norm(x), inplace=True)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.conv.out_channels
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
return input_height // 2
|
||||||
|
|
||||||
|
@block_registry.register(FreqCoordConvDownConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: FreqCoordConvDownConfig,
|
||||||
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
return FreqCoordConvDownBlock(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StandardConvDownConfig(BaseConfig):
|
class StandardConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvDownBlock."""
|
"""Configuration for a StandardConvDownBlock."""
|
||||||
@ -419,7 +509,7 @@ class StandardConvDownConfig(BaseConfig):
|
|||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class StandardConvDownBlock(nn.Module):
|
class StandardConvDownBlock(Block):
|
||||||
"""Standard Downsampling Convolutional Block.
|
"""Standard Downsampling Convolutional Block.
|
||||||
|
|
||||||
A basic downsampling block consisting of a 2D convolution, followed by
|
A basic downsampling block consisting of a 2D convolution, followed by
|
||||||
@ -472,6 +562,26 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
return F.relu(self.batch_norm(x), inplace=True)
|
return F.relu(self.batch_norm(x), inplace=True)
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.conv.out_channels
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
return input_height // 2
|
||||||
|
|
||||||
|
@block_registry.register(StandardConvDownConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: StandardConvDownConfig,
|
||||||
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
return StandardConvDownBlock(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvUpConfig(BaseConfig):
|
class FreqCoordConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvUpBlock."""
|
"""Configuration for a FreqCoordConvUpBlock."""
|
||||||
@ -488,8 +598,14 @@ class FreqCoordConvUpConfig(BaseConfig):
|
|||||||
pad_size: int = 1
|
pad_size: int = 1
|
||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
up_mode: str = "bilinear"
|
||||||
|
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||||
|
|
||||||
class FreqCoordConvUpBlock(nn.Module):
|
up_scale: Tuple[int, int] = (2, 2)
|
||||||
|
"""Scaling factor for height and width during upsampling."""
|
||||||
|
|
||||||
|
|
||||||
|
class FreqCoordConvUpBlock(Block):
|
||||||
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
This block implements an upsampling step followed by a convolution,
|
This block implements an upsampling step followed by a convolution,
|
||||||
@ -581,6 +697,29 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.conv.out_channels
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
return input_height * 2
|
||||||
|
|
||||||
|
@block_registry.register(FreqCoordConvUpConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: FreqCoordConvUpConfig,
|
||||||
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
return FreqCoordConvUpBlock(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
up_mode=config.up_mode,
|
||||||
|
up_scale=config.up_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StandardConvUpConfig(BaseConfig):
|
class StandardConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvUpBlock."""
|
"""Configuration for a StandardConvUpBlock."""
|
||||||
@ -597,8 +736,14 @@ class StandardConvUpConfig(BaseConfig):
|
|||||||
pad_size: int = 1
|
pad_size: int = 1
|
||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
up_mode: str = "bilinear"
|
||||||
|
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||||
|
|
||||||
class StandardConvUpBlock(nn.Module):
|
up_scale: Tuple[int, int] = (2, 2)
|
||||||
|
"""Scaling factor for height and width during upsampling."""
|
||||||
|
|
||||||
|
|
||||||
|
class StandardConvUpBlock(Block):
|
||||||
"""Standard Upsampling Convolutional Block.
|
"""Standard Upsampling Convolutional Block.
|
||||||
|
|
||||||
A basic upsampling block used in CNN decoders. It first upsamples the input
|
A basic upsampling block used in CNN decoders. It first upsamples the input
|
||||||
@ -669,6 +814,28 @@ class StandardConvUpBlock(nn.Module):
|
|||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.conv.out_channels
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
return input_height * 2
|
||||||
|
|
||||||
|
@block_registry.register(StandardConvUpConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: StandardConvUpConfig,
|
||||||
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
return StandardConvUpBlock(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
up_mode=config.up_mode,
|
||||||
|
up_scale=config.up_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LayerConfig = Annotated[
|
LayerConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
@ -690,11 +857,61 @@ class LayerGroupConfig(BaseConfig):
|
|||||||
layers: List[LayerConfig]
|
layers: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
def build_layer_from_config(
|
class LayerGroup(nn.Module):
|
||||||
|
"""Standard implementation of the `LayerGroup` architecture."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layers: list[Block],
|
||||||
|
input_height: int,
|
||||||
|
input_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = layers
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.blocks[-1].get_output_channels()
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
for block in self.blocks:
|
||||||
|
input_height = block.get_output_height(input_height)
|
||||||
|
return input_height
|
||||||
|
|
||||||
|
@block_registry.register(LayerGroupConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: LayerGroupConfig,
|
||||||
|
input_height: int,
|
||||||
|
input_channels: int,
|
||||||
|
):
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for layer_config in config.layers:
|
||||||
|
layer = build_layer(
|
||||||
|
input_height=input_height,
|
||||||
|
in_channels=input_channels,
|
||||||
|
config=layer_config,
|
||||||
|
)
|
||||||
|
layers.append(layer)
|
||||||
|
input_height = layer.get_output_height(input_height)
|
||||||
|
input_channels = layer.get_output_channels()
|
||||||
|
|
||||||
|
return LayerGroup(
|
||||||
|
layers=layers,
|
||||||
|
input_height=input_height,
|
||||||
|
input_channels=input_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_layer(
|
||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: LayerConfig,
|
config: LayerConfig,
|
||||||
) -> Tuple[nn.Module, int, int]:
|
) -> Block:
|
||||||
"""Factory function to build a specific nn.Module block from its config.
|
"""Factory function to build a specific nn.Module block from its config.
|
||||||
|
|
||||||
Takes configuration object (one of the types included in the `LayerConfig`
|
Takes configuration object (one of the types included in the `LayerConfig`
|
||||||
@ -731,93 +948,4 @@ def build_layer_from_config(
|
|||||||
ValueError
|
ValueError
|
||||||
If parameters derived from the config are invalid for the block.
|
If parameters derived from the config are invalid for the block.
|
||||||
"""
|
"""
|
||||||
if config.name == "ConvBlock":
|
return block_registry.build(config, in_channels, input_height)
|
||||||
return (
|
|
||||||
ConvBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
),
|
|
||||||
config.out_channels,
|
|
||||||
input_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "FreqCoordConvDown":
|
|
||||||
return (
|
|
||||||
FreqCoordConvDownBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
input_height=input_height,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
),
|
|
||||||
config.out_channels,
|
|
||||||
input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "StandardConvDown":
|
|
||||||
return (
|
|
||||||
StandardConvDownBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
),
|
|
||||||
config.out_channels,
|
|
||||||
input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "FreqCoordConvUp":
|
|
||||||
return (
|
|
||||||
FreqCoordConvUpBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
input_height=input_height,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
),
|
|
||||||
config.out_channels,
|
|
||||||
input_height * 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "StandardConvUp":
|
|
||||||
return (
|
|
||||||
StandardConvUpBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
),
|
|
||||||
config.out_channels,
|
|
||||||
input_height * 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "SelfAttention":
|
|
||||||
return (
|
|
||||||
SelfAttention(
|
|
||||||
in_channels=in_channels,
|
|
||||||
attention_channels=config.attention_channels,
|
|
||||||
temperature=config.temperature,
|
|
||||||
),
|
|
||||||
config.attention_channels,
|
|
||||||
input_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "LayerGroup":
|
|
||||||
current_channels = in_channels
|
|
||||||
current_height = input_height
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
|
|
||||||
for block_config in config.layers:
|
|
||||||
block, current_channels, current_height = build_layer_from_config(
|
|
||||||
input_height=current_height,
|
|
||||||
in_channels=current_channels,
|
|
||||||
config=block_config,
|
|
||||||
)
|
|
||||||
blocks.append(block)
|
|
||||||
|
|
||||||
return nn.Sequential(*blocks), current_channels, current_height
|
|
||||||
|
|
||||||
raise NotImplementedError(f"Unknown block type {config.name}")
|
|
||||||
|
|||||||
@ -22,9 +22,10 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
|
Block,
|
||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
VerticalConv,
|
VerticalConv,
|
||||||
build_layer_from_config,
|
build_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -34,7 +35,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
class Bottleneck(Block):
|
||||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||||
|
|
||||||
This implementation represents the simplest bottleneck structure
|
This implementation represents the simplest bottleneck structure
|
||||||
@ -86,6 +87,7 @@ class Bottleneck(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
self.bottleneck_channels = (
|
self.bottleneck_channels = (
|
||||||
bottleneck_channels
|
bottleneck_channels
|
||||||
if bottleneck_channels is not None
|
if bottleneck_channels is not None
|
||||||
@ -172,7 +174,7 @@ def build_bottleneck(
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: BottleneckConfig | None = None,
|
config: BottleneckConfig | None = None,
|
||||||
) -> nn.Module:
|
) -> Block:
|
||||||
"""Factory function to build the Bottleneck module from configuration.
|
"""Factory function to build the Bottleneck module from configuration.
|
||||||
|
|
||||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||||
@ -206,11 +208,13 @@ def build_bottleneck(
|
|||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for layer_config in config.layers:
|
for layer_config in config.layers:
|
||||||
layer, current_channels, current_height = build_layer_from_config(
|
layer = build_layer(
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
|
current_height = layer.get_output_height(current_height)
|
||||||
|
current_channels = layer.get_output_channels()
|
||||||
assert current_height == input_height, (
|
assert current_height == input_height, (
|
||||||
"Bottleneck layers should not change the spectrogram height"
|
"Bottleneck layers should not change the spectrogram height"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
|
|||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
build_layer_from_config,
|
build_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -259,11 +259,13 @@ def build_decoder(
|
|||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for layer_config in config.layers:
|
for layer_config in config.layers:
|
||||||
layer, current_channels, current_height = build_layer_from_config(
|
layer = build_layer(
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
|
current_height = layer.get_output_height(current_height)
|
||||||
|
current_channels = layer.get_output_channels()
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
|
|
||||||
return Decoder(
|
return Decoder(
|
||||||
|
|||||||
@ -32,7 +32,7 @@ from batdetect2.models.blocks import (
|
|||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
build_layer_from_config,
|
build_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -300,12 +300,14 @@ def build_encoder(
|
|||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for layer_config in config.layers:
|
for layer_config in config.layers:
|
||||||
layer, current_channels, current_height = build_layer_from_config(
|
layer = build_layer(
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
|
current_height = layer.get_output_height(current_height)
|
||||||
|
current_channels = layer.get_output_channels()
|
||||||
|
|
||||||
return Encoder(
|
return Encoder(
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
|
|||||||
@ -19,9 +19,9 @@ def create_ax(
|
|||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
|
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs)
|
||||||
|
|
||||||
return ax # type: ignore
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_spectrogram(
|
def plot_spectrogram(
|
||||||
|
|||||||
@ -66,7 +66,7 @@ def plot_clip_detections(
|
|||||||
if m.gt is not None:
|
if m.gt is not None:
|
||||||
color = gt_color if is_match else missed_gt_color
|
color = gt_color if is_match else missed_gt_color
|
||||||
plot_geometry(
|
plot_geometry(
|
||||||
m.gt.sound_event.geometry, # type: ignore
|
m.gt.sound_event.geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=False,
|
add_points=False,
|
||||||
linewidth=linewidth,
|
linewidth=linewidth,
|
||||||
|
|||||||
@ -100,7 +100,7 @@ def plot_classification_heatmap(
|
|||||||
class_heatmap,
|
class_heatmap,
|
||||||
vmax=1,
|
vmax=1,
|
||||||
vmin=0,
|
vmin=0,
|
||||||
cmap=create_colormap(color), # type: ignore
|
cmap=create_colormap(color),
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -301,4 +301,4 @@ def _get_marker_positions(
|
|||||||
size = len(thresholds)
|
size = len(thresholds)
|
||||||
cut_points = np.linspace(0, 1, n_points)
|
cut_points = np.linspace(0, 1, n_points)
|
||||||
indices = np.searchsorted(thresholds[::-1], cut_points)
|
indices = np.searchsorted(thresholds[::-1], cut_points)
|
||||||
return np.clip(size - indices, 0, size - 1) # type: ignore
|
return np.clip(size - indices, 0, size - 1)
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class TargetClassConfig(BaseConfig):
|
|||||||
|
|
||||||
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||||
name="bat",
|
name="bat",
|
||||||
match_if=AllOfConfig(
|
condition_input=AllOfConfig(
|
||||||
conditions=[
|
conditions=[
|
||||||
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
||||||
NotConfig(
|
NotConfig(
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.core import Registry
|
||||||
from batdetect2.core.arrays import spec_to_xarray
|
from batdetect2.core.arrays import spec_to_xarray
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
@ -88,6 +89,9 @@ DEFAULT_ANCHOR = "bottom-left"
|
|||||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||||
|
|
||||||
|
|
||||||
|
roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper")
|
||||||
|
|
||||||
|
|
||||||
class AnchorBBoxMapperConfig(BaseConfig):
|
class AnchorBBoxMapperConfig(BaseConfig):
|
||||||
"""Configuration for `AnchorBBoxMapper`.
|
"""Configuration for `AnchorBBoxMapper`.
|
||||||
|
|
||||||
@ -244,6 +248,15 @@ class AnchorBBoxMapper(ROITargetMapper):
|
|||||||
anchor=self.anchor,
|
anchor=self.anchor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@roi_mapper_registry.register(AnchorBBoxMapperConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: AnchorBBoxMapperConfig):
|
||||||
|
return AnchorBBoxMapper(
|
||||||
|
anchor=config.anchor,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PeakEnergyBBoxMapperConfig(BaseConfig):
|
class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||||
"""Configuration for `PeakEnergyBBoxMapper`.
|
"""Configuration for `PeakEnergyBBoxMapper`.
|
||||||
@ -412,6 +425,22 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@roi_mapper_registry.register(PeakEnergyBBoxMapperConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PeakEnergyBBoxMapperConfig):
|
||||||
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
preprocessor = build_preprocessor(
|
||||||
|
config.preprocessing,
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
|
return PeakEnergyBBoxMapper(
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
loading_buffer=config.loading_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ROIMapperConfig = Annotated[
|
ROIMapperConfig = Annotated[
|
||||||
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
|
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
|
||||||
@ -445,31 +474,7 @@ def build_roi_mapper(
|
|||||||
If the `name` in the config does not correspond to a known mapper.
|
If the `name` in the config does not correspond to a known mapper.
|
||||||
"""
|
"""
|
||||||
config = config or AnchorBBoxMapperConfig()
|
config = config or AnchorBBoxMapperConfig()
|
||||||
|
return roi_mapper_registry.build(config)
|
||||||
if config.name == "anchor_bbox":
|
|
||||||
return AnchorBBoxMapper(
|
|
||||||
anchor=config.anchor,
|
|
||||||
time_scale=config.time_scale,
|
|
||||||
frequency_scale=config.frequency_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "peak_energy_bbox":
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
config.preprocessing,
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
return PeakEnergyBBoxMapper(
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
time_scale=config.time_scale,
|
|
||||||
frequency_scale=config.frequency_scale,
|
|
||||||
loading_buffer=config.loading_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"No ROI mapper of name '{config.name}' is implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
VALID_ANCHORS = [
|
VALID_ANCHORS = [
|
||||||
@ -636,7 +641,7 @@ def get_peak_energy_coordinates(
|
|||||||
)
|
)
|
||||||
|
|
||||||
index = selection.argmax(dim=["time", "frequency"])
|
index = selection.argmax(dim=["time", "frequency"])
|
||||||
point = selection.isel(index) # type: ignore
|
point = selection.isel(index)
|
||||||
peak_time: float = point.time.item()
|
peak_time: float = point.time.item()
|
||||||
peak_freq: float = point.frequency.item()
|
peak_freq: float = point.frequency.item()
|
||||||
return peak_time, peak_freq
|
return peak_time, peak_freq
|
||||||
|
|||||||
@ -548,63 +548,6 @@ class MaybeApply(torch.nn.Module):
|
|||||||
return self.augmentation(tensor, clip_annotation)
|
return self.augmentation(tensor, clip_annotation)
|
||||||
|
|
||||||
|
|
||||||
def build_augmentation_from_config(
|
|
||||||
config: AugmentationConfig,
|
|
||||||
samplerate: int,
|
|
||||||
audio_source: AudioSource | None = None,
|
|
||||||
) -> Augmentation | None:
|
|
||||||
"""Factory function to build a single augmentation from its config."""
|
|
||||||
if config.name == "mix_audio":
|
|
||||||
if audio_source is None:
|
|
||||||
warnings.warn(
|
|
||||||
"Mix audio augmentation ('mix_audio') requires an "
|
|
||||||
"'example_source' callable to be provided.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return MixAudio(
|
|
||||||
example_source=audio_source,
|
|
||||||
min_weight=config.min_weight,
|
|
||||||
max_weight=config.max_weight,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "add_echo":
|
|
||||||
return AddEcho(
|
|
||||||
max_delay=int(config.max_delay * samplerate),
|
|
||||||
min_weight=config.min_weight,
|
|
||||||
max_weight=config.max_weight,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "scale_volume":
|
|
||||||
return ScaleVolume(
|
|
||||||
max_scaling=config.max_scaling,
|
|
||||||
min_scaling=config.min_scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "warp":
|
|
||||||
return Warp(
|
|
||||||
delta=config.delta,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "mask_time":
|
|
||||||
return MaskTime(
|
|
||||||
max_perc=config.max_perc,
|
|
||||||
max_masks=config.max_masks,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "mask_freq":
|
|
||||||
return MaskFrequency(
|
|
||||||
max_perc=config.max_perc,
|
|
||||||
max_masks=config.max_masks,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Invalid or unimplemented augmentation type: "
|
|
||||||
f"{config.augmentation_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
audio=[
|
audio=[
|
||||||
|
|||||||
@ -61,8 +61,8 @@ class TrainingDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.clip_annotations)
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, index) -> TrainExample:
|
||||||
clip_annotation = self.clip_annotations[idx]
|
clip_annotation = self.clip_annotations[index]
|
||||||
|
|
||||||
if self.clipper is not None:
|
if self.clipper is not None:
|
||||||
clip_annotation = self.clipper(clip_annotation)
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
@ -95,7 +95,7 @@ class TrainingDataset(Dataset):
|
|||||||
detection_heatmap=heatmaps.detection,
|
detection_heatmap=heatmaps.detection,
|
||||||
class_heatmap=heatmaps.classes,
|
class_heatmap=heatmaps.classes,
|
||||||
size_heatmap=heatmaps.size,
|
size_heatmap=heatmaps.size,
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(index),
|
||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
@ -121,8 +121,8 @@ class ValidationDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.clip_annotations)
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, index) -> TrainExample:
|
||||||
clip_annotation = self.clip_annotations[idx]
|
clip_annotation = self.clip_annotations[index]
|
||||||
|
|
||||||
if self.clipper is not None:
|
if self.clipper is not None:
|
||||||
clip_annotation = self.clipper(clip_annotation)
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
@ -141,7 +141,7 @@ class ValidationDataset(Dataset):
|
|||||||
detection_heatmap=heatmaps.detection,
|
detection_heatmap=heatmaps.detection,
|
||||||
class_heatmap=heatmaps.classes,
|
class_heatmap=heatmaps.classes,
|
||||||
size_heatmap=heatmaps.size,
|
size_heatmap=heatmaps.size,
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(index),
|
||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -472,7 +472,7 @@ def build_loss(
|
|||||||
|
|
||||||
size_loss_fn = BBoxLoss()
|
size_loss_fn = BBoxLoss()
|
||||||
|
|
||||||
return LossFunction( # type: ignore
|
return LossFunction(
|
||||||
size_loss=size_loss_fn,
|
size_loss=size_loss_fn,
|
||||||
classification_loss=classification_loss_fn,
|
classification_loss=classification_loss_fn,
|
||||||
detection_loss=detection_loss_fn,
|
detection_loss=detection_loss_fn,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user