diff --git a/src/batdetect2/api.py b/src/batdetect2/api.py index 9b137e0..f27b951 100644 --- a/src/batdetect2/api.py +++ b/src/batdetect2/api.py @@ -98,7 +98,6 @@ consult the API documentation in the code. """ import warnings -from typing import List, Tuple import numpy as np import torch @@ -272,7 +271,7 @@ def process_spectrogram( samp_rate: int = TARGET_SAMPLERATE_HZ, model: DetectionModel = MODEL, config: ProcessingConfiguration | None = None, -) -> Tuple[List[Annotation], np.ndarray]: +) -> tuple[list[Annotation], np.ndarray]: """Process spectrogram with model. Parameters @@ -314,7 +313,7 @@ def process_audio( model: DetectionModel = MODEL, config: ProcessingConfiguration | None = None, device: torch.device = DEVICE, -) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]: +) -> tuple[list[Annotation], np.ndarray, torch.Tensor]: """Process audio array with model. Parameters @@ -357,7 +356,7 @@ def postprocess( outputs: ModelOutput, samp_rate: int = TARGET_SAMPLERATE_HZ, config: ProcessingConfiguration | None = None, -) -> Tuple[List[Annotation], np.ndarray]: +) -> tuple[list[Annotation], np.ndarray]: """Postprocess model outputs. Convert model tensor outputs to predicted bounding boxes and diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index ec5310b..f21d473 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List, Sequence, Tuple +from typing import Sequence import numpy as np import torch @@ -110,7 +110,7 @@ class BatDetect2API: experiment_name: str | None = None, run_name: str | None = None, save_predictions: bool = True, - ) -> Tuple[Dict[str, float], List[List[Detection]]]: + ) -> tuple[dict[str, float], list[list[Detection]]]: return evaluate( self.model, test_annotations, @@ -187,7 +187,7 @@ class BatDetect2API: def process_audio( self, audio: np.ndarray, - ) -> List[Detection]: + ) -> list[Detection]: spec = self.generate_spectrogram(audio) return self.process_spectrogram(spec) @@ -195,7 +195,7 @@ class BatDetect2API: self, spec: torch.Tensor, start_time: float = 0, - ) -> List[Detection]: + ) -> list[Detection]: if spec.ndim == 4 and spec.shape[0] > 1: raise ValueError("Batched spectrograms not supported.") @@ -214,7 +214,7 @@ class BatDetect2API: def process_directory( self, audio_dir: data.PathLike, - ) -> List[ClipDetections]: + ) -> list[ClipDetections]: files = list(get_audio_files(audio_dir)) return self.process_files(files) @@ -222,7 +222,7 @@ class BatDetect2API: self, audio_files: Sequence[data.PathLike], num_workers: int | None = None, - ) -> List[ClipDetections]: + ) -> list[ClipDetections]: return process_file_list( self.model, audio_files, @@ -238,7 +238,7 @@ class BatDetect2API: clips: Sequence[data.Clip], batch_size: int | None = None, num_workers: int | None = None, - ) -> List[ClipDetections]: + ) -> list[ClipDetections]: return run_batch_inference( self.model, clips, @@ -274,7 +274,7 @@ class BatDetect2API: def load_predictions( self, path: data.PathLike, - ) -> List[ClipDetections]: + ) -> list[ClipDetections]: return self.formatter.load(path) @classmethod diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index 60d4266..5a3a524 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested configuration sections. """ -from typing import Any, Type, TypeVar, Union, overload +from typing import Any, Type, TypeVar, overload import yaml from deepmerge.merger import Merger @@ -69,8 +69,7 @@ class BaseConfig(BaseModel): T = TypeVar("T") T_Model = TypeVar("T_Model", bound=BaseModel) - -Schema = Union[Type[T_Model], TypeAdapter[T]] +Schema = Type[T_Model] | TypeAdapter[T] def get_object_field(obj: dict, current_key: str) -> Any: diff --git a/src/batdetect2/data/iterators.py b/src/batdetect2/data/iterators.py index c669776..c499b0b 100644 --- a/src/batdetect2/data/iterators.py +++ b/src/batdetect2/data/iterators.py @@ -1,5 +1,4 @@ from collections.abc import Generator -from typing import Tuple from soundevent import data @@ -10,7 +9,7 @@ from batdetect2.typing.targets import TargetProtocol def iterate_over_sound_events( dataset: Dataset, targets: TargetProtocol, -) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]: +) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]: """Iterate over sound events in a dataset. Parameters @@ -24,7 +23,7 @@ def iterate_over_sound_events( Yields ------ - Tuple[Optional[str], data.SoundEventAnnotation] + tuple[Optional[str], data.SoundEventAnnotation] A tuple containing: - The encoded class name (str) for the sound event, or None if it cannot be encoded to a specific class. diff --git a/src/batdetect2/data/predictions/base.py b/src/batdetect2/data/predictions/base.py index 0d92a2d..70e3bf5 100644 --- a/src/batdetect2/data/predictions/base.py +++ b/src/batdetect2/data/predictions/base.py @@ -1,6 +1,5 @@ -from typing import Literal - from pathlib import Path +from typing import Literal from soundevent.data import PathLike diff --git a/src/batdetect2/data/split.py b/src/batdetect2/data/split.py index 47d3bd6..fa4f871 100644 --- a/src/batdetect2/data/split.py +++ b/src/batdetect2/data/split.py @@ -1,5 +1,3 @@ -from typing import Tuple - from sklearn.model_selection import train_test_split from batdetect2.data.datasets import Dataset @@ -15,7 +13,7 @@ def split_dataset_by_recordings( targets: TargetProtocol, train_size: float = 0.75, random_state: int | None = None, -) -> Tuple[Dataset, Dataset]: +) -> tuple[Dataset, Dataset]: recordings = extract_recordings_df(dataset) sound_events = extract_sound_events_df( diff --git a/src/batdetect2/detector/post_process.py b/src/batdetect2/detector/post_process.py index 637deb1..ff2a6b5 100644 --- a/src/batdetect2/detector/post_process.py +++ b/src/batdetect2/detector/post_process.py @@ -1,7 +1,5 @@ """Post-processing of the output of the model.""" -from typing import List, Tuple - import numpy as np import torch from torch import nn @@ -45,7 +43,7 @@ def run_nms( outputs: ModelOutput, params: NonMaximumSuppressionConfig, sampling_rate: np.ndarray, -) -> Tuple[List[PredictionResults], List[np.ndarray]]: +) -> tuple[list[PredictionResults], list[np.ndarray]]: """Run non-maximum suppression on the output of the model. Model outputs processed are expected to have a batch dimension. @@ -73,8 +71,8 @@ def run_nms( scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) # loop over batch to save outputs - preds: List[PredictionResults] = [] - feats: List[np.ndarray] = [] + preds: list[PredictionResults] = [] + feats: list[np.ndarray] = [] for num_detection in range(pred_det_nms.shape[0]): # get valid indices inds_ord = torch.argsort(x_pos[num_detection, :]) @@ -151,7 +149,7 @@ def run_nms( def non_max_suppression( heat: torch.Tensor, - kernel_size: int | Tuple[int, int], + kernel_size: int | tuple[int, int], ): # kernel can be an int or list/tuple if isinstance(kernel_size, int): diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py index d7a3ba1..003ce42 100644 --- a/src/batdetect2/evaluate/plots/top_class.py +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -4,12 +4,9 @@ from dataclasses import dataclass, field from typing import ( Annotated, Callable, - Dict, Iterable, - List, Literal, Sequence, - Tuple, ) import matplotlib.pyplot as plt @@ -32,7 +29,7 @@ from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol -TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]] +TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]] top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( name="top_class_plot" @@ -73,7 +70,7 @@ class PRCurve(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Iterable[Tuple[str, Figure]]: + ) -> Iterable[tuple[str, Figure]]: y_true = [] y_score = [] num_positives = 0 @@ -140,7 +137,7 @@ class ROCCurve(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Iterable[Tuple[str, Figure]]: + ) -> Iterable[tuple[str, Figure]]: y_true = [] y_score = [] @@ -223,7 +220,7 @@ class ConfusionMatrix(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Iterable[Tuple[str, Figure]]: + ) -> Iterable[tuple[str, Figure]]: cm, labels = compute_confusion_matrix( clip_evaluations, self.targets, @@ -295,26 +292,26 @@ class ExampleClassificationPlot(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Iterable[Tuple[str, Figure]]: + ) -> Iterable[tuple[str, Figure]]: grouped = group_matches(clip_evaluations, threshold=self.threshold) for class_name, matches in grouped.items(): - true_positives: List[MatchEval] = get_binned_sample( + true_positives: list[MatchEval] = get_binned_sample( matches.true_positives, n_examples=self.num_examples, ) - false_positives: List[MatchEval] = get_binned_sample( + false_positives: list[MatchEval] = get_binned_sample( matches.false_positives, n_examples=self.num_examples, ) - false_negatives: List[MatchEval] = random.sample( + false_negatives: list[MatchEval] = random.sample( matches.false_negatives, k=min(self.num_examples, len(matches.false_negatives)), ) - cross_triggers: List[MatchEval] = get_binned_sample( + cross_triggers: list[MatchEval] = get_binned_sample( matches.cross_triggers, n_examples=self.num_examples ) @@ -374,16 +371,16 @@ def build_top_class_plotter( @dataclass class ClassMatches: - false_positives: List[MatchEval] = field(default_factory=list) - false_negatives: List[MatchEval] = field(default_factory=list) - true_positives: List[MatchEval] = field(default_factory=list) - cross_triggers: List[MatchEval] = field(default_factory=list) + false_positives: list[MatchEval] = field(default_factory=list) + false_negatives: list[MatchEval] = field(default_factory=list) + true_positives: list[MatchEval] = field(default_factory=list) + cross_triggers: list[MatchEval] = field(default_factory=list) def group_matches( clip_evals: Sequence[ClipEval], threshold: float = 0.2, -) -> Dict[str, ClassMatches]: +) -> dict[str, ClassMatches]: class_examples = defaultdict(ClassMatches) for clip_eval in clip_evals: @@ -412,7 +409,7 @@ def group_matches( return class_examples -def get_binned_sample(matches: List[MatchEval], n_examples: int = 5): +def get_binned_sample(matches: list[MatchEval], n_examples: int = 5): if len(matches) < n_examples: return matches diff --git a/src/batdetect2/finetune/prep_data_finetune.py b/src/batdetect2/finetune/prep_data_finetune.py index 76bd627..87d7ba3 100644 --- a/src/batdetect2/finetune/prep_data_finetune.py +++ b/src/batdetect2/finetune/prep_data_finetune.py @@ -2,7 +2,6 @@ import argparse import json import os from collections import Counter -from typing import List, Tuple import numpy as np from sklearn.model_selection import StratifiedGroupKFold @@ -12,8 +11,8 @@ from batdetect2 import types def print_dataset_stats( - data: List[types.FileAnnotation], - classes_to_ignore: List[str] | None = None, + data: list[types.FileAnnotation], + classes_to_ignore: list[str] | None = None, ) -> Counter[str]: print("Num files:", len(data)) counts, _ = tu.get_class_names(data, classes_to_ignore) @@ -22,7 +21,7 @@ def print_dataset_stats( return counts -def load_file_names(file_name: str) -> List[str]: +def load_file_names(file_name: str) -> list[str]: if not os.path.isfile(file_name): raise FileNotFoundError(f"Input file not found - {file_name}") @@ -100,12 +99,12 @@ def parse_args(): def split_data( - data: List[types.FileAnnotation], + data: list[types.FileAnnotation], train_file: str, test_file: str, n_splits: int = 5, random_state: int = 0, -) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]: +) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]: if train_file != "" and test_file != "": # user has specifed the train / test split mapping = { diff --git a/src/batdetect2/inference/dataset.py b/src/batdetect2/inference/dataset.py index 62dcd36..13bfd32 100644 --- a/src/batdetect2/inference/dataset.py +++ b/src/batdetect2/inference/dataset.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Sequence +from typing import NamedTuple, Sequence import torch from loguru import logger @@ -29,7 +29,7 @@ class DatasetItem(NamedTuple): class InferenceDataset(Dataset[DatasetItem]): - clips: List[data.Clip] + clips: list[data.Clip] def __init__( self, @@ -111,7 +111,7 @@ def build_inference_dataset( ) -def _collate_fn(batch: List[DatasetItem]) -> DatasetItem: +def _collate_fn(batch: list[DatasetItem]) -> DatasetItem: max_width = max(item.spec.shape[-1] for item in batch) return DatasetItem( spec=torch.stack( diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index ec6605d..ff99c77 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -26,8 +26,6 @@ The primary entry point for building a full, ready-to-use BatDetect2 model is the ``build_model`` factory function exported from this module. """ -from typing import List - import torch from batdetect2.models.backbones import ( @@ -142,7 +140,7 @@ class Model(torch.nn.Module): self.postprocessor = postprocessor self.targets = targets - def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]: + def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]: """Run the full detection pipeline on a waveform tensor. Converts the waveform to a spectrogram, passes it through the @@ -157,7 +155,7 @@ class Model(torch.nn.Module): Returns ------- - List[ClipDetectionsTensor] + list[ClipDetectionsTensor] One detection tensor per clip in the batch. Each tensor encodes the detected events (locations, class scores, sizes) for that clip. diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index 8f7a2b1..f13aa7f 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -23,7 +23,7 @@ output so that the output spatial dimensions always match the input spatial dimensions. """ -from typing import Annotated, Literal, Tuple, Union +from typing import Annotated, Literal import torch import torch.nn.functional as F @@ -58,6 +58,14 @@ from batdetect2.typing.models import ( EncoderProtocol, ) +__all__ = [ + "BackboneImportConfig", + "UNetBackbone", + "BackboneConfig", + "load_backbone_config", + "build_backbone", +] + class UNetBackboneConfig(BaseConfig): """Configuration for a U-Net-style encoder-decoder backbone. @@ -110,15 +118,6 @@ class BackboneImportConfig(ImportConfig): name: Literal["import"] = "import" -__all__ = [ - "BackboneImportConfig", - "UNetBackbone", - "BackboneConfig", - "load_backbone_config", - "build_backbone", -] - - class UNetBackbone(BackboneModel): """U-Net-style encoder-decoder backbone network. @@ -262,7 +261,8 @@ class UNetBackbone(BackboneModel): BackboneConfig = Annotated[ - Union[UNetBackboneConfig,], Field(discriminator="name") + UNetBackboneConfig | BackboneImportConfig, + Field(discriminator="name"), ] @@ -292,7 +292,7 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: def _pad_adjust( spec: torch.Tensor, factor: int = 32, -) -> Tuple[torch.Tensor, int, int]: +) -> tuple[torch.Tensor, int, int]: """Pad a tensor's height and width to be divisible by ``factor``. Adds zero-padding to the bottom and right edges of the tensor so that @@ -308,7 +308,7 @@ def _pad_adjust( Returns ------- - Tuple[torch.Tensor, int, int] + tuple[torch.Tensor, int, int] - Padded tensor. - Number of rows added to the height (``h_pad``). - Number of columns added to the width (``w_pad``). diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 7e083c7..f7f96fc 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -46,7 +46,7 @@ configuration object (one of the ``*Config`` classes exported here), using a discriminated-union ``name`` field to dispatch to the correct class. """ -from typing import Annotated, List, Literal, Tuple, Union +from typing import Annotated, Literal import torch import torch.nn.functional as F @@ -687,7 +687,7 @@ class FreqCoordConvUpConfig(BaseConfig): up_mode: str = "bilinear" """Interpolation mode for upsampling (e.g., "nearest", "bilinear").""" - up_scale: Tuple[int, int] = (2, 2) + up_scale: tuple[int, int] = (2, 2) """Scaling factor for height and width during upsampling.""" @@ -706,22 +706,22 @@ class FreqCoordConvUpBlock(Block): Parameters ---------- - in_channels : int + in_channels Number of channels in the input tensor (before upsampling). - out_channels : int + out_channels Number of output channels after the convolution. - input_height : int + input_height Height (H dimension, frequency bins) of the tensor *before* upsampling. Used to calculate the height for coordinate feature generation after upsampling. - kernel_size : int, default=3 + kernel_size Size of the square convolutional kernel. - pad_size : int, default=1 + pad_size Padding added before convolution. - up_mode : str, default="bilinear" + up_mode Interpolation mode for upsampling (e.g., "nearest", "bilinear", "bicubic"). - up_scale : Tuple[int, int], default=(2, 2) + up_scale Scaling factor for height and width during upsampling (typically (2, 2)). """ @@ -734,7 +734,7 @@ class FreqCoordConvUpBlock(Block): kernel_size: int = 3, pad_size: int = 1, up_mode: str = "bilinear", - up_scale: Tuple[int, int] = (2, 2), + up_scale: tuple[int, int] = (2, 2), ): super().__init__() self.in_channels = in_channels @@ -824,7 +824,7 @@ class StandardConvUpConfig(BaseConfig): up_mode: str = "bilinear" """Interpolation mode for upsampling (e.g., "nearest", "bilinear").""" - up_scale: Tuple[int, int] = (2, 2) + up_scale: tuple[int, int] = (2, 2) """Scaling factor for height and width during upsampling.""" @@ -839,17 +839,17 @@ class StandardConvUpBlock(Block): Parameters ---------- - in_channels : int + in_channels Number of channels in the input tensor (before upsampling). - out_channels : int + out_channels Number of output channels after the convolution. - kernel_size : int, default=3 + kernel_size Size of the square convolutional kernel. - pad_size : int, default=1 + pad_size Padding added before convolution. - up_mode : str, default="bilinear" + up_mode Interpolation mode for upsampling (e.g., "nearest", "bilinear"). - up_scale : Tuple[int, int], default=(2, 2) + up_scale Scaling factor for height and width during upsampling. """ @@ -860,7 +860,7 @@ class StandardConvUpBlock(Block): kernel_size: int = 3, pad_size: int = 1, up_mode: str = "bilinear", - up_scale: Tuple[int, int] = (2, 2), + up_scale: tuple[int, int] = (2, 2), ): super(StandardConvUpBlock, self).__init__() self.in_channels = in_channels @@ -922,15 +922,14 @@ class StandardConvUpBlock(Block): LayerConfig = Annotated[ - Union[ - ConvConfig, - FreqCoordConvDownConfig, - StandardConvDownConfig, - FreqCoordConvUpConfig, - StandardConvUpConfig, - SelfAttentionConfig, - "LayerGroupConfig", - ], + ConvConfig + | BlockImportConfig + | FreqCoordConvDownConfig + | StandardConvDownConfig + | FreqCoordConvUpConfig + | StandardConvUpConfig + | SelfAttentionConfig + | "LayerGroupConfig", Field(discriminator="name"), ] """Type alias for the discriminated union of block configuration models.""" @@ -952,7 +951,7 @@ class LayerGroupConfig(BaseConfig): """ name: Literal["LayerGroup"] = "LayerGroup" - layers: List[LayerConfig] + layers: list[LayerConfig] class LayerGroup(nn.Module): diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index 7872747..be6f798 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -1,5 +1,3 @@ -from typing import Tuple - from matplotlib.axes import Axes from soundevent import data, plot @@ -16,7 +14,7 @@ __all__ = [ def plot_clip_annotation( clip_annotation: data.ClipAnnotation, preprocessor: PreprocessorProtocol | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, audio_dir: data.PathLike | None = None, add_points: bool = False, @@ -50,7 +48,7 @@ def plot_clip_annotation( def plot_anchor_points( clip_annotation: data.ClipAnnotation, targets: TargetProtocol, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, size: int = 1, color: str = "red", diff --git a/src/batdetect2/plotting/clip_predictions.py b/src/batdetect2/plotting/clip_predictions.py index fc28e22..a8cc198 100644 --- a/src/batdetect2/plotting/clip_predictions.py +++ b/src/batdetect2/plotting/clip_predictions.py @@ -1,4 +1,4 @@ -from typing import Iterable, Tuple +from typing import Iterable from matplotlib.axes import Axes from soundevent import data @@ -18,7 +18,7 @@ __all__ = [ def plot_clip_prediction( clip_prediction: data.ClipPrediction, preprocessor: PreprocessorProtocol | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, audio_dir: data.PathLike | None = None, add_legend: bool = False, diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py index 9a7e541..9b67e25 100644 --- a/src/batdetect2/plotting/clips.py +++ b/src/batdetect2/plotting/clips.py @@ -1,5 +1,3 @@ -from typing import Tuple - import matplotlib.pyplot as plt import torch from matplotlib.axes import Axes @@ -19,7 +17,7 @@ def plot_clip( clip: data.Clip, audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, audio_dir: data.PathLike | None = None, spec_cmap: str = "gray", diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index 286cab4..425a40d 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -1,7 +1,5 @@ """General plotting utilities.""" -from typing import Tuple - import matplotlib.pyplot as plt import numpy as np import torch @@ -14,7 +12,7 @@ __all__ = [ def create_ax( ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, **kwargs, ) -> axes.Axes: """Create a new axis if none is provided""" @@ -31,7 +29,7 @@ def plot_spectrogram( min_freq: float | None = None, max_freq: float | None = None, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_colorbar: bool = False, colorbar_kwargs: dict | None = None, vmin: float | None = None, diff --git a/src/batdetect2/plotting/heatmaps.py b/src/batdetect2/plotting/heatmaps.py index 2fd9187..db29c82 100644 --- a/src/batdetect2/plotting/heatmaps.py +++ b/src/batdetect2/plotting/heatmaps.py @@ -1,7 +1,5 @@ """Plot heatmaps""" -from typing import List, Tuple - import numpy as np import torch from matplotlib import axes, patches @@ -14,7 +12,7 @@ from batdetect2.plotting.common import create_ax def plot_detection_heatmap( heatmap: torch.Tensor | np.ndarray, ax: axes.Axes | None = None, - figsize: Tuple[int, int] = (10, 10), + figsize: tuple[int, int] = (10, 10), threshold: float | None = None, alpha: float = 1, cmap: str | Colormap = "jet", @@ -50,8 +48,8 @@ def plot_detection_heatmap( def plot_classification_heatmap( heatmap: torch.Tensor | np.ndarray, ax: axes.Axes | None = None, - figsize: Tuple[int, int] = (10, 10), - class_names: List[str] | None = None, + figsize: tuple[int, int] = (10, 10), + class_names: list[str] | None = None, threshold: float | None = 0.1, alpha: float = 1, cmap: str | Colormap = "tab20", diff --git a/src/batdetect2/plotting/legacy/plot.py b/src/batdetect2/plotting/legacy/plot.py index bac240f..6051abe 100644 --- a/src/batdetect2/plotting/legacy/plot.py +++ b/src/batdetect2/plotting/legacy/plot.py @@ -1,6 +1,6 @@ """Plot functions to visualize detections and spectrograms.""" -from typing import List, Tuple, cast +from typing import cast import matplotlib.ticker as tick import numpy as np @@ -27,7 +27,7 @@ def spectrogram( spec: torch.Tensor | np.ndarray, config: ProcessingConfiguration | None = None, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, cmap: str = "plasma", start_time: float = 0, ) -> axes.Axes: @@ -35,18 +35,18 @@ def spectrogram( Parameters ---------- - spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. - config (Optional[ProcessingConfiguration], optional): Configuration + spec: Spectrogram to plot. + config: Configuration used to compute the spectrogram. Defaults to None. If None, the default configuration will be used. - ax (Optional[axes.Axes], optional): Matplotlib axes object. + ax: Matplotlib axes object. Defaults to None. if provided, the spectrogram will be plotted on this axes. - figsize (Optional[Tuple[int, int]], optional): Figure size. + figsize: Figure size. Defaults to None. If `ax` is None, this will be used to create a new figure of the given size. - cmap (str, optional): Colormap to use. Defaults to "plasma". - start_time (float, optional): Start time of the spectrogram. + cmap: Colormap to use. Defaults to "plasma". + start_time: Start time of the spectrogram. Defaults to 0. This is useful if plotting a spectrogram of a segment of a longer audio file. @@ -104,10 +104,10 @@ def spectrogram( def spectrogram_with_detections( spec: torch.Tensor | np.ndarray, - dets: List[Annotation], + dets: list[Annotation], config: ProcessingConfiguration | None = None, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, cmap: str = "plasma", with_names: bool = True, start_time: float = 0, @@ -117,21 +117,21 @@ def spectrogram_with_detections( Parameters ---------- - spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. - detections (List[Annotation]): List of detections. - config (Optional[ProcessingConfiguration], optional): Configuration + spec: Spectrogram to plot. + detections: List of detections. + config: Configuration used to compute the spectrogram. Defaults to None. If None, the default configuration will be used. - ax (Optional[axes.Axes], optional): Matplotlib axes object. + ax: Matplotlib axes object. Defaults to None. if provided, the spectrogram will be plotted on this axes. - figsize (Optional[Tuple[int, int]], optional): Figure size. + figsize: Figure size. Defaults to None. If `ax` is None, this will be used to create a new figure of the given size. - cmap (str, optional): Colormap to use. Defaults to "plasma". - with_names (bool, optional): Whether to plot the name of the + cmap: Colormap to use. Defaults to "plasma". + with_names: Whether to plot the name of the predicted class next to the detection. Defaults to True. - start_time (float, optional): Start time of the spectrogram. + start_time: Start time of the spectrogram. Defaults to 0. This is useful if plotting a spectrogram of a segment of a longer audio file. **kwargs: Additional keyword arguments to pass to the @@ -167,9 +167,9 @@ def spectrogram_with_detections( def detections( - dets: List[Annotation], + dets: list[Annotation], ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, with_names: bool = True, **kwargs, ) -> axes.Axes: @@ -177,14 +177,14 @@ def detections( Parameters ---------- - dets (List[Annotation]): List of detections. - ax (Optional[axes.Axes], optional): Matplotlib axes object. + dets: List of detections. + ax: Matplotlib axes object. Defaults to None. if provided, the spectrogram will be plotted on this axes. - figsize (Optional[Tuple[int, int]], optional): Figure size. + figsize: Figure size. Defaults to None. If `ax` is None, this will be used to create a new figure of the given size. - with_names (bool, optional): Whether to plot the name of the + with_names: Whether to plot the name of the predicted class next to the detection. Defaults to True. **kwargs: Additional keyword arguments to pass to the `plot.detection` function. @@ -214,7 +214,7 @@ def detections( def detection( det: Annotation, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, linewidth: float = 1, edgecolor: str = "w", facecolor: str = "none", @@ -224,19 +224,19 @@ def detection( Parameters ---------- - det (Annotation): Detection to plot. - ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults + det: Detection to plot. + ax: Matplotlib axes object. Defaults to None. If provided, the spectrogram will be plotted on this axes. - figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults + figsize: Figure size. Defaults to None. If `ax` is None, this will be used to create a new figure of the given size. - linewidth (float, optional): Line width of the detection. + linewidth: Line width of the detection. Defaults to 1. - edgecolor (str, optional): Edge color of the detection. + edgecolor: Edge color of the detection. Defaults to "w", i.e. white. - facecolor (str, optional): Face color of the detection. + facecolor: Face color of the detection. Defaults to "none", i.e. transparent. - with_name (bool, optional): Whether to plot the name of the + with_name: Whether to plot the name of the predicted class next to the detection. Defaults to True. Returns @@ -277,22 +277,22 @@ def detection( def _compute_spec_extent( - shape: Tuple[int, int], + shape: tuple[int, int], params: SpectrogramParameters, -) -> Tuple[float, float, float, float]: +) -> tuple[float, float, float, float]: """Compute the extent of a spectrogram. Parameters ---------- - shape (Tuple[int, int]): Shape of the spectrogram. + shape: Shape of the spectrogram. The first dimension is the frequency axis and the second dimension is the time axis. - params (SpectrogramParameters): Spectrogram parameters. + params: Spectrogram parameters. Should be the same as the ones used to compute the spectrogram. Returns ------- - Tuple[float, float, float, float]: Extent of the spectrogram. + tuple[float, float, float, float]: Extent of the spectrogram. The first two values are the minimum and maximum time values, the last two values are the minimum and maximum frequency values. """ diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 69032ca..cbd1be6 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -1,4 +1,4 @@ -from typing import Protocol, Tuple +from typing import Protocol from matplotlib.axes import Axes from soundevent import data, plot @@ -40,7 +40,7 @@ def plot_false_positive_match( match: MatchProtocol, audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, @@ -111,7 +111,7 @@ def plot_false_negative_match( match: MatchProtocol, audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, @@ -171,7 +171,7 @@ def plot_true_positive_match( match: MatchProtocol, preprocessor: PreprocessorProtocol | None = None, audio_loader: AudioLoader | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, @@ -259,7 +259,7 @@ def plot_cross_trigger_match( match: MatchProtocol, preprocessor: PreprocessorProtocol | None = None, audio_loader: AudioLoader | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, ax: Axes | None = None, audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py index 78c73d2..b267a00 100644 --- a/src/batdetect2/plotting/metrics.py +++ b/src/batdetect2/plotting/metrics.py @@ -1,5 +1,3 @@ -from typing import Dict, Tuple - import numpy as np import seaborn as sns from cycler import cycler @@ -34,14 +32,14 @@ def plot_pr_curve( recall: np.ndarray, thresholds: np.ndarray, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, - color: str | Tuple[float, float, float] | None = None, + figsize: tuple[int, int] | None = None, + color: str | tuple[float, float, float] | None = None, add_labels: bool = True, add_legend: bool = False, - marker: str | Tuple[int, int, float] | None = "o", - markeredgecolor: str | Tuple[float, float, float] | None = None, + marker: str | tuple[int, int, float] | None = "o", + markeredgecolor: str | tuple[float, float, float] | None = None, markersize: float | None = None, - linestyle: str | Tuple[int, ...] | None = None, + linestyle: str | tuple[int, ...] | None = None, linewidth: float | None = None, label: str = "PR Curve", ) -> axes.Axes: @@ -76,9 +74,9 @@ def plot_pr_curve( def plot_pr_curves( - data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]], ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, include_ap: bool = False, @@ -119,7 +117,7 @@ def plot_threshold_precision_curve( threshold: np.ndarray, precision: np.ndarray, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_labels: bool = True, ): ax = create_ax(ax=ax, figsize=figsize) @@ -139,9 +137,9 @@ def plot_threshold_precision_curve( def plot_threshold_precision_curves( - data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]], ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, ): @@ -177,7 +175,7 @@ def plot_threshold_recall_curve( threshold: np.ndarray, recall: np.ndarray, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_labels: bool = True, ): ax = create_ax(ax=ax, figsize=figsize) @@ -197,9 +195,9 @@ def plot_threshold_recall_curve( def plot_threshold_recall_curves( - data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]], ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, ): @@ -236,7 +234,7 @@ def plot_roc_curve( tpr: np.ndarray, thresholds: np.ndarray, ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_labels: bool = True, ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) @@ -260,9 +258,9 @@ def plot_roc_curve( def plot_roc_curves( - data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]], ax: axes.Axes | None = None, - figsize: Tuple[int, int] | None = None, + figsize: tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, ) -> axes.Axes: diff --git a/src/batdetect2/postprocess/nms.py b/src/batdetect2/postprocess/nms.py index c4731f6..6a41430 100644 --- a/src/batdetect2/postprocess/nms.py +++ b/src/batdetect2/postprocess/nms.py @@ -11,8 +11,6 @@ activations that have lower scores than a local maximum. This helps prevent multiple, overlapping detections originating from the same sound event. """ -from typing import Tuple - import torch NMS_KERNEL_SIZE = 9 @@ -27,7 +25,7 @@ BatDetect2. def non_max_suppression( tensor: torch.Tensor, - kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE, + kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE, ) -> torch.Tensor: """Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap. @@ -42,11 +40,11 @@ def non_max_suppression( Parameters ---------- - tensor : torch.Tensor + tensor Input tensor, typically representing a detection heatmap. Must be a 3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying `torch.nn.functional.max_pool2d` operation. - kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE + kernel_size Size of the sliding window neighborhood used to find local maxima. If an integer `k` is provided, a square kernel of size `(k, k)` is used. If a tuple `(h, w)` is provided, a rectangular kernel of height `h` diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index 9d7c339..2a27560 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import torch from loguru import logger @@ -51,7 +49,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): max_freq: float, top_k_per_sec: int = 200, detection_threshold: float = 0.01, - nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE, + nms_kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE, ): """Initialize the Postprocessor.""" super().__init__() @@ -66,8 +64,8 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): def forward( self, output: ModelOutput, - start_times: List[float] | None = None, - ) -> List[ClipDetectionsTensor]: + start_times: list[float] | None = None, + ) -> list[ClipDetectionsTensor]: detection_heatmap = non_max_suppression( output.detection_probs.detach(), kernel_size=self.nms_kernel_size, diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index eb4200c..9b87495 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the *geometric* aspect of target definition from *semantic* classification. """ -from typing import Annotated, Literal, Tuple +from typing import Annotated, Literal import numpy as np from pydantic import Field @@ -144,7 +144,7 @@ class AnchorBBoxMapper(ROITargetMapper): Attributes ---------- - dimension_names : List[str] + dimension_names : list[str] The output dimension names: `['width', 'height']`. anchor : Anchor The configured anchor point type (e.g., "center", "bottom-left"). @@ -177,7 +177,7 @@ class AnchorBBoxMapper(ROITargetMapper): self.time_scale = time_scale self.frequency_scale = frequency_scale - def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]: + def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]: """Encode a SoundEvent into an anchor position and scaled box size. The position is determined by the configured anchor on the sound @@ -190,7 +190,7 @@ class AnchorBBoxMapper(ROITargetMapper): Returns ------- - Tuple[Position, Size] + tuple[Position, Size] A tuple of (anchor_position, [scaled_width, scaled_height]). """ from soundevent import geometry @@ -314,7 +314,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper): Attributes ---------- - dimension_names : List[str] + dimension_names : list[str] The output dimension names: `['left', 'bottom', 'right', 'top']`. preprocessor : PreprocessorProtocol The spectrogram preprocessor instance. @@ -371,7 +371,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper): Returns ------- - Tuple[Position, Size] + tuple[Position, Size] A tuple of (peak_position, [l, b, r, t] distances). """ from soundevent import geometry @@ -519,14 +519,14 @@ def _build_bounding_box( Parameters ---------- - pos : Tuple[float, float] + pos Reference position (time, frequency). - duration : float + duration The required *unscaled* duration (width) of the bounding box. - bandwidth : float + bandwidth The required *unscaled* frequency bandwidth (height) of the bounding box. - anchor : Anchor + anchor Specifies which part of the bounding box the input `pos` corresponds to. Returns diff --git a/src/batdetect2/targets/targets.py b/src/batdetect2/targets/targets.py index 9b09c2d..0c83d50 100644 --- a/src/batdetect2/targets/targets.py +++ b/src/batdetect2/targets/targets.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Tuple +from typing import Iterable from loguru import logger from soundevent import data @@ -33,20 +33,20 @@ class Targets(TargetProtocol): Attributes ---------- - class_names : List[str] + class_names An ordered list of the unique names of the specific target classes defined in the configuration. - generic_class_tags : List[data.Tag] + generic_class_tags A list of `soundevent.data.Tag` objects representing the configured generic class category (used when no specific class matches). - dimension_names : List[str] + dimension_names The names of the size dimensions handled by the ROI mapper (e.g., ['width', 'height']). """ - class_names: List[str] - detection_class_tags: List[data.Tag] - dimension_names: List[str] + class_names: list[str] + detection_class_tags: list[data.Tag] + dimension_names: list[str] detection_class_name: str def __init__(self, config: TargetConfig): @@ -128,7 +128,7 @@ class Targets(TargetProtocol): """ return self._encode_fn(sound_event) - def decode_class(self, class_label: str) -> List[data.Tag]: + def decode_class(self, class_label: str) -> list[data.Tag]: """Decode a predicted class name back into representative tags. Uses the configured mapping (based on `TargetClass.output_tags` or @@ -142,7 +142,7 @@ class Targets(TargetProtocol): Returns ------- - List[data.Tag] + list[data.Tag] The list of tags corresponding to the input class name. """ return self._decode_fn(class_label) @@ -161,7 +161,7 @@ class Targets(TargetProtocol): Returns ------- - Tuple[float, float] + tuple[float, float] The reference position `(time, frequency)`. Raises @@ -192,9 +192,9 @@ class Targets(TargetProtocol): Parameters ---------- - pos : Tuple[float, float] + pos The reference position `(time, frequency)`. - dims : np.ndarray + dims NumPy array with size dimensions (e.g., from model prediction), matching the order in `self.dimension_names`. @@ -292,7 +292,7 @@ def load_targets( def iterate_encoded_sound_events( sound_events: Iterable[data.SoundEventAnnotation], targets: TargetProtocol, -) -> Iterable[Tuple[str | None, Position, Size]]: +) -> Iterable[tuple[str | None, Position, Size]]: for sound_event in sound_events: if not targets.filter(sound_event): continue diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 5e1966c..be357fe 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING import lightning as L import torch @@ -97,7 +97,7 @@ class TrainingModule(L.LightningModule): def load_model_from_checkpoint( path: PathLike, -) -> Tuple[Model, "BatDetect2Config"]: +) -> tuple[Model, "BatDetect2Config"]: module = TrainingModule.load_from_checkpoint(path) # type: ignore return module.model, module.config diff --git a/src/batdetect2/types.py b/src/batdetect2/types.py index 4c8dd4f..35cdacc 100644 --- a/src/batdetect2/types.py +++ b/src/batdetect2/types.py @@ -1,6 +1,6 @@ """Types used in the code base.""" -from typing import Any, List, NamedTuple, TypedDict +from typing import Any, NamedTuple, TypedDict import numpy as np import torch @@ -86,7 +86,7 @@ class ModelParameters(TypedDict): resize_factor: float """Resize factor.""" - class_names: List[str] + class_names: list[str] """Class names. The model is trained to detect these classes. @@ -158,7 +158,7 @@ class FileAnnotation(TypedDict): notes: str """Notes of file.""" - annotation: List[Annotation] + annotation: list[Annotation] """List of annotations.""" @@ -168,26 +168,26 @@ class RunResults(TypedDict): pred_dict: FileAnnotation """Predictions in the format expected by the annotation tool.""" - spec_feats: NotRequired[List[np.ndarray]] + spec_feats: NotRequired[list[np.ndarray]] """Spectrogram features.""" - spec_feat_names: NotRequired[List[str]] + spec_feat_names: NotRequired[list[str]] """Spectrogram feature names.""" - cnn_feats: NotRequired[List[np.ndarray]] + cnn_feats: NotRequired[list[np.ndarray]] """CNN features.""" - cnn_feat_names: NotRequired[List[str]] + cnn_feat_names: NotRequired[list[str]] """CNN feature names.""" - spec_slices: NotRequired[List[np.ndarray]] + spec_slices: NotRequired[list[np.ndarray]] """Spectrogram slices.""" class ResultParams(TypedDict): """Result parameters.""" - class_names: List[str] + class_names: list[str] """Class names.""" spec_features: bool @@ -234,7 +234,7 @@ class ProcessingConfiguration(TypedDict): scale_raw_audio: bool """Whether to scale the raw audio to be between -1 and 1.""" - class_names: List[str] + class_names: list[str] """Names of the classes the model can detect.""" detection_threshold: float @@ -466,7 +466,7 @@ class FeatureExtractionParameters(TypedDict): class HeatmapParameters(TypedDict): """Parameters that control the heatmap generation function.""" - class_names: List[str] + class_names: list[str] fft_win_length: float """Length of the FFT window in seconds.""" @@ -553,15 +553,15 @@ class AudioLoaderAnnotationGroup(TypedDict): individual_ids: np.ndarray x_inds: np.ndarray y_inds: np.ndarray - annotation: List[Annotation] + annotation: list[Annotation] annotated: bool class_id_file: int """ID of the class of the file.""" class AudioLoaderParameters(TypedDict): - class_names: List[str] - classes_to_ignore: List[str] + class_names: list[str] + classes_to_ignore: list[str] target_samp_rate: int scale_raw_audio: bool fft_win_length: float diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 0d1c1d3..9698342 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -1,12 +1,9 @@ from dataclasses import dataclass from typing import ( - Dict, Generic, Iterable, - List, Protocol, Sequence, - Tuple, TypeVar, ) @@ -33,7 +30,7 @@ class MatchEvaluation: gt_geometry: data.Geometry | None pred_score: float - pred_class_scores: Dict[str, float] + pred_class_scores: dict[str, float] pred_geometry: data.Geometry | None affinity: float @@ -66,7 +63,7 @@ class MatchEvaluation: @dataclass class ClipMatches: clip: data.Clip - matches: List[MatchEvaluation] + matches: list[MatchEvaluation] class MatcherProtocol(Protocol): @@ -75,7 +72,7 @@ class MatcherProtocol(Protocol): ground_truth: Sequence[data.Geometry], predictions: Sequence[data.Geometry], scores: Sequence[float], - ) -> Iterable[Tuple[int | None, int | None, float]]: ... + ) -> Iterable[tuple[int | None, int | None, float]]: ... Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) @@ -94,7 +91,7 @@ class MetricsProtocol(Protocol): self, clip_annotations: Sequence[data.ClipAnnotation], predictions: Sequence[Sequence[Detection]], - ) -> Dict[str, float]: ... + ) -> dict[str, float]: ... class PlotterProtocol(Protocol): @@ -102,7 +99,7 @@ class PlotterProtocol(Protocol): self, clip_annotations: Sequence[data.ClipAnnotation], predictions: Sequence[Sequence[Detection]], - ) -> Iterable[Tuple[str, Figure]]: ... + ) -> Iterable[tuple[str, Figure]]: ... EvaluationOutput = TypeVar("EvaluationOutput") @@ -119,8 +116,8 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): def compute_metrics( self, eval_outputs: EvaluationOutput - ) -> Dict[str, float]: ... + ) -> dict[str, float]: ... def generate_plots( self, eval_outputs: EvaluationOutput - ) -> Iterable[Tuple[str, Figure]]: ... + ) -> Iterable[tuple[str, Figure]]: ... diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py index a1e27f2..076e287 100644 --- a/src/batdetect2/typing/train.py +++ b/src/batdetect2/typing/train.py @@ -1,4 +1,4 @@ -from typing import Callable, NamedTuple, Protocol, Tuple +from typing import Callable, NamedTuple, Protocol import torch from soundevent import data @@ -52,7 +52,7 @@ steps, and returns the final `Heatmaps` used for model training. Augmentation = Callable[ [torch.Tensor, data.ClipAnnotation], - Tuple[torch.Tensor, data.ClipAnnotation], + tuple[torch.Tensor, data.ClipAnnotation], ] diff --git a/src/batdetect2/utils/audio_utils.py b/src/batdetect2/utils/audio_utils.py index bc3b857..f5887cb 100644 --- a/src/batdetect2/utils/audio_utils.py +++ b/src/batdetect2/utils/audio_utils.py @@ -1,5 +1,4 @@ import warnings -from typing import Tuple import librosa import librosa.core.spectrum @@ -147,7 +146,7 @@ def load_audio( target_samp_rate: int, scale: bool = False, max_duration: float | None = None, -) -> Tuple[int, np.ndarray]: +) -> tuple[int, np.ndarray]: """Load an audio file and resample it to the target sampling rate. The audio is also scaled to [-1, 1] and clipped to the maximum duration. diff --git a/src/batdetect2/utils/detector_utils.py b/src/batdetect2/utils/detector_utils.py index 629f669..fc448c3 100644 --- a/src/batdetect2/utils/detector_utils.py +++ b/src/batdetect2/utils/detector_utils.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Iterator, List, Tuple +from typing import Any, Iterator import librosa import numpy as np @@ -60,7 +60,7 @@ def get_default_bd_args(): return args -def list_audio_files(ip_dir: str) -> List[str]: +def list_audio_files(ip_dir: str) -> list[str]: """Get all audio files in directory. Args: @@ -86,7 +86,7 @@ def load_model( load_weights: bool = True, device: torch.device | str | None = None, weights_only: bool = True, -) -> Tuple[DetectionModel, ModelParameters]: +) -> tuple[DetectionModel, ModelParameters]: """Load model from file. Args: @@ -185,26 +185,28 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): def get_annotations_from_preds( predictions: PredictionResults, - class_names: List[str], -) -> List[Annotation]: + class_names: list[str], +) -> list[Annotation]: """Get list of annotations from predictions.""" # Get the best class prediction probability and index for each detection class_prob_best = predictions["class_probs"].max(0) class_ind_best = predictions["class_probs"].argmax(0) # Pack the results into a list of dictionaries - annotations: List[Annotation] = [ - { - "start_time": round(float(start_time), 4), - "end_time": round(float(end_time), 4), - "low_freq": int(low_freq), - "high_freq": int(high_freq), - "class": str(class_names[class_index]), - "class_prob": round(float(class_prob), 3), - "det_prob": round(float(det_prob), 3), - "individual": "-1", - "event": "Echolocation", - } + annotations: list[Annotation] = [ + Annotation( + { + "start_time": round(float(start_time), 4), + "end_time": round(float(end_time), 4), + "low_freq": int(low_freq), + "high_freq": int(high_freq), + "class": str(class_names[class_index]), + "class_prob": round(float(class_prob), 3), + "det_prob": round(float(det_prob), 3), + "individual": "-1", + "event": "Echolocation", + } + ) for ( start_time, end_time, @@ -232,7 +234,7 @@ def format_single_result( time_exp: float, duration: float, predictions: PredictionResults, - class_names: List[str], + class_names: list[str], ) -> FileAnnotation: """Format results into the format expected by the annotation tool. @@ -315,9 +317,9 @@ def convert_results( ] # combine into final results dictionary - results: RunResults = { + results: RunResults = RunResults({ # type: ignore "pred_dict": pred_dict, - } + }) # add spectrogram features if they exist if len(spec_feats) > 0 and params["spec_features"]: @@ -413,7 +415,7 @@ def compute_spectrogram( sampling_rate: int, params: SpectrogramParameters, device: torch.device, -) -> Tuple[float, torch.Tensor]: +) -> tuple[float, torch.Tensor]: """Compute a spectrogram from an audio array. Will pad the audio array so that it is evenly divisible by the @@ -475,7 +477,7 @@ def iterate_over_chunks( audio: np.ndarray, samplerate: float, chunk_size: float, -) -> Iterator[Tuple[float, np.ndarray]]: +) -> Iterator[tuple[float, np.ndarray]]: """Iterate over audio in chunks of size chunk_size. Parameters @@ -509,7 +511,7 @@ def _process_spectrogram( samplerate: float, model: DetectionModel, config: ProcessingConfiguration, -) -> Tuple[PredictionResults, np.ndarray]: +) -> tuple[PredictionResults, np.ndarray]: # evaluate model with torch.no_grad(): outputs = model(spec) @@ -546,7 +548,7 @@ def postprocess_model_outputs( outputs: ModelOutput, samp_rate: int, config: ProcessingConfiguration, -) -> Tuple[List[Annotation], np.ndarray]: +) -> tuple[list[Annotation], np.ndarray]: # run non-max suppression pred_nms_list, features = pp.run_nms( outputs, @@ -585,7 +587,7 @@ def process_spectrogram( samplerate: int, model: DetectionModel, config: ProcessingConfiguration, -) -> Tuple[List[Annotation], np.ndarray]: +) -> tuple[list[Annotation], np.ndarray]: """Process a spectrogram with detection model. Will run non-maximum suppression on the output of the model. @@ -604,9 +606,9 @@ def process_spectrogram( Returns ------- - detections: List[Annotation] + detections List of detections predicted by the model. - features : np.ndarray + features An array of CNN features associated with each annotation. The array is of shape (num_detections, num_features). Is empty if `config["cnn_features"]` is False. @@ -632,7 +634,7 @@ def _process_audio_array( model: DetectionModel, config: ProcessingConfiguration, device: torch.device, -) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]: +) -> tuple[PredictionResults, np.ndarray, torch.Tensor]: # load audio file and compute spectrogram _, spec = compute_spectrogram( audio, @@ -669,7 +671,7 @@ def process_audio_array( model: DetectionModel, config: ProcessingConfiguration, device: torch.device, -) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]: +) -> tuple[list[Annotation], np.ndarray, torch.Tensor]: """Process a single audio array with detection model. Parameters @@ -689,7 +691,7 @@ def process_audio_array( Returns ------- - annotations : List[Annotation] + annotations : list[Annotation] List of annotations predicted by the model. features : np.ndarray Array of CNN features associated with each annotation. diff --git a/tests/test_data/test_annotations/test_batdetect2.py b/tests/test_data/test_annotations/test_batdetect2.py index 83ef326..446a8fd 100644 --- a/tests/test_data/test_annotations/test_batdetect2.py +++ b/tests/test_data/test_annotations/test_batdetect2.py @@ -1,7 +1,7 @@ import json import uuid from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import pytest from soundevent import data @@ -20,11 +20,11 @@ def create_legacy_file_annotation( duration: float = 5.0, time_exp: float = 1.0, class_name: str = "Myotis", - annotations: Optional[List[Dict[str, Any]]] = None, + annotations: list[dict[str, Any]] | None = None, annotated: bool = True, issues: bool = False, notes: str = "", -) -> Dict[str, Any]: +) -> dict[str, Any]: if annotations is None: annotations = [ { @@ -61,7 +61,7 @@ def create_legacy_file_annotation( @pytest.fixture def batdetect2_files_test_setup( tmp_path: Path, wav_factory -) -> Tuple[Path, Path, List[Dict[str, Any]]]: +) -> tuple[Path, Path, list[dict[str, Any]]]: """Sets up a directory structure for batdetect2 files format tests.""" audio_dir = tmp_path / "audio" audio_dir.mkdir() @@ -143,7 +143,7 @@ def batdetect2_files_test_setup( @pytest.fixture def batdetect2_merged_test_setup( tmp_path: Path, batdetect2_files_test_setup -) -> Tuple[Path, Path, List[Dict[str, Any]]]: +) -> tuple[Path, Path, list[dict[str, Any]]]: """Sets up a directory structure for batdetect2 merged file format tests.""" audio_dir, _, files_data = batdetect2_files_test_setup merged_anns_path = tmp_path / "merged_anns.json"