Removing legacy types

This commit is contained in:
mbsantiago 2026-03-17 12:53:03 +00:00
parent 8ac4f4c44d
commit 1a7c0b4b3a
32 changed files with 246 additions and 277 deletions

View File

@ -98,7 +98,6 @@ consult the API documentation in the code.
""" """
import warnings import warnings
from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
@ -272,7 +271,7 @@ def process_spectrogram(
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: ProcessingConfiguration | None = None, config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]: ) -> tuple[list[Annotation], np.ndarray]:
"""Process spectrogram with model. """Process spectrogram with model.
Parameters Parameters
@ -314,7 +313,7 @@ def process_audio(
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: ProcessingConfiguration | None = None, config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]: ) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model. """Process audio array with model.
Parameters Parameters
@ -357,7 +356,7 @@ def postprocess(
outputs: ModelOutput, outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: ProcessingConfiguration | None = None, config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]: ) -> tuple[list[Annotation], np.ndarray]:
"""Postprocess model outputs. """Postprocess model outputs.
Convert model tensor outputs to predicted bounding boxes and Convert model tensor outputs to predicted bounding boxes and

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Sequence, Tuple from typing import Sequence
import numpy as np import numpy as np
import torch import torch
@ -110,7 +110,7 @@ class BatDetect2API:
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
save_predictions: bool = True, save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[Detection]]]: ) -> tuple[dict[str, float], list[list[Detection]]]:
return evaluate( return evaluate(
self.model, self.model,
test_annotations, test_annotations,
@ -187,7 +187,7 @@ class BatDetect2API:
def process_audio( def process_audio(
self, self,
audio: np.ndarray, audio: np.ndarray,
) -> List[Detection]: ) -> list[Detection]:
spec = self.generate_spectrogram(audio) spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec) return self.process_spectrogram(spec)
@ -195,7 +195,7 @@ class BatDetect2API:
self, self,
spec: torch.Tensor, spec: torch.Tensor,
start_time: float = 0, start_time: float = 0,
) -> List[Detection]: ) -> list[Detection]:
if spec.ndim == 4 and spec.shape[0] > 1: if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.") raise ValueError("Batched spectrograms not supported.")
@ -214,7 +214,7 @@ class BatDetect2API:
def process_directory( def process_directory(
self, self,
audio_dir: data.PathLike, audio_dir: data.PathLike,
) -> List[ClipDetections]: ) -> list[ClipDetections]:
files = list(get_audio_files(audio_dir)) files = list(get_audio_files(audio_dir))
return self.process_files(files) return self.process_files(files)
@ -222,7 +222,7 @@ class BatDetect2API:
self, self,
audio_files: Sequence[data.PathLike], audio_files: Sequence[data.PathLike],
num_workers: int | None = None, num_workers: int | None = None,
) -> List[ClipDetections]: ) -> list[ClipDetections]:
return process_file_list( return process_file_list(
self.model, self.model,
audio_files, audio_files,
@ -238,7 +238,7 @@ class BatDetect2API:
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
batch_size: int | None = None, batch_size: int | None = None,
num_workers: int | None = None, num_workers: int | None = None,
) -> List[ClipDetections]: ) -> list[ClipDetections]:
return run_batch_inference( return run_batch_inference(
self.model, self.model,
clips, clips,
@ -274,7 +274,7 @@ class BatDetect2API:
def load_predictions( def load_predictions(
self, self,
path: data.PathLike, path: data.PathLike,
) -> List[ClipDetections]: ) -> list[ClipDetections]:
return self.formatter.load(path) return self.formatter.load(path)
@classmethod @classmethod

View File

@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
configuration sections. configuration sections.
""" """
from typing import Any, Type, TypeVar, Union, overload from typing import Any, Type, TypeVar, overload
import yaml import yaml
from deepmerge.merger import Merger from deepmerge.merger import Merger
@ -69,8 +69,7 @@ class BaseConfig(BaseModel):
T = TypeVar("T") T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel) T_Model = TypeVar("T_Model", bound=BaseModel)
Schema = Type[T_Model] | TypeAdapter[T]
Schema = Union[Type[T_Model], TypeAdapter[T]]
def get_object_field(obj: dict, current_key: str) -> Any: def get_object_field(obj: dict, current_key: str) -> Any:

View File

@ -1,5 +1,4 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Tuple
from soundevent import data from soundevent import data
@ -10,7 +9,7 @@ from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events( def iterate_over_sound_events(
dataset: Dataset, dataset: Dataset,
targets: TargetProtocol, targets: TargetProtocol,
) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]: ) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset. """Iterate over sound events in a dataset.
Parameters Parameters
@ -24,7 +23,7 @@ def iterate_over_sound_events(
Yields Yields
------ ------
Tuple[Optional[str], data.SoundEventAnnotation] tuple[Optional[str], data.SoundEventAnnotation]
A tuple containing: A tuple containing:
- The encoded class name (str) for the sound event, or None if it - The encoded class name (str) for the sound event, or None if it
cannot be encoded to a specific class. cannot be encoded to a specific class.

View File

@ -1,6 +1,5 @@
from typing import Literal
from pathlib import Path from pathlib import Path
from typing import Literal
from soundevent.data import PathLike from soundevent.data import PathLike

View File

@ -1,5 +1,3 @@
from typing import Tuple
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
@ -15,7 +13,7 @@ def split_dataset_by_recordings(
targets: TargetProtocol, targets: TargetProtocol,
train_size: float = 0.75, train_size: float = 0.75,
random_state: int | None = None, random_state: int | None = None,
) -> Tuple[Dataset, Dataset]: ) -> tuple[Dataset, Dataset]:
recordings = extract_recordings_df(dataset) recordings = extract_recordings_df(dataset)
sound_events = extract_sound_events_df( sound_events = extract_sound_events_df(

View File

@ -1,7 +1,5 @@
"""Post-processing of the output of the model.""" """Post-processing of the output of the model."""
from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
@ -45,7 +43,7 @@ def run_nms(
outputs: ModelOutput, outputs: ModelOutput,
params: NonMaximumSuppressionConfig, params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray, sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]: ) -> tuple[list[PredictionResults], list[np.ndarray]]:
"""Run non-maximum suppression on the output of the model. """Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension. Model outputs processed are expected to have a batch dimension.
@ -73,8 +71,8 @@ def run_nms(
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs # loop over batch to save outputs
preds: List[PredictionResults] = [] preds: list[PredictionResults] = []
feats: List[np.ndarray] = [] feats: list[np.ndarray] = []
for num_detection in range(pred_det_nms.shape[0]): for num_detection in range(pred_det_nms.shape[0]):
# get valid indices # get valid indices
inds_ord = torch.argsort(x_pos[num_detection, :]) inds_ord = torch.argsort(x_pos[num_detection, :])
@ -151,7 +149,7 @@ def run_nms(
def non_max_suppression( def non_max_suppression(
heat: torch.Tensor, heat: torch.Tensor,
kernel_size: int | Tuple[int, int], kernel_size: int | tuple[int, int],
): ):
# kernel can be an int or list/tuple # kernel can be an int or list/tuple
if isinstance(kernel_size, int): if isinstance(kernel_size, int):

View File

@ -4,12 +4,9 @@ from dataclasses import dataclass, field
from typing import ( from typing import (
Annotated, Annotated,
Callable, Callable,
Dict,
Iterable, Iterable,
List,
Literal, Literal,
Sequence, Sequence,
Tuple,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -32,7 +29,7 @@ from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]] TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot" name="top_class_plot"
@ -73,7 +70,7 @@ class PRCurve(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
y_true = [] y_true = []
y_score = [] y_score = []
num_positives = 0 num_positives = 0
@ -140,7 +137,7 @@ class ROCCurve(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
y_true = [] y_true = []
y_score = [] y_score = []
@ -223,7 +220,7 @@ class ConfusionMatrix(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
cm, labels = compute_confusion_matrix( cm, labels = compute_confusion_matrix(
clip_evaluations, clip_evaluations,
self.targets, self.targets,
@ -295,26 +292,26 @@ class ExampleClassificationPlot(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold) grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items(): for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample( true_positives: list[MatchEval] = get_binned_sample(
matches.true_positives, matches.true_positives,
n_examples=self.num_examples, n_examples=self.num_examples,
) )
false_positives: List[MatchEval] = get_binned_sample( false_positives: list[MatchEval] = get_binned_sample(
matches.false_positives, matches.false_positives,
n_examples=self.num_examples, n_examples=self.num_examples,
) )
false_negatives: List[MatchEval] = random.sample( false_negatives: list[MatchEval] = random.sample(
matches.false_negatives, matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)), k=min(self.num_examples, len(matches.false_negatives)),
) )
cross_triggers: List[MatchEval] = get_binned_sample( cross_triggers: list[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples matches.cross_triggers, n_examples=self.num_examples
) )
@ -374,16 +371,16 @@ def build_top_class_plotter(
@dataclass @dataclass
class ClassMatches: class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list) false_positives: list[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list) false_negatives: list[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list) true_positives: list[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list) cross_triggers: list[MatchEval] = field(default_factory=list)
def group_matches( def group_matches(
clip_evals: Sequence[ClipEval], clip_evals: Sequence[ClipEval],
threshold: float = 0.2, threshold: float = 0.2,
) -> Dict[str, ClassMatches]: ) -> dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches) class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals: for clip_eval in clip_evals:
@ -412,7 +409,7 @@ def group_matches(
return class_examples return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5): def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
if len(matches) < n_examples: if len(matches) < n_examples:
return matches return matches

View File

@ -2,7 +2,6 @@ import argparse
import json import json
import os import os
from collections import Counter from collections import Counter
from typing import List, Tuple
import numpy as np import numpy as np
from sklearn.model_selection import StratifiedGroupKFold from sklearn.model_selection import StratifiedGroupKFold
@ -12,8 +11,8 @@ from batdetect2 import types
def print_dataset_stats( def print_dataset_stats(
data: List[types.FileAnnotation], data: list[types.FileAnnotation],
classes_to_ignore: List[str] | None = None, classes_to_ignore: list[str] | None = None,
) -> Counter[str]: ) -> Counter[str]:
print("Num files:", len(data)) print("Num files:", len(data))
counts, _ = tu.get_class_names(data, classes_to_ignore) counts, _ = tu.get_class_names(data, classes_to_ignore)
@ -22,7 +21,7 @@ def print_dataset_stats(
return counts return counts
def load_file_names(file_name: str) -> List[str]: def load_file_names(file_name: str) -> list[str]:
if not os.path.isfile(file_name): if not os.path.isfile(file_name):
raise FileNotFoundError(f"Input file not found - {file_name}") raise FileNotFoundError(f"Input file not found - {file_name}")
@ -100,12 +99,12 @@ def parse_args():
def split_data( def split_data(
data: List[types.FileAnnotation], data: list[types.FileAnnotation],
train_file: str, train_file: str,
test_file: str, test_file: str,
n_splits: int = 5, n_splits: int = 5,
random_state: int = 0, random_state: int = 0,
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]: ) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
if train_file != "" and test_file != "": if train_file != "" and test_file != "":
# user has specifed the train / test split # user has specifed the train / test split
mapping = { mapping = {

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Sequence from typing import NamedTuple, Sequence
import torch import torch
from loguru import logger from loguru import logger
@ -29,7 +29,7 @@ class DatasetItem(NamedTuple):
class InferenceDataset(Dataset[DatasetItem]): class InferenceDataset(Dataset[DatasetItem]):
clips: List[data.Clip] clips: list[data.Clip]
def __init__( def __init__(
self, self,
@ -111,7 +111,7 @@ def build_inference_dataset(
) )
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem: def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
max_width = max(item.spec.shape[-1] for item in batch) max_width = max(item.spec.shape[-1] for item in batch)
return DatasetItem( return DatasetItem(
spec=torch.stack( spec=torch.stack(

View File

@ -26,8 +26,6 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module. is the ``build_model`` factory function exported from this module.
""" """
from typing import List
import torch import torch
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
@ -142,7 +140,7 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.targets = targets
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]: def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor. """Run the full detection pipeline on a waveform tensor.
Converts the waveform to a spectrogram, passes it through the Converts the waveform to a spectrogram, passes it through the
@ -157,7 +155,7 @@ class Model(torch.nn.Module):
Returns Returns
------- -------
List[ClipDetectionsTensor] list[ClipDetectionsTensor]
One detection tensor per clip in the batch. Each tensor encodes One detection tensor per clip in the batch. Each tensor encodes
the detected events (locations, class scores, sizes) for that the detected events (locations, class scores, sizes) for that
clip. clip.

View File

@ -23,7 +23,7 @@ output so that the output spatial dimensions always match the input spatial
dimensions. dimensions.
""" """
from typing import Annotated, Literal, Tuple, Union from typing import Annotated, Literal
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -58,6 +58,14 @@ from batdetect2.typing.models import (
EncoderProtocol, EncoderProtocol,
) )
__all__ = [
"BackboneImportConfig",
"UNetBackbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
class UNetBackboneConfig(BaseConfig): class UNetBackboneConfig(BaseConfig):
"""Configuration for a U-Net-style encoder-decoder backbone. """Configuration for a U-Net-style encoder-decoder backbone.
@ -110,15 +118,6 @@ class BackboneImportConfig(ImportConfig):
name: Literal["import"] = "import" name: Literal["import"] = "import"
__all__ = [
"BackboneImportConfig",
"UNetBackbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
class UNetBackbone(BackboneModel): class UNetBackbone(BackboneModel):
"""U-Net-style encoder-decoder backbone network. """U-Net-style encoder-decoder backbone network.
@ -262,7 +261,8 @@ class UNetBackbone(BackboneModel):
BackboneConfig = Annotated[ BackboneConfig = Annotated[
Union[UNetBackboneConfig,], Field(discriminator="name") UNetBackboneConfig | BackboneImportConfig,
Field(discriminator="name"),
] ]
@ -292,7 +292,7 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
def _pad_adjust( def _pad_adjust(
spec: torch.Tensor, spec: torch.Tensor,
factor: int = 32, factor: int = 32,
) -> Tuple[torch.Tensor, int, int]: ) -> tuple[torch.Tensor, int, int]:
"""Pad a tensor's height and width to be divisible by ``factor``. """Pad a tensor's height and width to be divisible by ``factor``.
Adds zero-padding to the bottom and right edges of the tensor so that Adds zero-padding to the bottom and right edges of the tensor so that
@ -308,7 +308,7 @@ def _pad_adjust(
Returns Returns
------- -------
Tuple[torch.Tensor, int, int] tuple[torch.Tensor, int, int]
- Padded tensor. - Padded tensor.
- Number of rows added to the height (``h_pad``). - Number of rows added to the height (``h_pad``).
- Number of columns added to the width (``w_pad``). - Number of columns added to the width (``w_pad``).

View File

@ -46,7 +46,7 @@ configuration object (one of the ``*Config`` classes exported here), using
a discriminated-union ``name`` field to dispatch to the correct class. a discriminated-union ``name`` field to dispatch to the correct class.
""" """
from typing import Annotated, List, Literal, Tuple, Union from typing import Annotated, Literal
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -687,7 +687,7 @@ class FreqCoordConvUpConfig(BaseConfig):
up_mode: str = "bilinear" up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear").""" """Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
up_scale: Tuple[int, int] = (2, 2) up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling.""" """Scaling factor for height and width during upsampling."""
@ -706,22 +706,22 @@ class FreqCoordConvUpBlock(Block):
Parameters Parameters
---------- ----------
in_channels : int in_channels
Number of channels in the input tensor (before upsampling). Number of channels in the input tensor (before upsampling).
out_channels : int out_channels
Number of output channels after the convolution. Number of output channels after the convolution.
input_height : int input_height
Height (H dimension, frequency bins) of the tensor *before* upsampling. Height (H dimension, frequency bins) of the tensor *before* upsampling.
Used to calculate the height for coordinate feature generation after Used to calculate the height for coordinate feature generation after
upsampling. upsampling.
kernel_size : int, default=3 kernel_size
Size of the square convolutional kernel. Size of the square convolutional kernel.
pad_size : int, default=1 pad_size
Padding added before convolution. Padding added before convolution.
up_mode : str, default="bilinear" up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear", Interpolation mode for upsampling (e.g., "nearest", "bilinear",
"bicubic"). "bicubic").
up_scale : Tuple[int, int], default=(2, 2) up_scale
Scaling factor for height and width during upsampling Scaling factor for height and width during upsampling
(typically (2, 2)). (typically (2, 2)).
""" """
@ -734,7 +734,7 @@ class FreqCoordConvUpBlock(Block):
kernel_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: tuple[int, int] = (2, 2),
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -824,7 +824,7 @@ class StandardConvUpConfig(BaseConfig):
up_mode: str = "bilinear" up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear").""" """Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
up_scale: Tuple[int, int] = (2, 2) up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling.""" """Scaling factor for height and width during upsampling."""
@ -839,17 +839,17 @@ class StandardConvUpBlock(Block):
Parameters Parameters
---------- ----------
in_channels : int in_channels
Number of channels in the input tensor (before upsampling). Number of channels in the input tensor (before upsampling).
out_channels : int out_channels
Number of output channels after the convolution. Number of output channels after the convolution.
kernel_size : int, default=3 kernel_size
Size of the square convolutional kernel. Size of the square convolutional kernel.
pad_size : int, default=1 pad_size
Padding added before convolution. Padding added before convolution.
up_mode : str, default="bilinear" up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear"). Interpolation mode for upsampling (e.g., "nearest", "bilinear").
up_scale : Tuple[int, int], default=(2, 2) up_scale
Scaling factor for height and width during upsampling. Scaling factor for height and width during upsampling.
""" """
@ -860,7 +860,7 @@ class StandardConvUpBlock(Block):
kernel_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: tuple[int, int] = (2, 2),
): ):
super(StandardConvUpBlock, self).__init__() super(StandardConvUpBlock, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -922,15 +922,14 @@ class StandardConvUpBlock(Block):
LayerConfig = Annotated[ LayerConfig = Annotated[
Union[ ConvConfig
ConvConfig, | BlockImportConfig
FreqCoordConvDownConfig, | FreqCoordConvDownConfig
StandardConvDownConfig, | StandardConvDownConfig
FreqCoordConvUpConfig, | FreqCoordConvUpConfig
StandardConvUpConfig, | StandardConvUpConfig
SelfAttentionConfig, | SelfAttentionConfig
"LayerGroupConfig", | "LayerGroupConfig",
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configuration models.""" """Type alias for the discriminated union of block configuration models."""
@ -952,7 +951,7 @@ class LayerGroupConfig(BaseConfig):
""" """
name: Literal["LayerGroup"] = "LayerGroup" name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig] layers: list[LayerConfig]
class LayerGroup(nn.Module): class LayerGroup(nn.Module):

View File

@ -1,5 +1,3 @@
from typing import Tuple
from matplotlib.axes import Axes from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot
@ -16,7 +14,7 @@ __all__ = [
def plot_clip_annotation( def plot_clip_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
add_points: bool = False, add_points: bool = False,
@ -50,7 +48,7 @@ def plot_clip_annotation(
def plot_anchor_points( def plot_anchor_points(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
targets: TargetProtocol, targets: TargetProtocol,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
size: int = 1, size: int = 1,
color: str = "red", color: str = "red",

View File

@ -1,4 +1,4 @@
from typing import Iterable, Tuple from typing import Iterable
from matplotlib.axes import Axes from matplotlib.axes import Axes
from soundevent import data from soundevent import data
@ -18,7 +18,7 @@ __all__ = [
def plot_clip_prediction( def plot_clip_prediction(
clip_prediction: data.ClipPrediction, clip_prediction: data.ClipPrediction,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
add_legend: bool = False, add_legend: bool = False,

View File

@ -1,5 +1,3 @@
from typing import Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from matplotlib.axes import Axes from matplotlib.axes import Axes
@ -19,7 +17,7 @@ def plot_clip(
clip: data.Clip, clip: data.Clip,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
spec_cmap: str = "gray", spec_cmap: str = "gray",

View File

@ -1,7 +1,5 @@
"""General plotting utilities.""" """General plotting utilities."""
from typing import Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
@ -14,7 +12,7 @@ __all__ = [
def create_ax( def create_ax(
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
**kwargs, **kwargs,
) -> axes.Axes: ) -> axes.Axes:
"""Create a new axis if none is provided""" """Create a new axis if none is provided"""
@ -31,7 +29,7 @@ def plot_spectrogram(
min_freq: float | None = None, min_freq: float | None = None,
max_freq: float | None = None, max_freq: float | None = None,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_colorbar: bool = False, add_colorbar: bool = False,
colorbar_kwargs: dict | None = None, colorbar_kwargs: dict | None = None,
vmin: float | None = None, vmin: float | None = None,

View File

@ -1,7 +1,5 @@
"""Plot heatmaps""" """Plot heatmaps"""
from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
from matplotlib import axes, patches from matplotlib import axes, patches
@ -14,7 +12,7 @@ from batdetect2.plotting.common import create_ax
def plot_detection_heatmap( def plot_detection_heatmap(
heatmap: torch.Tensor | np.ndarray, heatmap: torch.Tensor | np.ndarray,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10), figsize: tuple[int, int] = (10, 10),
threshold: float | None = None, threshold: float | None = None,
alpha: float = 1, alpha: float = 1,
cmap: str | Colormap = "jet", cmap: str | Colormap = "jet",
@ -50,8 +48,8 @@ def plot_detection_heatmap(
def plot_classification_heatmap( def plot_classification_heatmap(
heatmap: torch.Tensor | np.ndarray, heatmap: torch.Tensor | np.ndarray,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10), figsize: tuple[int, int] = (10, 10),
class_names: List[str] | None = None, class_names: list[str] | None = None,
threshold: float | None = 0.1, threshold: float | None = 0.1,
alpha: float = 1, alpha: float = 1,
cmap: str | Colormap = "tab20", cmap: str | Colormap = "tab20",

View File

@ -1,6 +1,6 @@
"""Plot functions to visualize detections and spectrograms.""" """Plot functions to visualize detections and spectrograms."""
from typing import List, Tuple, cast from typing import cast
import matplotlib.ticker as tick import matplotlib.ticker as tick
import numpy as np import numpy as np
@ -27,7 +27,7 @@ def spectrogram(
spec: torch.Tensor | np.ndarray, spec: torch.Tensor | np.ndarray,
config: ProcessingConfiguration | None = None, config: ProcessingConfiguration | None = None,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
cmap: str = "plasma", cmap: str = "plasma",
start_time: float = 0, start_time: float = 0,
) -> axes.Axes: ) -> axes.Axes:
@ -35,18 +35,18 @@ def spectrogram(
Parameters Parameters
---------- ----------
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. spec: Spectrogram to plot.
config (Optional[ProcessingConfiguration], optional): Configuration config: Configuration
used to compute the spectrogram. Defaults to None. If None, used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used. the default configuration will be used.
ax (Optional[axes.Axes], optional): Matplotlib axes object. ax: Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted Defaults to None. if provided, the spectrogram will be plotted
on this axes. on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size. figsize: Figure size.
Defaults to None. If `ax` is None, this will be used to create Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size. a new figure of the given size.
cmap (str, optional): Colormap to use. Defaults to "plasma". cmap: Colormap to use. Defaults to "plasma".
start_time (float, optional): Start time of the spectrogram. start_time: Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file. of a segment of a longer audio file.
@ -104,10 +104,10 @@ def spectrogram(
def spectrogram_with_detections( def spectrogram_with_detections(
spec: torch.Tensor | np.ndarray, spec: torch.Tensor | np.ndarray,
dets: List[Annotation], dets: list[Annotation],
config: ProcessingConfiguration | None = None, config: ProcessingConfiguration | None = None,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
cmap: str = "plasma", cmap: str = "plasma",
with_names: bool = True, with_names: bool = True,
start_time: float = 0, start_time: float = 0,
@ -117,21 +117,21 @@ def spectrogram_with_detections(
Parameters Parameters
---------- ----------
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. spec: Spectrogram to plot.
detections (List[Annotation]): List of detections. detections: List of detections.
config (Optional[ProcessingConfiguration], optional): Configuration config: Configuration
used to compute the spectrogram. Defaults to None. If None, used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used. the default configuration will be used.
ax (Optional[axes.Axes], optional): Matplotlib axes object. ax: Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted Defaults to None. if provided, the spectrogram will be plotted
on this axes. on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size. figsize: Figure size.
Defaults to None. If `ax` is None, this will be used to create Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size. a new figure of the given size.
cmap (str, optional): Colormap to use. Defaults to "plasma". cmap: Colormap to use. Defaults to "plasma".
with_names (bool, optional): Whether to plot the name of the with_names: Whether to plot the name of the
predicted class next to the detection. Defaults to True. predicted class next to the detection. Defaults to True.
start_time (float, optional): Start time of the spectrogram. start_time: Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file. of a segment of a longer audio file.
**kwargs: Additional keyword arguments to pass to the **kwargs: Additional keyword arguments to pass to the
@ -167,9 +167,9 @@ def spectrogram_with_detections(
def detections( def detections(
dets: List[Annotation], dets: list[Annotation],
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
with_names: bool = True, with_names: bool = True,
**kwargs, **kwargs,
) -> axes.Axes: ) -> axes.Axes:
@ -177,14 +177,14 @@ def detections(
Parameters Parameters
---------- ----------
dets (List[Annotation]): List of detections. dets: List of detections.
ax (Optional[axes.Axes], optional): Matplotlib axes object. ax: Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted Defaults to None. if provided, the spectrogram will be plotted
on this axes. on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size. figsize: Figure size.
Defaults to None. If `ax` is None, this will be used to create Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size. a new figure of the given size.
with_names (bool, optional): Whether to plot the name of the with_names: Whether to plot the name of the
predicted class next to the detection. Defaults to True. predicted class next to the detection. Defaults to True.
**kwargs: Additional keyword arguments to pass to the **kwargs: Additional keyword arguments to pass to the
`plot.detection` function. `plot.detection` function.
@ -214,7 +214,7 @@ def detections(
def detection( def detection(
det: Annotation, det: Annotation,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
linewidth: float = 1, linewidth: float = 1,
edgecolor: str = "w", edgecolor: str = "w",
facecolor: str = "none", facecolor: str = "none",
@ -224,19 +224,19 @@ def detection(
Parameters Parameters
---------- ----------
det (Annotation): Detection to plot. det: Detection to plot.
ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults ax: Matplotlib axes object. Defaults
to None. If provided, the spectrogram will be plotted on this axes. to None. If provided, the spectrogram will be plotted on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults figsize: Figure size. Defaults
to None. If `ax` is None, this will be used to create a new figure to None. If `ax` is None, this will be used to create a new figure
of the given size. of the given size.
linewidth (float, optional): Line width of the detection. linewidth: Line width of the detection.
Defaults to 1. Defaults to 1.
edgecolor (str, optional): Edge color of the detection. edgecolor: Edge color of the detection.
Defaults to "w", i.e. white. Defaults to "w", i.e. white.
facecolor (str, optional): Face color of the detection. facecolor: Face color of the detection.
Defaults to "none", i.e. transparent. Defaults to "none", i.e. transparent.
with_name (bool, optional): Whether to plot the name of the with_name: Whether to plot the name of the
predicted class next to the detection. Defaults to True. predicted class next to the detection. Defaults to True.
Returns Returns
@ -277,22 +277,22 @@ def detection(
def _compute_spec_extent( def _compute_spec_extent(
shape: Tuple[int, int], shape: tuple[int, int],
params: SpectrogramParameters, params: SpectrogramParameters,
) -> Tuple[float, float, float, float]: ) -> tuple[float, float, float, float]:
"""Compute the extent of a spectrogram. """Compute the extent of a spectrogram.
Parameters Parameters
---------- ----------
shape (Tuple[int, int]): Shape of the spectrogram. shape: Shape of the spectrogram.
The first dimension is the frequency axis and the second The first dimension is the frequency axis and the second
dimension is the time axis. dimension is the time axis.
params (SpectrogramParameters): Spectrogram parameters. params: Spectrogram parameters.
Should be the same as the ones used to compute the spectrogram. Should be the same as the ones used to compute the spectrogram.
Returns Returns
------- -------
Tuple[float, float, float, float]: Extent of the spectrogram. tuple[float, float, float, float]: Extent of the spectrogram.
The first two values are the minimum and maximum time values, The first two values are the minimum and maximum time values,
the last two values are the minimum and maximum frequency values. the last two values are the minimum and maximum frequency values.
""" """

View File

@ -1,4 +1,4 @@
from typing import Protocol, Tuple from typing import Protocol
from matplotlib.axes import Axes from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot
@ -40,7 +40,7 @@ def plot_false_positive_match(
match: MatchProtocol, match: MatchProtocol,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
@ -111,7 +111,7 @@ def plot_false_negative_match(
match: MatchProtocol, match: MatchProtocol,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
@ -171,7 +171,7 @@ def plot_true_positive_match(
match: MatchProtocol, match: MatchProtocol,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
@ -259,7 +259,7 @@ def plot_cross_trigger_match(
match: MatchProtocol, match: MatchProtocol,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,

View File

@ -1,5 +1,3 @@
from typing import Dict, Tuple
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns
from cycler import cycler from cycler import cycler
@ -34,14 +32,14 @@ def plot_pr_curve(
recall: np.ndarray, recall: np.ndarray,
thresholds: np.ndarray, thresholds: np.ndarray,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
color: str | Tuple[float, float, float] | None = None, color: str | tuple[float, float, float] | None = None,
add_labels: bool = True, add_labels: bool = True,
add_legend: bool = False, add_legend: bool = False,
marker: str | Tuple[int, int, float] | None = "o", marker: str | tuple[int, int, float] | None = "o",
markeredgecolor: str | Tuple[float, float, float] | None = None, markeredgecolor: str | tuple[float, float, float] | None = None,
markersize: float | None = None, markersize: float | None = None,
linestyle: str | Tuple[int, ...] | None = None, linestyle: str | tuple[int, ...] | None = None,
linewidth: float | None = None, linewidth: float | None = None,
label: str = "PR Curve", label: str = "PR Curve",
) -> axes.Axes: ) -> axes.Axes:
@ -76,9 +74,9 @@ def plot_pr_curve(
def plot_pr_curves( def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
include_ap: bool = False, include_ap: bool = False,
@ -119,7 +117,7 @@ def plot_threshold_precision_curve(
threshold: np.ndarray, threshold: np.ndarray,
precision: np.ndarray, precision: np.ndarray,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_labels: bool = True, add_labels: bool = True,
): ):
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -139,9 +137,9 @@ def plot_threshold_precision_curve(
def plot_threshold_precision_curves( def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
): ):
@ -177,7 +175,7 @@ def plot_threshold_recall_curve(
threshold: np.ndarray, threshold: np.ndarray,
recall: np.ndarray, recall: np.ndarray,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_labels: bool = True, add_labels: bool = True,
): ):
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -197,9 +195,9 @@ def plot_threshold_recall_curve(
def plot_threshold_recall_curves( def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
): ):
@ -236,7 +234,7 @@ def plot_roc_curve(
tpr: np.ndarray, tpr: np.ndarray,
thresholds: np.ndarray, thresholds: np.ndarray,
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_labels: bool = True, add_labels: bool = True,
) -> axes.Axes: ) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -260,9 +258,9 @@ def plot_roc_curve(
def plot_roc_curves( def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
) -> axes.Axes: ) -> axes.Axes:

View File

@ -11,8 +11,6 @@ activations that have lower scores than a local maximum. This helps prevent
multiple, overlapping detections originating from the same sound event. multiple, overlapping detections originating from the same sound event.
""" """
from typing import Tuple
import torch import torch
NMS_KERNEL_SIZE = 9 NMS_KERNEL_SIZE = 9
@ -27,7 +25,7 @@ BatDetect2.
def non_max_suppression( def non_max_suppression(
tensor: torch.Tensor, tensor: torch.Tensor,
kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE, kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap. """Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
@ -42,11 +40,11 @@ def non_max_suppression(
Parameters Parameters
---------- ----------
tensor : torch.Tensor tensor
Input tensor, typically representing a detection heatmap. Must be a Input tensor, typically representing a detection heatmap. Must be a
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying 3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
`torch.nn.functional.max_pool2d` operation. `torch.nn.functional.max_pool2d` operation.
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE kernel_size
Size of the sliding window neighborhood used to find local maxima. Size of the sliding window neighborhood used to find local maxima.
If an integer `k` is provided, a square kernel of size `(k, k)` is used. If an integer `k` is provided, a square kernel of size `(k, k)` is used.
If a tuple `(h, w)` is provided, a rectangular kernel of height `h` If a tuple `(h, w)` is provided, a rectangular kernel of height `h`

View File

@ -1,5 +1,3 @@
from typing import List, Tuple
import torch import torch
from loguru import logger from loguru import logger
@ -51,7 +49,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
max_freq: float, max_freq: float,
top_k_per_sec: int = 200, top_k_per_sec: int = 200,
detection_threshold: float = 0.01, detection_threshold: float = 0.01,
nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE, nms_kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
): ):
"""Initialize the Postprocessor.""" """Initialize the Postprocessor."""
super().__init__() super().__init__()
@ -66,8 +64,8 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def forward( def forward(
self, self,
output: ModelOutput, output: ModelOutput,
start_times: List[float] | None = None, start_times: list[float] | None = None,
) -> List[ClipDetectionsTensor]: ) -> list[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression( detection_heatmap = non_max_suppression(
output.detection_probs.detach(), output.detection_probs.detach(),
kernel_size=self.nms_kernel_size, kernel_size=self.nms_kernel_size,

View File

@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification. *geometric* aspect of target definition from *semantic* classification.
""" """
from typing import Annotated, Literal, Tuple from typing import Annotated, Literal
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field
@ -144,7 +144,7 @@ class AnchorBBoxMapper(ROITargetMapper):
Attributes Attributes
---------- ----------
dimension_names : List[str] dimension_names : list[str]
The output dimension names: `['width', 'height']`. The output dimension names: `['width', 'height']`.
anchor : Anchor anchor : Anchor
The configured anchor point type (e.g., "center", "bottom-left"). The configured anchor point type (e.g., "center", "bottom-left").
@ -177,7 +177,7 @@ class AnchorBBoxMapper(ROITargetMapper):
self.time_scale = time_scale self.time_scale = time_scale
self.frequency_scale = frequency_scale self.frequency_scale = frequency_scale
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]: def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent into an anchor position and scaled box size. """Encode a SoundEvent into an anchor position and scaled box size.
The position is determined by the configured anchor on the sound The position is determined by the configured anchor on the sound
@ -190,7 +190,7 @@ class AnchorBBoxMapper(ROITargetMapper):
Returns Returns
------- -------
Tuple[Position, Size] tuple[Position, Size]
A tuple of (anchor_position, [scaled_width, scaled_height]). A tuple of (anchor_position, [scaled_width, scaled_height]).
""" """
from soundevent import geometry from soundevent import geometry
@ -314,7 +314,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
Attributes Attributes
---------- ----------
dimension_names : List[str] dimension_names : list[str]
The output dimension names: `['left', 'bottom', 'right', 'top']`. The output dimension names: `['left', 'bottom', 'right', 'top']`.
preprocessor : PreprocessorProtocol preprocessor : PreprocessorProtocol
The spectrogram preprocessor instance. The spectrogram preprocessor instance.
@ -371,7 +371,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
Returns Returns
------- -------
Tuple[Position, Size] tuple[Position, Size]
A tuple of (peak_position, [l, b, r, t] distances). A tuple of (peak_position, [l, b, r, t] distances).
""" """
from soundevent import geometry from soundevent import geometry
@ -519,14 +519,14 @@ def _build_bounding_box(
Parameters Parameters
---------- ----------
pos : Tuple[float, float] pos
Reference position (time, frequency). Reference position (time, frequency).
duration : float duration
The required *unscaled* duration (width) of the bounding box. The required *unscaled* duration (width) of the bounding box.
bandwidth : float bandwidth
The required *unscaled* frequency bandwidth (height) of the bounding The required *unscaled* frequency bandwidth (height) of the bounding
box. box.
anchor : Anchor anchor
Specifies which part of the bounding box the input `pos` corresponds to. Specifies which part of the bounding box the input `pos` corresponds to.
Returns Returns

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Tuple from typing import Iterable
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
@ -33,20 +33,20 @@ class Targets(TargetProtocol):
Attributes Attributes
---------- ----------
class_names : List[str] class_names
An ordered list of the unique names of the specific target classes An ordered list of the unique names of the specific target classes
defined in the configuration. defined in the configuration.
generic_class_tags : List[data.Tag] generic_class_tags
A list of `soundevent.data.Tag` objects representing the configured A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches). generic class category (used when no specific class matches).
dimension_names : List[str] dimension_names
The names of the size dimensions handled by the ROI mapper The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']). (e.g., ['width', 'height']).
""" """
class_names: List[str] class_names: list[str]
detection_class_tags: List[data.Tag] detection_class_tags: list[data.Tag]
dimension_names: List[str] dimension_names: list[str]
detection_class_name: str detection_class_name: str
def __init__(self, config: TargetConfig): def __init__(self, config: TargetConfig):
@ -128,7 +128,7 @@ class Targets(TargetProtocol):
""" """
return self._encode_fn(sound_event) return self._encode_fn(sound_event)
def decode_class(self, class_label: str) -> List[data.Tag]: def decode_class(self, class_label: str) -> list[data.Tag]:
"""Decode a predicted class name back into representative tags. """Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or Uses the configured mapping (based on `TargetClass.output_tags` or
@ -142,7 +142,7 @@ class Targets(TargetProtocol):
Returns Returns
------- -------
List[data.Tag] list[data.Tag]
The list of tags corresponding to the input class name. The list of tags corresponding to the input class name.
""" """
return self._decode_fn(class_label) return self._decode_fn(class_label)
@ -161,7 +161,7 @@ class Targets(TargetProtocol):
Returns Returns
------- -------
Tuple[float, float] tuple[float, float]
The reference position `(time, frequency)`. The reference position `(time, frequency)`.
Raises Raises
@ -192,9 +192,9 @@ class Targets(TargetProtocol):
Parameters Parameters
---------- ----------
pos : Tuple[float, float] pos
The reference position `(time, frequency)`. The reference position `(time, frequency)`.
dims : np.ndarray dims
NumPy array with size dimensions (e.g., from model prediction), NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`. matching the order in `self.dimension_names`.
@ -292,7 +292,7 @@ def load_targets(
def iterate_encoded_sound_events( def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation], sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol, targets: TargetProtocol,
) -> Iterable[Tuple[str | None, Position, Size]]: ) -> Iterable[tuple[str | None, Position, Size]]:
for sound_event in sound_events: for sound_event in sound_events:
if not targets.filter(sound_event): if not targets.filter(sound_event):
continue continue

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING
import lightning as L import lightning as L
import torch import torch
@ -97,7 +97,7 @@ class TrainingModule(L.LightningModule):
def load_model_from_checkpoint( def load_model_from_checkpoint(
path: PathLike, path: PathLike,
) -> Tuple[Model, "BatDetect2Config"]: ) -> tuple[Model, "BatDetect2Config"]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config return module.model, module.config

View File

@ -1,6 +1,6 @@
"""Types used in the code base.""" """Types used in the code base."""
from typing import Any, List, NamedTuple, TypedDict from typing import Any, NamedTuple, TypedDict
import numpy as np import numpy as np
import torch import torch
@ -86,7 +86,7 @@ class ModelParameters(TypedDict):
resize_factor: float resize_factor: float
"""Resize factor.""" """Resize factor."""
class_names: List[str] class_names: list[str]
"""Class names. """Class names.
The model is trained to detect these classes. The model is trained to detect these classes.
@ -158,7 +158,7 @@ class FileAnnotation(TypedDict):
notes: str notes: str
"""Notes of file.""" """Notes of file."""
annotation: List[Annotation] annotation: list[Annotation]
"""List of annotations.""" """List of annotations."""
@ -168,26 +168,26 @@ class RunResults(TypedDict):
pred_dict: FileAnnotation pred_dict: FileAnnotation
"""Predictions in the format expected by the annotation tool.""" """Predictions in the format expected by the annotation tool."""
spec_feats: NotRequired[List[np.ndarray]] spec_feats: NotRequired[list[np.ndarray]]
"""Spectrogram features.""" """Spectrogram features."""
spec_feat_names: NotRequired[List[str]] spec_feat_names: NotRequired[list[str]]
"""Spectrogram feature names.""" """Spectrogram feature names."""
cnn_feats: NotRequired[List[np.ndarray]] cnn_feats: NotRequired[list[np.ndarray]]
"""CNN features.""" """CNN features."""
cnn_feat_names: NotRequired[List[str]] cnn_feat_names: NotRequired[list[str]]
"""CNN feature names.""" """CNN feature names."""
spec_slices: NotRequired[List[np.ndarray]] spec_slices: NotRequired[list[np.ndarray]]
"""Spectrogram slices.""" """Spectrogram slices."""
class ResultParams(TypedDict): class ResultParams(TypedDict):
"""Result parameters.""" """Result parameters."""
class_names: List[str] class_names: list[str]
"""Class names.""" """Class names."""
spec_features: bool spec_features: bool
@ -234,7 +234,7 @@ class ProcessingConfiguration(TypedDict):
scale_raw_audio: bool scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1.""" """Whether to scale the raw audio to be between -1 and 1."""
class_names: List[str] class_names: list[str]
"""Names of the classes the model can detect.""" """Names of the classes the model can detect."""
detection_threshold: float detection_threshold: float
@ -466,7 +466,7 @@ class FeatureExtractionParameters(TypedDict):
class HeatmapParameters(TypedDict): class HeatmapParameters(TypedDict):
"""Parameters that control the heatmap generation function.""" """Parameters that control the heatmap generation function."""
class_names: List[str] class_names: list[str]
fft_win_length: float fft_win_length: float
"""Length of the FFT window in seconds.""" """Length of the FFT window in seconds."""
@ -553,15 +553,15 @@ class AudioLoaderAnnotationGroup(TypedDict):
individual_ids: np.ndarray individual_ids: np.ndarray
x_inds: np.ndarray x_inds: np.ndarray
y_inds: np.ndarray y_inds: np.ndarray
annotation: List[Annotation] annotation: list[Annotation]
annotated: bool annotated: bool
class_id_file: int class_id_file: int
"""ID of the class of the file.""" """ID of the class of the file."""
class AudioLoaderParameters(TypedDict): class AudioLoaderParameters(TypedDict):
class_names: List[str] class_names: list[str]
classes_to_ignore: List[str] classes_to_ignore: list[str]
target_samp_rate: int target_samp_rate: int
scale_raw_audio: bool scale_raw_audio: bool
fft_win_length: float fft_win_length: float

View File

@ -1,12 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Dict,
Generic, Generic,
Iterable, Iterable,
List,
Protocol, Protocol,
Sequence, Sequence,
Tuple,
TypeVar, TypeVar,
) )
@ -33,7 +30,7 @@ class MatchEvaluation:
gt_geometry: data.Geometry | None gt_geometry: data.Geometry | None
pred_score: float pred_score: float
pred_class_scores: Dict[str, float] pred_class_scores: dict[str, float]
pred_geometry: data.Geometry | None pred_geometry: data.Geometry | None
affinity: float affinity: float
@ -66,7 +63,7 @@ class MatchEvaluation:
@dataclass @dataclass
class ClipMatches: class ClipMatches:
clip: data.Clip clip: data.Clip
matches: List[MatchEvaluation] matches: list[MatchEvaluation]
class MatcherProtocol(Protocol): class MatcherProtocol(Protocol):
@ -75,7 +72,7 @@ class MatcherProtocol(Protocol):
ground_truth: Sequence[data.Geometry], ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry], predictions: Sequence[data.Geometry],
scores: Sequence[float], scores: Sequence[float],
) -> Iterable[Tuple[int | None, int | None, float]]: ... ) -> Iterable[tuple[int | None, int | None, float]]: ...
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
@ -94,7 +91,7 @@ class MetricsProtocol(Protocol):
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]], predictions: Sequence[Sequence[Detection]],
) -> Dict[str, float]: ... ) -> dict[str, float]: ...
class PlotterProtocol(Protocol): class PlotterProtocol(Protocol):
@ -102,7 +99,7 @@ class PlotterProtocol(Protocol):
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]], predictions: Sequence[Sequence[Detection]],
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[tuple[str, Figure]]: ...
EvaluationOutput = TypeVar("EvaluationOutput") EvaluationOutput = TypeVar("EvaluationOutput")
@ -119,8 +116,8 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
def compute_metrics( def compute_metrics(
self, eval_outputs: EvaluationOutput self, eval_outputs: EvaluationOutput
) -> Dict[str, float]: ... ) -> dict[str, float]: ...
def generate_plots( def generate_plots(
self, eval_outputs: EvaluationOutput self, eval_outputs: EvaluationOutput
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[tuple[str, Figure]]: ...

View File

@ -1,4 +1,4 @@
from typing import Callable, NamedTuple, Protocol, Tuple from typing import Callable, NamedTuple, Protocol
import torch import torch
from soundevent import data from soundevent import data
@ -52,7 +52,7 @@ steps, and returns the final `Heatmaps` used for model training.
Augmentation = Callable[ Augmentation = Callable[
[torch.Tensor, data.ClipAnnotation], [torch.Tensor, data.ClipAnnotation],
Tuple[torch.Tensor, data.ClipAnnotation], tuple[torch.Tensor, data.ClipAnnotation],
] ]

View File

@ -1,5 +1,4 @@
import warnings import warnings
from typing import Tuple
import librosa import librosa
import librosa.core.spectrum import librosa.core.spectrum
@ -147,7 +146,7 @@ def load_audio(
target_samp_rate: int, target_samp_rate: int,
scale: bool = False, scale: bool = False,
max_duration: float | None = None, max_duration: float | None = None,
) -> Tuple[int, np.ndarray]: ) -> tuple[int, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate. """Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration. The audio is also scaled to [-1, 1] and clipped to the maximum duration.

View File

@ -1,6 +1,6 @@
import json import json
import os import os
from typing import Any, Iterator, List, Tuple from typing import Any, Iterator
import librosa import librosa
import numpy as np import numpy as np
@ -60,7 +60,7 @@ def get_default_bd_args():
return args return args
def list_audio_files(ip_dir: str) -> List[str]: def list_audio_files(ip_dir: str) -> list[str]:
"""Get all audio files in directory. """Get all audio files in directory.
Args: Args:
@ -86,7 +86,7 @@ def load_model(
load_weights: bool = True, load_weights: bool = True,
device: torch.device | str | None = None, device: torch.device | str | None = None,
weights_only: bool = True, weights_only: bool = True,
) -> Tuple[DetectionModel, ModelParameters]: ) -> tuple[DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
Args: Args:
@ -185,26 +185,28 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
def get_annotations_from_preds( def get_annotations_from_preds(
predictions: PredictionResults, predictions: PredictionResults,
class_names: List[str], class_names: list[str],
) -> List[Annotation]: ) -> list[Annotation]:
"""Get list of annotations from predictions.""" """Get list of annotations from predictions."""
# Get the best class prediction probability and index for each detection # Get the best class prediction probability and index for each detection
class_prob_best = predictions["class_probs"].max(0) class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0) class_ind_best = predictions["class_probs"].argmax(0)
# Pack the results into a list of dictionaries # Pack the results into a list of dictionaries
annotations: List[Annotation] = [ annotations: list[Annotation] = [
{ Annotation(
"start_time": round(float(start_time), 4), {
"end_time": round(float(end_time), 4), "start_time": round(float(start_time), 4),
"low_freq": int(low_freq), "end_time": round(float(end_time), 4),
"high_freq": int(high_freq), "low_freq": int(low_freq),
"class": str(class_names[class_index]), "high_freq": int(high_freq),
"class_prob": round(float(class_prob), 3), "class": str(class_names[class_index]),
"det_prob": round(float(det_prob), 3), "class_prob": round(float(class_prob), 3),
"individual": "-1", "det_prob": round(float(det_prob), 3),
"event": "Echolocation", "individual": "-1",
} "event": "Echolocation",
}
)
for ( for (
start_time, start_time,
end_time, end_time,
@ -232,7 +234,7 @@ def format_single_result(
time_exp: float, time_exp: float,
duration: float, duration: float,
predictions: PredictionResults, predictions: PredictionResults,
class_names: List[str], class_names: list[str],
) -> FileAnnotation: ) -> FileAnnotation:
"""Format results into the format expected by the annotation tool. """Format results into the format expected by the annotation tool.
@ -315,9 +317,9 @@ def convert_results(
] ]
# combine into final results dictionary # combine into final results dictionary
results: RunResults = { results: RunResults = RunResults({ # type: ignore
"pred_dict": pred_dict, "pred_dict": pred_dict,
} })
# add spectrogram features if they exist # add spectrogram features if they exist
if len(spec_feats) > 0 and params["spec_features"]: if len(spec_feats) > 0 and params["spec_features"]:
@ -413,7 +415,7 @@ def compute_spectrogram(
sampling_rate: int, sampling_rate: int,
params: SpectrogramParameters, params: SpectrogramParameters,
device: torch.device, device: torch.device,
) -> Tuple[float, torch.Tensor]: ) -> tuple[float, torch.Tensor]:
"""Compute a spectrogram from an audio array. """Compute a spectrogram from an audio array.
Will pad the audio array so that it is evenly divisible by the Will pad the audio array so that it is evenly divisible by the
@ -475,7 +477,7 @@ def iterate_over_chunks(
audio: np.ndarray, audio: np.ndarray,
samplerate: float, samplerate: float,
chunk_size: float, chunk_size: float,
) -> Iterator[Tuple[float, np.ndarray]]: ) -> Iterator[tuple[float, np.ndarray]]:
"""Iterate over audio in chunks of size chunk_size. """Iterate over audio in chunks of size chunk_size.
Parameters Parameters
@ -509,7 +511,7 @@ def _process_spectrogram(
samplerate: float, samplerate: float,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[PredictionResults, np.ndarray]: ) -> tuple[PredictionResults, np.ndarray]:
# evaluate model # evaluate model
with torch.no_grad(): with torch.no_grad():
outputs = model(spec) outputs = model(spec)
@ -546,7 +548,7 @@ def postprocess_model_outputs(
outputs: ModelOutput, outputs: ModelOutput,
samp_rate: int, samp_rate: int,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[List[Annotation], np.ndarray]: ) -> tuple[list[Annotation], np.ndarray]:
# run non-max suppression # run non-max suppression
pred_nms_list, features = pp.run_nms( pred_nms_list, features = pp.run_nms(
outputs, outputs,
@ -585,7 +587,7 @@ def process_spectrogram(
samplerate: int, samplerate: int,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[List[Annotation], np.ndarray]: ) -> tuple[list[Annotation], np.ndarray]:
"""Process a spectrogram with detection model. """Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model. Will run non-maximum suppression on the output of the model.
@ -604,9 +606,9 @@ def process_spectrogram(
Returns Returns
------- -------
detections: List[Annotation] detections
List of detections predicted by the model. List of detections predicted by the model.
features : np.ndarray features
An array of CNN features associated with each annotation. An array of CNN features associated with each annotation.
The array is of shape (num_detections, num_features). The array is of shape (num_detections, num_features).
Is empty if `config["cnn_features"]` is False. Is empty if `config["cnn_features"]` is False.
@ -632,7 +634,7 @@ def _process_audio_array(
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]: ) -> tuple[PredictionResults, np.ndarray, torch.Tensor]:
# load audio file and compute spectrogram # load audio file and compute spectrogram
_, spec = compute_spectrogram( _, spec = compute_spectrogram(
audio, audio,
@ -669,7 +671,7 @@ def process_audio_array(
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]: ) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
"""Process a single audio array with detection model. """Process a single audio array with detection model.
Parameters Parameters
@ -689,7 +691,7 @@ def process_audio_array(
Returns Returns
------- -------
annotations : List[Annotation] annotations : list[Annotation]
List of annotations predicted by the model. List of annotations predicted by the model.
features : np.ndarray features : np.ndarray
Array of CNN features associated with each annotation. Array of CNN features associated with each annotation.

View File

@ -1,7 +1,7 @@
import json import json
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any
import pytest import pytest
from soundevent import data from soundevent import data
@ -20,11 +20,11 @@ def create_legacy_file_annotation(
duration: float = 5.0, duration: float = 5.0,
time_exp: float = 1.0, time_exp: float = 1.0,
class_name: str = "Myotis", class_name: str = "Myotis",
annotations: Optional[List[Dict[str, Any]]] = None, annotations: list[dict[str, Any]] | None = None,
annotated: bool = True, annotated: bool = True,
issues: bool = False, issues: bool = False,
notes: str = "", notes: str = "",
) -> Dict[str, Any]: ) -> dict[str, Any]:
if annotations is None: if annotations is None:
annotations = [ annotations = [
{ {
@ -61,7 +61,7 @@ def create_legacy_file_annotation(
@pytest.fixture @pytest.fixture
def batdetect2_files_test_setup( def batdetect2_files_test_setup(
tmp_path: Path, wav_factory tmp_path: Path, wav_factory
) -> Tuple[Path, Path, List[Dict[str, Any]]]: ) -> tuple[Path, Path, list[dict[str, Any]]]:
"""Sets up a directory structure for batdetect2 files format tests.""" """Sets up a directory structure for batdetect2 files format tests."""
audio_dir = tmp_path / "audio" audio_dir = tmp_path / "audio"
audio_dir.mkdir() audio_dir.mkdir()
@ -143,7 +143,7 @@ def batdetect2_files_test_setup(
@pytest.fixture @pytest.fixture
def batdetect2_merged_test_setup( def batdetect2_merged_test_setup(
tmp_path: Path, batdetect2_files_test_setup tmp_path: Path, batdetect2_files_test_setup
) -> Tuple[Path, Path, List[Dict[str, Any]]]: ) -> tuple[Path, Path, list[dict[str, Any]]]:
"""Sets up a directory structure for batdetect2 merged file format tests.""" """Sets up a directory structure for batdetect2 merged file format tests."""
audio_dir, _, files_data = batdetect2_files_test_setup audio_dir, _, files_data = batdetect2_files_test_setup
merged_anns_path = tmp_path / "merged_anns.json" merged_anns_path = tmp_path / "merged_anns.json"