From d52e988b8fb2368f811d83fe80bd74ab25e7d236 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 8 Mar 2026 12:55:36 +0000 Subject: [PATCH] Fix type errors --- justfile | 53 +-- pyproject.toml | 19 +- src/batdetect2/audio/loader.py | 8 +- src/batdetect2/compat/data.py | 22 +- src/batdetect2/data/predictions/batdetect2.py | 2 +- src/batdetect2/data/split.py | 10 +- src/batdetect2/data/summary.py | 6 +- src/batdetect2/data/transforms.py | 7 +- src/batdetect2/evaluate/affinity.py | 16 +- src/batdetect2/evaluate/dataset.py | 6 +- src/batdetect2/evaluate/tasks/base.py | 2 +- src/batdetect2/inference/dataset.py | 6 +- src/batdetect2/models/backbones.py | 2 +- src/batdetect2/models/blocks.py | 332 ++++++++++++------ src/batdetect2/models/bottleneck.py | 12 +- src/batdetect2/models/decoder.py | 6 +- src/batdetect2/models/encoder.py | 6 +- src/batdetect2/plotting/common.py | 4 +- src/batdetect2/plotting/detections.py | 2 +- src/batdetect2/plotting/heatmaps.py | 2 +- src/batdetect2/plotting/metrics.py | 2 +- src/batdetect2/targets/classes.py | 2 +- src/batdetect2/targets/rois.py | 57 +-- src/batdetect2/train/augmentations.py | 57 --- src/batdetect2/train/dataset.py | 12 +- src/batdetect2/train/losses.py | 2 +- 26 files changed, 371 insertions(+), 284 deletions(-) diff --git a/justfile b/justfile index 97016a4..d0ba419 100644 --- a/justfile +++ b/justfile @@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov" help: @just --list +install: + uv sync + # Testing & Coverage # Run tests using pytest. test: - pytest {{TESTS_DIR}} + uv run pytest {{TESTS_DIR}} # Run tests and generate coverage data. coverage: - pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}} + uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}} # Generate an HTML coverage report. coverage-html: coverage @echo "Generating HTML coverage report..." - coverage html -d {{HTML_COVERAGE_DIR}} + uv run coverage html -d {{HTML_COVERAGE_DIR}} @echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/" # Serve the HTML coverage report locally. coverage-serve: coverage-html @echo "Serving report at http://localhost:8000/ ..." - python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000 + uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000 # Documentation # Build documentation using Sphinx. docs: - sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}} + uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}} # Serve documentation with live reload. docs-serve: - sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser + uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser # Formatting & Linting # Format code using ruff. -format: - ruff format {{PYTHON_DIRS}} - -# Check code formatting using ruff. -format-check: - ruff format --check {{PYTHON_DIRS}} - -# Lint code using ruff. -lint: - ruff check {{PYTHON_DIRS}} +fix-format: + uv run ruff format {{PYTHON_DIRS}} # Lint code using ruff and apply automatic fixes. -lint-fix: - ruff check --fix {{PYTHON_DIRS}} +fix-lint: + uv run ruff check --fix {{PYTHON_DIRS}} + +# Combined Formatting & Linting +fix: fix-format fix-lint + +# Checking tasks +# Check code formatting using ruff. +check-format: + uv run ruff format --check {{PYTHON_DIRS}} + +# Lint code using ruff. +check-lint: + uv run ruff check {{PYTHON_DIRS}} # Type Checking -# Type check code using pyright. -typecheck: - pyright {{PYTHON_DIRS}} +# Type check code using ty. +check-types: + uv run ty check {{PYTHON_DIRS}} # Combined Checks # Run all checks (format-check, lint, typecheck). -check: format-check lint typecheck test +check: check-format check-lint check-types # Cleaning tasks # Remove Python bytecode and cache. @@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs # Train on example data. example-train OPTIONS="": - batdetect2 train \ + uv run batdetect2 train \ --val-dataset example_data/dataset.yaml \ --config example_data/config.yaml \ {{OPTIONS}} \ diff --git a/pyproject.toml b/pyproject.toml index a0ad77c..03f3359 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,6 @@ dev = [ "ruff>=0.7.3", "ipykernel>=6.29.4", "setuptools>=69.5.1", - "basedpyright>=1.31.0", "myst-parser>=3.0.1", "sphinx-autobuild>=2024.10.3", "numpydoc>=1.8.0", @@ -94,6 +93,12 @@ mlflow = ["mlflow>=3.1.1"] [tool.ruff] line-length = 79 target-version = "py310" +exclude = [ + "src/batdetect2/train/legacy", + "src/batdetect2/plotting/legacy", + "src/batdetect2/evaluate/legacy", + "src/batdetect2/finetune", +] [tool.ruff.format] docstring-code-format = true @@ -105,15 +110,11 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"] [tool.ruff.lint.pydocstyle] convention = "numpy" -[tool.pyright] +[tool.ty.src] include = ["src", "tests"] -pythonVersion = "3.10" -pythonPlatform = "All" exclude = [ - "src/batdetect2/detector/", - "src/batdetect2/finetune", - "src/batdetect2/utils", - "src/batdetect2/plot", - "src/batdetect2/evaluate/legacy", "src/batdetect2/train/legacy", + "src/batdetect2/plotting/legacy", + "src/batdetect2/evaluate/legacy", + "src/batdetect2/finetune", ] diff --git a/src/batdetect2/audio/loader.py b/src/batdetect2/audio/loader.py index e8c46a0..c2536cf 100644 --- a/src/batdetect2/audio/loader.py +++ b/src/batdetect2/audio/loader.py @@ -113,7 +113,7 @@ def load_file_audio( samplerate: int | None = None, config: ResampleConfig | None = None, audio_dir: data.PathLike | None = None, - dtype: DTypeLike = np.float32, # type: ignore + dtype: DTypeLike = np.float32, ) -> np.ndarray: """Load and preprocess audio from a file path using specified config.""" try: @@ -137,7 +137,7 @@ def load_recording_audio( samplerate: int | None = None, config: ResampleConfig | None = None, audio_dir: data.PathLike | None = None, - dtype: DTypeLike = np.float32, # type: ignore + dtype: DTypeLike = np.float32, ) -> np.ndarray: """Load and preprocess the entire audio content of a recording using config.""" clip = data.Clip( @@ -159,7 +159,7 @@ def load_clip_audio( samplerate: int | None = None, config: ResampleConfig | None = None, audio_dir: data.PathLike | None = None, - dtype: DTypeLike = np.float32, # type: ignore + dtype: DTypeLike = np.float32, ) -> np.ndarray: """Load and preprocess a specific audio clip segment based on config.""" try: @@ -286,7 +286,7 @@ def resample_audio_fourier( If `num` is negative. """ ratio = sr_new / sr_orig - return resample( # type: ignore + return resample( array, int(array.shape[axis] * ratio), axis=axis, diff --git a/src/batdetect2/compat/data.py b/src/batdetect2/compat/data.py index 9cb540a..8803e31 100644 --- a/src/batdetect2/compat/data.py +++ b/src/batdetect2/compat/data.py @@ -103,17 +103,17 @@ def convert_to_annotation_group( y_inds.append(0) annotations.append( - { - "start_time": start_time, - "end_time": end_time, - "low_freq": low_freq, - "high_freq": high_freq, - "class_prob": 1.0, - "det_prob": 1.0, - "individual": "0", - "event": event, - "class_id": class_id, # type: ignore - } + Annotation( + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + class_prob=1.0, + det_prob=1.0, + individual="0", + event=event, + class_id=class_id, + ) ) return { diff --git a/src/batdetect2/data/predictions/batdetect2.py b/src/batdetect2/data/predictions/batdetect2.py index 6dcdc63..4a76541 100644 --- a/src/batdetect2/data/predictions/batdetect2.py +++ b/src/batdetect2/data/predictions/batdetect2.py @@ -89,7 +89,7 @@ class FileAnnotation(TypedDict): annotation: List[Annotation] """List of annotations.""" - file_path: NotRequired[str] + file_path: NotRequired[str] # ty: ignore[invalid-type-form] """Path to file.""" diff --git a/src/batdetect2/data/split.py b/src/batdetect2/data/split.py index 22bc6f1..47d3bd6 100644 --- a/src/batdetect2/data/split.py +++ b/src/batdetect2/data/split.py @@ -26,15 +26,15 @@ def split_dataset_by_recordings( ) majority_class = ( - sound_events.groupby("recording_id") + sound_events.groupby("recording_id") # type: ignore .apply( lambda group: ( - group["class_name"] # type: ignore + group["class_name"] .value_counts() .sort_values(ascending=False) .index[0] ), - include_groups=False, # type: ignore + include_groups=False, ) .rename("class_name") .to_frame() @@ -48,8 +48,8 @@ def split_dataset_by_recordings( random_state=random_state, ) - train_ids_set = set(train.values) # type: ignore - test_ids_set = set(test.values) # type: ignore + train_ids_set = set(train.values) + test_ids_set = set(test.values) extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set diff --git a/src/batdetect2/data/summary.py b/src/batdetect2/data/summary.py index d550d9d..1db0948 100644 --- a/src/batdetect2/data/summary.py +++ b/src/batdetect2/data/summary.py @@ -175,14 +175,14 @@ def compute_class_summary( .rename("num recordings") ) durations = ( - sound_events.groupby("class_name") + sound_events.groupby("class_name") # ty: ignore[no-matching-overload] .apply( lambda group: recordings[ recordings["clip_annotation_id"].isin( - group["clip_annotation_id"] # type: ignore + group["clip_annotation_id"] ) ]["duration"].sum(), - include_groups=False, # type: ignore + include_groups=False, ) .sort_values(ascending=False) .rename("duration") diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py index c8c99cd..0e9b5d2 100644 --- a/src/batdetect2/data/transforms.py +++ b/src/batdetect2/data/transforms.py @@ -176,12 +176,7 @@ class MapTagValue: if self.target_key is None: tags.append(tag.model_copy(update=dict(value=value))) else: - tags.append( - data.Tag( - key=self.target_key, # type: ignore - value=value, - ) - ) + tags.append(data.Tag(key=self.target_key, value=value)) return sound_event_annotation.model_copy(update=dict(tags=tags)) diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index 976cb89..a6a4d6b 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -120,11 +120,11 @@ class BBoxIOU(AffinityFunction): def __call__( self, - prediction: Detection, - gt: data.SoundEventAnnotation, + detection: Detection, + ground_truth: data.SoundEventAnnotation, ): - target_geometry = gt.sound_event.geometry - source_geometry = prediction.geometry + target_geometry = ground_truth.sound_event.geometry + source_geometry = detection.geometry if self.time_buffer > 0 or self.freq_buffer > 0: target_geometry = buffer_geometry( @@ -168,11 +168,11 @@ class GeometricIOU(AffinityFunction): def __call__( self, - prediction: Detection, - gt: data.SoundEventAnnotation, + detection: Detection, + ground_truth: data.SoundEventAnnotation, ): - target_geometry = gt.sound_event.geometry - source_geometry = prediction.geometry + target_geometry = ground_truth.sound_event.geometry + source_geometry = detection.geometry if self.time_buffer > 0 or self.freq_buffer > 0: target_geometry = buffer_geometry( diff --git a/src/batdetect2/evaluate/dataset.py b/src/batdetect2/evaluate/dataset.py index c029636..9bf106a 100644 --- a/src/batdetect2/evaluate/dataset.py +++ b/src/batdetect2/evaluate/dataset.py @@ -51,8 +51,8 @@ class TestDataset(Dataset[TestExample]): def __len__(self): return len(self.clip_annotations) - def __getitem__(self, idx: int) -> TestExample: - clip_annotation = self.clip_annotations[idx] + def __getitem__(self, index: int) -> TestExample: + clip_annotation = self.clip_annotations[index] if self.clipper is not None: clip_annotation = self.clipper(clip_annotation) @@ -63,7 +63,7 @@ class TestDataset(Dataset[TestExample]): spectrogram = self.preprocessor(wav_tensor) return TestExample( spec=spectrogram, - idx=torch.tensor(idx), + idx=torch.tensor(index), start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 2c0c941..bcfe38f 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): self, clip_annotation: data.ClipAnnotation, prediction: ClipDetections, - ) -> T_Output: ... + ) -> T_Output: ... # ty: ignore[empty-body] def include_sound_event_annotation( self, diff --git a/src/batdetect2/inference/dataset.py b/src/batdetect2/inference/dataset.py index cf6ad09..62dcd36 100644 --- a/src/batdetect2/inference/dataset.py +++ b/src/batdetect2/inference/dataset.py @@ -46,14 +46,14 @@ class InferenceDataset(Dataset[DatasetItem]): def __len__(self): return len(self.clips) - def __getitem__(self, idx: int) -> DatasetItem: - clip = self.clips[idx] + def __getitem__(self, index: int) -> DatasetItem: + clip = self.clips[index] wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) wav_tensor = torch.tensor(wav).unsqueeze(0) spectrogram = self.preprocessor(wav_tensor) return DatasetItem( spec=spectrogram, - idx=torch.tensor(idx), + idx=torch.tensor(index), start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index fd55c80..08b665c 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -185,7 +185,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel: ) decoder = build_decoder( - in_channels=bottleneck.out_channels, + in_channels=bottleneck.get_output_channels(), input_height=encoder.output_height, config=config.decoder, ) diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 1e39031..3c64e2c 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -20,20 +20,21 @@ research: spatial frequency information to filters, potentially enabling them to learn frequency-dependent patterns more effectively. -These blocks can be utilized directly in custom PyTorch model definitions or +These blocks can be used directly in custom PyTorch model definitions or assembled into larger architectures. A unified factory function `build_layer_from_config` allows creating instances of these blocks based on configuration objects. """ -from typing import Annotated, List, Literal, Tuple, Union +from typing import Annotated, List, Literal, Protocol, Tuple, Union import torch import torch.nn.functional as F from pydantic import Field from torch import nn +from batdetect2.core import Registry from batdetect2.core.configs import BaseConfig __all__ = [ @@ -51,17 +52,30 @@ __all__ = [ "FreqCoordConvUpConfig", "StandardConvUpConfig", "LayerConfig", - "build_layer_from_config", + "build_layer", ] +class BlockProtocol(Protocol): + def get_output_channels(self) -> int: ... + + def get_output_height(self, input_height: int) -> int: + return input_height + + +class Block(nn.Module, BlockProtocol): ... + + +block_registry: Registry[Block, [int, int]] = Registry("block") + + class SelfAttentionConfig(BaseConfig): name: Literal["SelfAttention"] = "SelfAttention" attention_channels: int temperature: float = 1 -class SelfAttention(nn.Module): +class SelfAttention(Block): """Self-Attention mechanism operating along the time dimension. This module implements a scaled dot-product self-attention mechanism, @@ -121,6 +135,7 @@ class SelfAttention(nn.Module): # Note, does not encode position information (absolute or relative) self.temperature = temperature self.att_dim = attention_channels + self.output_channels = in_channels self.key_fun = nn.Linear(in_channels, attention_channels) self.value_fun = nn.Linear(in_channels, attention_channels) @@ -190,6 +205,22 @@ class SelfAttention(nn.Module): att_weights = F.softmax(kk_qq, 1) return att_weights + def get_output_channels(self) -> int: + return self.output_channels + + @block_registry.register(SelfAttentionConfig) + @staticmethod + def from_config( + config: SelfAttentionConfig, + input_channels: int, + input_height: int, + ) -> "SelfAttention": + return SelfAttention( + in_channels=input_channels, + attention_channels=config.attention_channels, + temperature=config.temperature, + ) + class ConvConfig(BaseConfig): """Configuration for a basic ConvBlock.""" @@ -207,7 +238,7 @@ class ConvConfig(BaseConfig): """Padding size.""" -class ConvBlock(nn.Module): +class ConvBlock(Block): """Basic Convolutional Block. A standard building block consisting of a 2D convolution, followed by @@ -258,8 +289,30 @@ class ConvBlock(nn.Module): """ return F.relu_(self.batch_norm(self.conv(x))) + def get_output_channels(self) -> int: + return self.conv.out_channels -class VerticalConv(nn.Module): + @block_registry.register(ConvConfig) + @staticmethod + def from_config( + config: ConvConfig, + input_channels: int, + input_height: int, + ): + return ConvBlock( + in_channels=input_channels, + out_channels=config.out_channels, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ) + + +class VerticalConvConfig(BaseConfig): + name: Literal["VerticalConv"] = "VerticalConv" + channels: int + + +class VerticalConv(Block): """Convolutional layer that aggregates features across the entire height. Applies a 2D convolution using a kernel with shape `(input_height, 1)`. @@ -312,6 +365,22 @@ class VerticalConv(nn.Module): """ return F.relu_(self.bn(self.conv(x))) + def get_output_channels(self) -> int: + return self.conv.out_channels + + @block_registry.register(VerticalConvConfig) + @staticmethod + def from_config( + config: VerticalConvConfig, + input_channels: int, + input_height: int, + ): + return VerticalConv( + in_channels=input_channels, + out_channels=config.channels, + input_height=input_height, + ) + class FreqCoordConvDownConfig(BaseConfig): """Configuration for a FreqCoordConvDownBlock.""" @@ -329,7 +398,7 @@ class FreqCoordConvDownConfig(BaseConfig): """Padding size.""" -class FreqCoordConvDownBlock(nn.Module): +class FreqCoordConvDownBlock(Block): """Downsampling Conv Block incorporating Frequency Coordinate features. This block implements a downsampling step (Conv2d + MaxPool2d) commonly @@ -402,6 +471,27 @@ class FreqCoordConvDownBlock(nn.Module): x = F.relu(self.batch_norm(x), inplace=True) return x + def get_output_channels(self) -> int: + return self.conv.out_channels + + def get_output_height(self, input_height: int) -> int: + return input_height // 2 + + @block_registry.register(FreqCoordConvDownConfig) + @staticmethod + def from_config( + config: FreqCoordConvDownConfig, + input_channels: int, + input_height: int, + ): + return FreqCoordConvDownBlock( + in_channels=input_channels, + out_channels=config.out_channels, + input_height=input_height, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ) + class StandardConvDownConfig(BaseConfig): """Configuration for a StandardConvDownBlock.""" @@ -419,7 +509,7 @@ class StandardConvDownConfig(BaseConfig): """Padding size.""" -class StandardConvDownBlock(nn.Module): +class StandardConvDownBlock(Block): """Standard Downsampling Convolutional Block. A basic downsampling block consisting of a 2D convolution, followed by @@ -472,6 +562,26 @@ class StandardConvDownBlock(nn.Module): x = F.max_pool2d(self.conv(x), 2, 2) return F.relu(self.batch_norm(x), inplace=True) + def get_output_channels(self) -> int: + return self.conv.out_channels + + def get_output_height(self, input_height: int) -> int: + return input_height // 2 + + @block_registry.register(StandardConvDownConfig) + @staticmethod + def from_config( + config: StandardConvDownConfig, + input_channels: int, + input_height: int, + ): + return StandardConvDownBlock( + in_channels=input_channels, + out_channels=config.out_channels, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ) + class FreqCoordConvUpConfig(BaseConfig): """Configuration for a FreqCoordConvUpBlock.""" @@ -488,8 +598,14 @@ class FreqCoordConvUpConfig(BaseConfig): pad_size: int = 1 """Padding size.""" + up_mode: str = "bilinear" + """Interpolation mode for upsampling (e.g., "nearest", "bilinear").""" -class FreqCoordConvUpBlock(nn.Module): + up_scale: Tuple[int, int] = (2, 2) + """Scaling factor for height and width during upsampling.""" + + +class FreqCoordConvUpBlock(Block): """Upsampling Conv Block incorporating Frequency Coordinate features. This block implements an upsampling step followed by a convolution, @@ -581,6 +697,29 @@ class FreqCoordConvUpBlock(nn.Module): op = F.relu(self.batch_norm(op), inplace=True) return op + def get_output_channels(self) -> int: + return self.conv.out_channels + + def get_output_height(self, input_height: int) -> int: + return input_height * 2 + + @block_registry.register(FreqCoordConvUpConfig) + @staticmethod + def from_config( + config: FreqCoordConvUpConfig, + input_channels: int, + input_height: int, + ): + return FreqCoordConvUpBlock( + in_channels=input_channels, + out_channels=config.out_channels, + input_height=input_height, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + up_mode=config.up_mode, + up_scale=config.up_scale, + ) + class StandardConvUpConfig(BaseConfig): """Configuration for a StandardConvUpBlock.""" @@ -597,8 +736,14 @@ class StandardConvUpConfig(BaseConfig): pad_size: int = 1 """Padding size.""" + up_mode: str = "bilinear" + """Interpolation mode for upsampling (e.g., "nearest", "bilinear").""" -class StandardConvUpBlock(nn.Module): + up_scale: Tuple[int, int] = (2, 2) + """Scaling factor for height and width during upsampling.""" + + +class StandardConvUpBlock(Block): """Standard Upsampling Convolutional Block. A basic upsampling block used in CNN decoders. It first upsamples the input @@ -669,6 +814,28 @@ class StandardConvUpBlock(nn.Module): op = F.relu(self.batch_norm(op), inplace=True) return op + def get_output_channels(self) -> int: + return self.conv.out_channels + + def get_output_height(self, input_height: int) -> int: + return input_height * 2 + + @block_registry.register(StandardConvUpConfig) + @staticmethod + def from_config( + config: StandardConvUpConfig, + input_channels: int, + input_height: int, + ): + return StandardConvUpBlock( + in_channels=input_channels, + out_channels=config.out_channels, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + up_mode=config.up_mode, + up_scale=config.up_scale, + ) + LayerConfig = Annotated[ Union[ @@ -690,11 +857,61 @@ class LayerGroupConfig(BaseConfig): layers: List[LayerConfig] -def build_layer_from_config( +class LayerGroup(nn.Module): + """Standard implementation of the `LayerGroup` architecture.""" + + def __init__( + self, + layers: list[Block], + input_height: int, + input_channels: int, + ): + super().__init__() + self.blocks = layers + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + def get_output_channels(self) -> int: + return self.blocks[-1].get_output_channels() + + def get_output_height(self, input_height: int) -> int: + for block in self.blocks: + input_height = block.get_output_height(input_height) + return input_height + + @block_registry.register(LayerGroupConfig) + @staticmethod + def from_config( + config: LayerGroupConfig, + input_height: int, + input_channels: int, + ): + layers = [] + + for layer_config in config.layers: + layer = build_layer( + input_height=input_height, + in_channels=input_channels, + config=layer_config, + ) + layers.append(layer) + input_height = layer.get_output_height(input_height) + input_channels = layer.get_output_channels() + + return LayerGroup( + layers=layers, + input_height=input_height, + input_channels=input_channels, + ) + + +def build_layer( input_height: int, in_channels: int, config: LayerConfig, -) -> Tuple[nn.Module, int, int]: +) -> Block: """Factory function to build a specific nn.Module block from its config. Takes configuration object (one of the types included in the `LayerConfig` @@ -731,93 +948,4 @@ def build_layer_from_config( ValueError If parameters derived from the config are invalid for the block. """ - if config.name == "ConvBlock": - return ( - ConvBlock( - in_channels=in_channels, - out_channels=config.out_channels, - kernel_size=config.kernel_size, - pad_size=config.pad_size, - ), - config.out_channels, - input_height, - ) - - if config.name == "FreqCoordConvDown": - return ( - FreqCoordConvDownBlock( - in_channels=in_channels, - out_channels=config.out_channels, - input_height=input_height, - kernel_size=config.kernel_size, - pad_size=config.pad_size, - ), - config.out_channels, - input_height // 2, - ) - - if config.name == "StandardConvDown": - return ( - StandardConvDownBlock( - in_channels=in_channels, - out_channels=config.out_channels, - kernel_size=config.kernel_size, - pad_size=config.pad_size, - ), - config.out_channels, - input_height // 2, - ) - - if config.name == "FreqCoordConvUp": - return ( - FreqCoordConvUpBlock( - in_channels=in_channels, - out_channels=config.out_channels, - input_height=input_height, - kernel_size=config.kernel_size, - pad_size=config.pad_size, - ), - config.out_channels, - input_height * 2, - ) - - if config.name == "StandardConvUp": - return ( - StandardConvUpBlock( - in_channels=in_channels, - out_channels=config.out_channels, - kernel_size=config.kernel_size, - pad_size=config.pad_size, - ), - config.out_channels, - input_height * 2, - ) - - if config.name == "SelfAttention": - return ( - SelfAttention( - in_channels=in_channels, - attention_channels=config.attention_channels, - temperature=config.temperature, - ), - config.attention_channels, - input_height, - ) - - if config.name == "LayerGroup": - current_channels = in_channels - current_height = input_height - - blocks = [] - - for block_config in config.layers: - block, current_channels, current_height = build_layer_from_config( - input_height=current_height, - in_channels=current_channels, - config=block_config, - ) - blocks.append(block) - - return nn.Sequential(*blocks), current_channels, current_height - - raise NotImplementedError(f"Unknown block type {config.name}") + return block_registry.build(config, in_channels, input_height) diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index a04b18d..82f3971 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -22,9 +22,10 @@ from torch import nn from batdetect2.core.configs import BaseConfig from batdetect2.models.blocks import ( + Block, SelfAttentionConfig, VerticalConv, - build_layer_from_config, + build_layer, ) __all__ = [ @@ -34,7 +35,7 @@ __all__ = [ ] -class Bottleneck(nn.Module): +class Bottleneck(Block): """Base Bottleneck module for Encoder-Decoder architectures. This implementation represents the simplest bottleneck structure @@ -86,6 +87,7 @@ class Bottleneck(nn.Module): self.in_channels = in_channels self.input_height = input_height self.out_channels = out_channels + self.bottleneck_channels = ( bottleneck_channels if bottleneck_channels is not None @@ -172,7 +174,7 @@ def build_bottleneck( input_height: int, in_channels: int, config: BottleneckConfig | None = None, -) -> nn.Module: +) -> Block: """Factory function to build the Bottleneck module from configuration. Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based @@ -206,11 +208,13 @@ def build_bottleneck( layers = [] for layer_config in config.layers: - layer, current_channels, current_height = build_layer_from_config( + layer = build_layer( input_height=current_height, in_channels=current_channels, config=layer_config, ) + current_height = layer.get_output_height(current_height) + current_channels = layer.get_output_channels() assert current_height == input_height, ( "Bottleneck layers should not change the spectrogram height" ) diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index e7c3e91..8275cb5 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -30,7 +30,7 @@ from batdetect2.models.blocks import ( FreqCoordConvUpConfig, LayerGroupConfig, StandardConvUpConfig, - build_layer_from_config, + build_layer, ) __all__ = [ @@ -259,11 +259,13 @@ def build_decoder( layers = [] for layer_config in config.layers: - layer, current_channels, current_height = build_layer_from_config( + layer = build_layer( in_channels=current_channels, input_height=current_height, config=layer_config, ) + current_height = layer.get_output_height(current_height) + current_channels = layer.get_output_channels() layers.append(layer) return Decoder( diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index a302992..a9a0173 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -32,7 +32,7 @@ from batdetect2.models.blocks import ( FreqCoordConvDownConfig, LayerGroupConfig, StandardConvDownConfig, - build_layer_from_config, + build_layer, ) __all__ = [ @@ -300,12 +300,14 @@ def build_encoder( layers = [] for layer_config in config.layers: - layer, current_channels, current_height = build_layer_from_config( + layer = build_layer( in_channels=current_channels, input_height=current_height, config=layer_config, ) layers.append(layer) + current_height = layer.get_output_height(current_height) + current_channels = layer.get_output_channels() return Encoder( input_height=input_height, diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index 0bf6eca..286cab4 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -19,9 +19,9 @@ def create_ax( ) -> axes.Axes: """Create a new axis if none is provided""" if ax is None: - _, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore + _, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) - return ax # type: ignore + return ax def plot_spectrogram( diff --git a/src/batdetect2/plotting/detections.py b/src/batdetect2/plotting/detections.py index 797da6d..4bbaf27 100644 --- a/src/batdetect2/plotting/detections.py +++ b/src/batdetect2/plotting/detections.py @@ -66,7 +66,7 @@ def plot_clip_detections( if m.gt is not None: color = gt_color if is_match else missed_gt_color plot_geometry( - m.gt.sound_event.geometry, # type: ignore + m.gt.sound_event.geometry, ax=ax, add_points=False, linewidth=linewidth, diff --git a/src/batdetect2/plotting/heatmaps.py b/src/batdetect2/plotting/heatmaps.py index 9376590..2fd9187 100644 --- a/src/batdetect2/plotting/heatmaps.py +++ b/src/batdetect2/plotting/heatmaps.py @@ -100,7 +100,7 @@ def plot_classification_heatmap( class_heatmap, vmax=1, vmin=0, - cmap=create_colormap(color), # type: ignore + cmap=create_colormap(color), alpha=alpha, ) diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py index 30f8244..78c73d2 100644 --- a/src/batdetect2/plotting/metrics.py +++ b/src/batdetect2/plotting/metrics.py @@ -301,4 +301,4 @@ def _get_marker_positions( size = len(thresholds) cut_points = np.linspace(0, 1, n_points) indices = np.searchsorted(thresholds[::-1], cut_points) - return np.clip(size - indices, 0, size - 1) # type: ignore + return np.clip(size - indices, 0, size - 1) diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 30fd021..59dc6c3 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -72,7 +72,7 @@ class TargetClassConfig(BaseConfig): DEFAULT_DETECTION_CLASS = TargetClassConfig( name="bat", - match_if=AllOfConfig( + condition_input=AllOfConfig( conditions=[ HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")), NotConfig( diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 672f90b..4967710 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -27,6 +27,7 @@ from pydantic import Field from soundevent import data from batdetect2.audio import AudioConfig, build_audio_loader +from batdetect2.core import Registry from batdetect2.core.arrays import spec_to_xarray from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor @@ -88,6 +89,9 @@ DEFAULT_ANCHOR = "bottom-left" """Default reference position within the geometry ('bottom-left' corner).""" +roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper") + + class AnchorBBoxMapperConfig(BaseConfig): """Configuration for `AnchorBBoxMapper`. @@ -244,6 +248,15 @@ class AnchorBBoxMapper(ROITargetMapper): anchor=self.anchor, ) + @roi_mapper_registry.register(AnchorBBoxMapperConfig) + @staticmethod + def from_config(config: AnchorBBoxMapperConfig): + return AnchorBBoxMapper( + anchor=config.anchor, + time_scale=config.time_scale, + frequency_scale=config.frequency_scale, + ) + class PeakEnergyBBoxMapperConfig(BaseConfig): """Configuration for `PeakEnergyBBoxMapper`. @@ -412,6 +425,22 @@ class PeakEnergyBBoxMapper(ROITargetMapper): ] ) + @roi_mapper_registry.register(PeakEnergyBBoxMapperConfig) + @staticmethod + def from_config(config: PeakEnergyBBoxMapperConfig): + audio_loader = build_audio_loader(config=config.audio) + preprocessor = build_preprocessor( + config.preprocessing, + input_samplerate=audio_loader.samplerate, + ) + return PeakEnergyBBoxMapper( + preprocessor=preprocessor, + audio_loader=audio_loader, + time_scale=config.time_scale, + frequency_scale=config.frequency_scale, + loading_buffer=config.loading_buffer, + ) + ROIMapperConfig = Annotated[ AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig, @@ -445,31 +474,7 @@ def build_roi_mapper( If the `name` in the config does not correspond to a known mapper. """ config = config or AnchorBBoxMapperConfig() - - if config.name == "anchor_bbox": - return AnchorBBoxMapper( - anchor=config.anchor, - time_scale=config.time_scale, - frequency_scale=config.frequency_scale, - ) - - if config.name == "peak_energy_bbox": - audio_loader = build_audio_loader(config=config.audio) - preprocessor = build_preprocessor( - config.preprocessing, - input_samplerate=audio_loader.samplerate, - ) - return PeakEnergyBBoxMapper( - preprocessor=preprocessor, - audio_loader=audio_loader, - time_scale=config.time_scale, - frequency_scale=config.frequency_scale, - loading_buffer=config.loading_buffer, - ) - - raise NotImplementedError( - f"No ROI mapper of name '{config.name}' is implemented" - ) + return roi_mapper_registry.build(config) VALID_ANCHORS = [ @@ -636,7 +641,7 @@ def get_peak_energy_coordinates( ) index = selection.argmax(dim=["time", "frequency"]) - point = selection.isel(index) # type: ignore + point = selection.isel(index) peak_time: float = point.time.item() peak_freq: float = point.frequency.item() return peak_time, peak_freq diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index ba2e9ca..67a5490 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -548,63 +548,6 @@ class MaybeApply(torch.nn.Module): return self.augmentation(tensor, clip_annotation) -def build_augmentation_from_config( - config: AugmentationConfig, - samplerate: int, - audio_source: AudioSource | None = None, -) -> Augmentation | None: - """Factory function to build a single augmentation from its config.""" - if config.name == "mix_audio": - if audio_source is None: - warnings.warn( - "Mix audio augmentation ('mix_audio') requires an " - "'example_source' callable to be provided.", - stacklevel=2, - ) - return None - - return MixAudio( - example_source=audio_source, - min_weight=config.min_weight, - max_weight=config.max_weight, - ) - - if config.name == "add_echo": - return AddEcho( - max_delay=int(config.max_delay * samplerate), - min_weight=config.min_weight, - max_weight=config.max_weight, - ) - - if config.name == "scale_volume": - return ScaleVolume( - max_scaling=config.max_scaling, - min_scaling=config.min_scaling, - ) - - if config.name == "warp": - return Warp( - delta=config.delta, - ) - - if config.name == "mask_time": - return MaskTime( - max_perc=config.max_perc, - max_masks=config.max_masks, - ) - - if config.name == "mask_freq": - return MaskFrequency( - max_perc=config.max_perc, - max_masks=config.max_masks, - ) - - raise NotImplementedError( - "Invalid or unimplemented augmentation type: " - f"{config.augmentation_type}" - ) - - DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( enabled=True, audio=[ diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 309b2a6..34898aa 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -61,8 +61,8 @@ class TrainingDataset(Dataset): def __len__(self): return len(self.clip_annotations) - def __getitem__(self, idx) -> TrainExample: - clip_annotation = self.clip_annotations[idx] + def __getitem__(self, index) -> TrainExample: + clip_annotation = self.clip_annotations[index] if self.clipper is not None: clip_annotation = self.clipper(clip_annotation) @@ -95,7 +95,7 @@ class TrainingDataset(Dataset): detection_heatmap=heatmaps.detection, class_heatmap=heatmaps.classes, size_heatmap=heatmaps.size, - idx=torch.tensor(idx), + idx=torch.tensor(index), start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) @@ -121,8 +121,8 @@ class ValidationDataset(Dataset): def __len__(self): return len(self.clip_annotations) - def __getitem__(self, idx) -> TrainExample: - clip_annotation = self.clip_annotations[idx] + def __getitem__(self, index) -> TrainExample: + clip_annotation = self.clip_annotations[index] if self.clipper is not None: clip_annotation = self.clipper(clip_annotation) @@ -141,7 +141,7 @@ class ValidationDataset(Dataset): detection_heatmap=heatmaps.detection, class_heatmap=heatmaps.classes, size_heatmap=heatmaps.size, - idx=torch.tensor(idx), + idx=torch.tensor(index), start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index 119a654..2adfea6 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -472,7 +472,7 @@ def build_loss( size_loss_fn = BBoxLoss() - return LossFunction( # type: ignore + return LossFunction( size_loss=size_loss_fn, classification_loss=classification_loss_fn, detection_loss=detection_loss_fn,