Removing legacy types

This commit is contained in:
mbsantiago 2026-03-17 12:53:03 +00:00
parent 8ac4f4c44d
commit 1a7c0b4b3a
32 changed files with 246 additions and 277 deletions

View File

@ -98,7 +98,6 @@ consult the API documentation in the code.
"""
import warnings
from typing import List, Tuple
import numpy as np
import torch
@ -272,7 +271,7 @@ def process_spectrogram(
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]:
) -> tuple[list[Annotation], np.ndarray]:
"""Process spectrogram with model.
Parameters
@ -314,7 +313,7 @@ def process_audio(
model: DetectionModel = MODEL,
config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model.
Parameters
@ -357,7 +356,7 @@ def postprocess(
outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]:
) -> tuple[list[Annotation], np.ndarray]:
"""Postprocess model outputs.
Convert model tensor outputs to predicted bounding boxes and

View File

@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
from typing import Sequence
import numpy as np
import torch
@ -110,7 +110,7 @@ class BatDetect2API:
experiment_name: str | None = None,
run_name: str | None = None,
save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[Detection]]]:
) -> tuple[dict[str, float], list[list[Detection]]]:
return evaluate(
self.model,
test_annotations,
@ -187,7 +187,7 @@ class BatDetect2API:
def process_audio(
self,
audio: np.ndarray,
) -> List[Detection]:
) -> list[Detection]:
spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec)
@ -195,7 +195,7 @@ class BatDetect2API:
self,
spec: torch.Tensor,
start_time: float = 0,
) -> List[Detection]:
) -> list[Detection]:
if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.")
@ -214,7 +214,7 @@ class BatDetect2API:
def process_directory(
self,
audio_dir: data.PathLike,
) -> List[ClipDetections]:
) -> list[ClipDetections]:
files = list(get_audio_files(audio_dir))
return self.process_files(files)
@ -222,7 +222,7 @@ class BatDetect2API:
self,
audio_files: Sequence[data.PathLike],
num_workers: int | None = None,
) -> List[ClipDetections]:
) -> list[ClipDetections]:
return process_file_list(
self.model,
audio_files,
@ -238,7 +238,7 @@ class BatDetect2API:
clips: Sequence[data.Clip],
batch_size: int | None = None,
num_workers: int | None = None,
) -> List[ClipDetections]:
) -> list[ClipDetections]:
return run_batch_inference(
self.model,
clips,
@ -274,7 +274,7 @@ class BatDetect2API:
def load_predictions(
self,
path: data.PathLike,
) -> List[ClipDetections]:
) -> list[ClipDetections]:
return self.formatter.load(path)
@classmethod

View File

@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
configuration sections.
"""
from typing import Any, Type, TypeVar, Union, overload
from typing import Any, Type, TypeVar, overload
import yaml
from deepmerge.merger import Merger
@ -69,8 +69,7 @@ class BaseConfig(BaseModel):
T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel)
Schema = Union[Type[T_Model], TypeAdapter[T]]
Schema = Type[T_Model] | TypeAdapter[T]
def get_object_field(obj: dict, current_key: str) -> Any:

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import Tuple
from soundevent import data
@ -10,7 +9,7 @@ from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events(
dataset: Dataset,
targets: TargetProtocol,
) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]:
) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset.
Parameters
@ -24,7 +23,7 @@ def iterate_over_sound_events(
Yields
------
Tuple[Optional[str], data.SoundEventAnnotation]
tuple[Optional[str], data.SoundEventAnnotation]
A tuple containing:
- The encoded class name (str) for the sound event, or None if it
cannot be encoded to a specific class.

View File

@ -1,6 +1,5 @@
from typing import Literal
from pathlib import Path
from typing import Literal
from soundevent.data import PathLike

View File

@ -1,5 +1,3 @@
from typing import Tuple
from sklearn.model_selection import train_test_split
from batdetect2.data.datasets import Dataset
@ -15,7 +13,7 @@ def split_dataset_by_recordings(
targets: TargetProtocol,
train_size: float = 0.75,
random_state: int | None = None,
) -> Tuple[Dataset, Dataset]:
) -> tuple[Dataset, Dataset]:
recordings = extract_recordings_df(dataset)
sound_events = extract_sound_events_df(

View File

@ -1,7 +1,5 @@
"""Post-processing of the output of the model."""
from typing import List, Tuple
import numpy as np
import torch
from torch import nn
@ -45,7 +43,7 @@ def run_nms(
outputs: ModelOutput,
params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
) -> tuple[list[PredictionResults], list[np.ndarray]]:
"""Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension.
@ -73,8 +71,8 @@ def run_nms(
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs
preds: List[PredictionResults] = []
feats: List[np.ndarray] = []
preds: list[PredictionResults] = []
feats: list[np.ndarray] = []
for num_detection in range(pred_det_nms.shape[0]):
# get valid indices
inds_ord = torch.argsort(x_pos[num_detection, :])
@ -151,7 +149,7 @@ def run_nms(
def non_max_suppression(
heat: torch.Tensor,
kernel_size: int | Tuple[int, int],
kernel_size: int | tuple[int, int],
):
# kernel can be an int or list/tuple
if isinstance(kernel_size, int):

View File

@ -4,12 +4,9 @@ from dataclasses import dataclass, field
from typing import (
Annotated,
Callable,
Dict,
Iterable,
List,
Literal,
Sequence,
Tuple,
)
import matplotlib.pyplot as plt
@ -32,7 +29,7 @@ from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
@ -73,7 +70,7 @@ class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
@ -140,7 +137,7 @@ class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
y_true = []
y_score = []
@ -223,7 +220,7 @@ class ConfusionMatrix(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
cm, labels = compute_confusion_matrix(
clip_evaluations,
self.targets,
@ -295,26 +292,26 @@ class ExampleClassificationPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample(
true_positives: list[MatchEval] = get_binned_sample(
matches.true_positives,
n_examples=self.num_examples,
)
false_positives: List[MatchEval] = get_binned_sample(
false_positives: list[MatchEval] = get_binned_sample(
matches.false_positives,
n_examples=self.num_examples,
)
false_negatives: List[MatchEval] = random.sample(
false_negatives: list[MatchEval] = random.sample(
matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)),
)
cross_triggers: List[MatchEval] = get_binned_sample(
cross_triggers: list[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples
)
@ -374,16 +371,16 @@ def build_top_class_plotter(
@dataclass
class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list)
false_positives: list[MatchEval] = field(default_factory=list)
false_negatives: list[MatchEval] = field(default_factory=list)
true_positives: list[MatchEval] = field(default_factory=list)
cross_triggers: list[MatchEval] = field(default_factory=list)
def group_matches(
clip_evals: Sequence[ClipEval],
threshold: float = 0.2,
) -> Dict[str, ClassMatches]:
) -> dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals:
@ -412,7 +409,7 @@ def group_matches(
return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
if len(matches) < n_examples:
return matches

View File

@ -2,7 +2,6 @@ import argparse
import json
import os
from collections import Counter
from typing import List, Tuple
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
@ -12,8 +11,8 @@ from batdetect2 import types
def print_dataset_stats(
data: List[types.FileAnnotation],
classes_to_ignore: List[str] | None = None,
data: list[types.FileAnnotation],
classes_to_ignore: list[str] | None = None,
) -> Counter[str]:
print("Num files:", len(data))
counts, _ = tu.get_class_names(data, classes_to_ignore)
@ -22,7 +21,7 @@ def print_dataset_stats(
return counts
def load_file_names(file_name: str) -> List[str]:
def load_file_names(file_name: str) -> list[str]:
if not os.path.isfile(file_name):
raise FileNotFoundError(f"Input file not found - {file_name}")
@ -100,12 +99,12 @@ def parse_args():
def split_data(
data: List[types.FileAnnotation],
data: list[types.FileAnnotation],
train_file: str,
test_file: str,
n_splits: int = 5,
random_state: int = 0,
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
if train_file != "" and test_file != "":
# user has specifed the train / test split
mapping = {

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Sequence
from typing import NamedTuple, Sequence
import torch
from loguru import logger
@ -29,7 +29,7 @@ class DatasetItem(NamedTuple):
class InferenceDataset(Dataset[DatasetItem]):
clips: List[data.Clip]
clips: list[data.Clip]
def __init__(
self,
@ -111,7 +111,7 @@ def build_inference_dataset(
)
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
max_width = max(item.spec.shape[-1] for item in batch)
return DatasetItem(
spec=torch.stack(

View File

@ -26,8 +26,6 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module.
"""
from typing import List
import torch
from batdetect2.models.backbones import (
@ -142,7 +140,7 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor
self.targets = targets
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
Converts the waveform to a spectrogram, passes it through the
@ -157,7 +155,7 @@ class Model(torch.nn.Module):
Returns
-------
List[ClipDetectionsTensor]
list[ClipDetectionsTensor]
One detection tensor per clip in the batch. Each tensor encodes
the detected events (locations, class scores, sizes) for that
clip.

View File

@ -23,7 +23,7 @@ output so that the output spatial dimensions always match the input spatial
dimensions.
"""
from typing import Annotated, Literal, Tuple, Union
from typing import Annotated, Literal
import torch
import torch.nn.functional as F
@ -58,6 +58,14 @@ from batdetect2.typing.models import (
EncoderProtocol,
)
__all__ = [
"BackboneImportConfig",
"UNetBackbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
class UNetBackboneConfig(BaseConfig):
"""Configuration for a U-Net-style encoder-decoder backbone.
@ -110,15 +118,6 @@ class BackboneImportConfig(ImportConfig):
name: Literal["import"] = "import"
__all__ = [
"BackboneImportConfig",
"UNetBackbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
class UNetBackbone(BackboneModel):
"""U-Net-style encoder-decoder backbone network.
@ -262,7 +261,8 @@ class UNetBackbone(BackboneModel):
BackboneConfig = Annotated[
Union[UNetBackboneConfig,], Field(discriminator="name")
UNetBackboneConfig | BackboneImportConfig,
Field(discriminator="name"),
]
@ -292,7 +292,7 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
def _pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
) -> tuple[torch.Tensor, int, int]:
"""Pad a tensor's height and width to be divisible by ``factor``.
Adds zero-padding to the bottom and right edges of the tensor so that
@ -308,7 +308,7 @@ def _pad_adjust(
Returns
-------
Tuple[torch.Tensor, int, int]
tuple[torch.Tensor, int, int]
- Padded tensor.
- Number of rows added to the height (``h_pad``).
- Number of columns added to the width (``w_pad``).

View File

@ -46,7 +46,7 @@ configuration object (one of the ``*Config`` classes exported here), using
a discriminated-union ``name`` field to dispatch to the correct class.
"""
from typing import Annotated, List, Literal, Tuple, Union
from typing import Annotated, Literal
import torch
import torch.nn.functional as F
@ -687,7 +687,7 @@ class FreqCoordConvUpConfig(BaseConfig):
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
up_scale: Tuple[int, int] = (2, 2)
up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
@ -706,22 +706,22 @@ class FreqCoordConvUpBlock(Block):
Parameters
----------
in_channels : int
in_channels
Number of channels in the input tensor (before upsampling).
out_channels : int
out_channels
Number of output channels after the convolution.
input_height : int
input_height
Height (H dimension, frequency bins) of the tensor *before* upsampling.
Used to calculate the height for coordinate feature generation after
upsampling.
kernel_size : int, default=3
kernel_size
Size of the square convolutional kernel.
pad_size : int, default=1
pad_size
Padding added before convolution.
up_mode : str, default="bilinear"
up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
"bicubic").
up_scale : Tuple[int, int], default=(2, 2)
up_scale
Scaling factor for height and width during upsampling
(typically (2, 2)).
"""
@ -734,7 +734,7 @@ class FreqCoordConvUpBlock(Block):
kernel_size: int = 3,
pad_size: int = 1,
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
up_scale: tuple[int, int] = (2, 2),
):
super().__init__()
self.in_channels = in_channels
@ -824,7 +824,7 @@ class StandardConvUpConfig(BaseConfig):
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
up_scale: Tuple[int, int] = (2, 2)
up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
@ -839,17 +839,17 @@ class StandardConvUpBlock(Block):
Parameters
----------
in_channels : int
in_channels
Number of channels in the input tensor (before upsampling).
out_channels : int
out_channels
Number of output channels after the convolution.
kernel_size : int, default=3
kernel_size
Size of the square convolutional kernel.
pad_size : int, default=1
pad_size
Padding added before convolution.
up_mode : str, default="bilinear"
up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
up_scale : Tuple[int, int], default=(2, 2)
up_scale
Scaling factor for height and width during upsampling.
"""
@ -860,7 +860,7 @@ class StandardConvUpBlock(Block):
kernel_size: int = 3,
pad_size: int = 1,
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
up_scale: tuple[int, int] = (2, 2),
):
super(StandardConvUpBlock, self).__init__()
self.in_channels = in_channels
@ -922,15 +922,14 @@ class StandardConvUpBlock(Block):
LayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
SelfAttentionConfig,
"LayerGroupConfig",
],
ConvConfig
| BlockImportConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| SelfAttentionConfig
| "LayerGroupConfig",
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
@ -952,7 +951,7 @@ class LayerGroupConfig(BaseConfig):
"""
name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig]
layers: list[LayerConfig]
class LayerGroup(nn.Module):

View File

@ -1,5 +1,3 @@
from typing import Tuple
from matplotlib.axes import Axes
from soundevent import data, plot
@ -16,7 +14,7 @@ __all__ = [
def plot_clip_annotation(
clip_annotation: data.ClipAnnotation,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
add_points: bool = False,
@ -50,7 +48,7 @@ def plot_clip_annotation(
def plot_anchor_points(
clip_annotation: data.ClipAnnotation,
targets: TargetProtocol,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
size: int = 1,
color: str = "red",

View File

@ -1,4 +1,4 @@
from typing import Iterable, Tuple
from typing import Iterable
from matplotlib.axes import Axes
from soundevent import data
@ -18,7 +18,7 @@ __all__ = [
def plot_clip_prediction(
clip_prediction: data.ClipPrediction,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
add_legend: bool = False,

View File

@ -1,5 +1,3 @@
from typing import Tuple
import matplotlib.pyplot as plt
import torch
from matplotlib.axes import Axes
@ -19,7 +17,7 @@ def plot_clip(
clip: data.Clip,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
spec_cmap: str = "gray",

View File

@ -1,7 +1,5 @@
"""General plotting utilities."""
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
@ -14,7 +12,7 @@ __all__ = [
def create_ax(
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
**kwargs,
) -> axes.Axes:
"""Create a new axis if none is provided"""
@ -31,7 +29,7 @@ def plot_spectrogram(
min_freq: float | None = None,
max_freq: float | None = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_colorbar: bool = False,
colorbar_kwargs: dict | None = None,
vmin: float | None = None,

View File

@ -1,7 +1,5 @@
"""Plot heatmaps"""
from typing import List, Tuple
import numpy as np
import torch
from matplotlib import axes, patches
@ -14,7 +12,7 @@ from batdetect2.plotting.common import create_ax
def plot_detection_heatmap(
heatmap: torch.Tensor | np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10),
figsize: tuple[int, int] = (10, 10),
threshold: float | None = None,
alpha: float = 1,
cmap: str | Colormap = "jet",
@ -50,8 +48,8 @@ def plot_detection_heatmap(
def plot_classification_heatmap(
heatmap: torch.Tensor | np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10),
class_names: List[str] | None = None,
figsize: tuple[int, int] = (10, 10),
class_names: list[str] | None = None,
threshold: float | None = 0.1,
alpha: float = 1,
cmap: str | Colormap = "tab20",

View File

@ -1,6 +1,6 @@
"""Plot functions to visualize detections and spectrograms."""
from typing import List, Tuple, cast
from typing import cast
import matplotlib.ticker as tick
import numpy as np
@ -27,7 +27,7 @@ def spectrogram(
spec: torch.Tensor | np.ndarray,
config: ProcessingConfiguration | None = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
cmap: str = "plasma",
start_time: float = 0,
) -> axes.Axes:
@ -35,18 +35,18 @@ def spectrogram(
Parameters
----------
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
config (Optional[ProcessingConfiguration], optional): Configuration
spec: Spectrogram to plot.
config: Configuration
used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used.
ax (Optional[axes.Axes], optional): Matplotlib axes object.
ax: Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted
on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size.
figsize: Figure size.
Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size.
cmap (str, optional): Colormap to use. Defaults to "plasma".
start_time (float, optional): Start time of the spectrogram.
cmap: Colormap to use. Defaults to "plasma".
start_time: Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file.
@ -104,10 +104,10 @@ def spectrogram(
def spectrogram_with_detections(
spec: torch.Tensor | np.ndarray,
dets: List[Annotation],
dets: list[Annotation],
config: ProcessingConfiguration | None = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
cmap: str = "plasma",
with_names: bool = True,
start_time: float = 0,
@ -117,21 +117,21 @@ def spectrogram_with_detections(
Parameters
----------
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
detections (List[Annotation]): List of detections.
config (Optional[ProcessingConfiguration], optional): Configuration
spec: Spectrogram to plot.
detections: List of detections.
config: Configuration
used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used.
ax (Optional[axes.Axes], optional): Matplotlib axes object.
ax: Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted
on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size.
figsize: Figure size.
Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size.
cmap (str, optional): Colormap to use. Defaults to "plasma".
with_names (bool, optional): Whether to plot the name of the
cmap: Colormap to use. Defaults to "plasma".
with_names: Whether to plot the name of the
predicted class next to the detection. Defaults to True.
start_time (float, optional): Start time of the spectrogram.
start_time: Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file.
**kwargs: Additional keyword arguments to pass to the
@ -167,9 +167,9 @@ def spectrogram_with_detections(
def detections(
dets: List[Annotation],
dets: list[Annotation],
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
with_names: bool = True,
**kwargs,
) -> axes.Axes:
@ -177,14 +177,14 @@ def detections(
Parameters
----------
dets (List[Annotation]): List of detections.
ax (Optional[axes.Axes], optional): Matplotlib axes object.
dets: List of detections.
ax: Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted
on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size.
figsize: Figure size.
Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size.
with_names (bool, optional): Whether to plot the name of the
with_names: Whether to plot the name of the
predicted class next to the detection. Defaults to True.
**kwargs: Additional keyword arguments to pass to the
`plot.detection` function.
@ -214,7 +214,7 @@ def detections(
def detection(
det: Annotation,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
linewidth: float = 1,
edgecolor: str = "w",
facecolor: str = "none",
@ -224,19 +224,19 @@ def detection(
Parameters
----------
det (Annotation): Detection to plot.
ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults
det: Detection to plot.
ax: Matplotlib axes object. Defaults
to None. If provided, the spectrogram will be plotted on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
figsize: Figure size. Defaults
to None. If `ax` is None, this will be used to create a new figure
of the given size.
linewidth (float, optional): Line width of the detection.
linewidth: Line width of the detection.
Defaults to 1.
edgecolor (str, optional): Edge color of the detection.
edgecolor: Edge color of the detection.
Defaults to "w", i.e. white.
facecolor (str, optional): Face color of the detection.
facecolor: Face color of the detection.
Defaults to "none", i.e. transparent.
with_name (bool, optional): Whether to plot the name of the
with_name: Whether to plot the name of the
predicted class next to the detection. Defaults to True.
Returns
@ -277,22 +277,22 @@ def detection(
def _compute_spec_extent(
shape: Tuple[int, int],
shape: tuple[int, int],
params: SpectrogramParameters,
) -> Tuple[float, float, float, float]:
) -> tuple[float, float, float, float]:
"""Compute the extent of a spectrogram.
Parameters
----------
shape (Tuple[int, int]): Shape of the spectrogram.
shape: Shape of the spectrogram.
The first dimension is the frequency axis and the second
dimension is the time axis.
params (SpectrogramParameters): Spectrogram parameters.
params: Spectrogram parameters.
Should be the same as the ones used to compute the spectrogram.
Returns
-------
Tuple[float, float, float, float]: Extent of the spectrogram.
tuple[float, float, float, float]: Extent of the spectrogram.
The first two values are the minimum and maximum time values,
the last two values are the minimum and maximum frequency values.
"""

View File

@ -1,4 +1,4 @@
from typing import Protocol, Tuple
from typing import Protocol
from matplotlib.axes import Axes
from soundevent import data, plot
@ -40,7 +40,7 @@ def plot_false_positive_match(
match: MatchProtocol,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,
@ -111,7 +111,7 @@ def plot_false_negative_match(
match: MatchProtocol,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,
@ -171,7 +171,7 @@ def plot_true_positive_match(
match: MatchProtocol,
preprocessor: PreprocessorProtocol | None = None,
audio_loader: AudioLoader | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,
@ -259,7 +259,7 @@ def plot_cross_trigger_match(
match: MatchProtocol,
preprocessor: PreprocessorProtocol | None = None,
audio_loader: AudioLoader | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,

View File

@ -1,5 +1,3 @@
from typing import Dict, Tuple
import numpy as np
import seaborn as sns
from cycler import cycler
@ -34,14 +32,14 @@ def plot_pr_curve(
recall: np.ndarray,
thresholds: np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
color: str | Tuple[float, float, float] | None = None,
figsize: tuple[int, int] | None = None,
color: str | tuple[float, float, float] | None = None,
add_labels: bool = True,
add_legend: bool = False,
marker: str | Tuple[int, int, float] | None = "o",
markeredgecolor: str | Tuple[float, float, float] | None = None,
marker: str | tuple[int, int, float] | None = "o",
markeredgecolor: str | tuple[float, float, float] | None = None,
markersize: float | None = None,
linestyle: str | Tuple[int, ...] | None = None,
linestyle: str | tuple[int, ...] | None = None,
linewidth: float | None = None,
label: str = "PR Curve",
) -> axes.Axes:
@ -76,9 +74,9 @@ def plot_pr_curve(
def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
include_ap: bool = False,
@ -119,7 +117,7 @@ def plot_threshold_precision_curve(
threshold: np.ndarray,
precision: np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
@ -139,9 +137,9 @@ def plot_threshold_precision_curve(
def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
):
@ -177,7 +175,7 @@ def plot_threshold_recall_curve(
threshold: np.ndarray,
recall: np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
@ -197,9 +195,9 @@ def plot_threshold_recall_curve(
def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
):
@ -236,7 +234,7 @@ def plot_roc_curve(
tpr: np.ndarray,
thresholds: np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
@ -260,9 +258,9 @@ def plot_roc_curve(
def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
figsize: tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:

View File

@ -11,8 +11,6 @@ activations that have lower scores than a local maximum. This helps prevent
multiple, overlapping detections originating from the same sound event.
"""
from typing import Tuple
import torch
NMS_KERNEL_SIZE = 9
@ -27,7 +25,7 @@ BatDetect2.
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
@ -42,11 +40,11 @@ def non_max_suppression(
Parameters
----------
tensor : torch.Tensor
tensor
Input tensor, typically representing a detection heatmap. Must be a
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
`torch.nn.functional.max_pool2d` operation.
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
kernel_size
Size of the sliding window neighborhood used to find local maxima.
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`

View File

@ -1,5 +1,3 @@
from typing import List, Tuple
import torch
from loguru import logger
@ -51,7 +49,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
nms_kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
):
"""Initialize the Postprocessor."""
super().__init__()
@ -66,8 +64,8 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def forward(
self,
output: ModelOutput,
start_times: List[float] | None = None,
) -> List[ClipDetectionsTensor]:
start_times: list[float] | None = None,
) -> list[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=self.nms_kernel_size,

View File

@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification.
"""
from typing import Annotated, Literal, Tuple
from typing import Annotated, Literal
import numpy as np
from pydantic import Field
@ -144,7 +144,7 @@ class AnchorBBoxMapper(ROITargetMapper):
Attributes
----------
dimension_names : List[str]
dimension_names : list[str]
The output dimension names: `['width', 'height']`.
anchor : Anchor
The configured anchor point type (e.g., "center", "bottom-left").
@ -177,7 +177,7 @@ class AnchorBBoxMapper(ROITargetMapper):
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent into an anchor position and scaled box size.
The position is determined by the configured anchor on the sound
@ -190,7 +190,7 @@ class AnchorBBoxMapper(ROITargetMapper):
Returns
-------
Tuple[Position, Size]
tuple[Position, Size]
A tuple of (anchor_position, [scaled_width, scaled_height]).
"""
from soundevent import geometry
@ -314,7 +314,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
Attributes
----------
dimension_names : List[str]
dimension_names : list[str]
The output dimension names: `['left', 'bottom', 'right', 'top']`.
preprocessor : PreprocessorProtocol
The spectrogram preprocessor instance.
@ -371,7 +371,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
Returns
-------
Tuple[Position, Size]
tuple[Position, Size]
A tuple of (peak_position, [l, b, r, t] distances).
"""
from soundevent import geometry
@ -519,14 +519,14 @@ def _build_bounding_box(
Parameters
----------
pos : Tuple[float, float]
pos
Reference position (time, frequency).
duration : float
duration
The required *unscaled* duration (width) of the bounding box.
bandwidth : float
bandwidth
The required *unscaled* frequency bandwidth (height) of the bounding
box.
anchor : Anchor
anchor
Specifies which part of the bounding box the input `pos` corresponds to.
Returns

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Tuple
from typing import Iterable
from loguru import logger
from soundevent import data
@ -33,20 +33,20 @@ class Targets(TargetProtocol):
Attributes
----------
class_names : List[str]
class_names
An ordered list of the unique names of the specific target classes
defined in the configuration.
generic_class_tags : List[data.Tag]
generic_class_tags
A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches).
dimension_names : List[str]
dimension_names
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: List[str]
detection_class_tags: List[data.Tag]
dimension_names: List[str]
class_names: list[str]
detection_class_tags: list[data.Tag]
dimension_names: list[str]
detection_class_name: str
def __init__(self, config: TargetConfig):
@ -128,7 +128,7 @@ class Targets(TargetProtocol):
"""
return self._encode_fn(sound_event)
def decode_class(self, class_label: str) -> List[data.Tag]:
def decode_class(self, class_label: str) -> list[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
@ -142,7 +142,7 @@ class Targets(TargetProtocol):
Returns
-------
List[data.Tag]
list[data.Tag]
The list of tags corresponding to the input class name.
"""
return self._decode_fn(class_label)
@ -161,7 +161,7 @@ class Targets(TargetProtocol):
Returns
-------
Tuple[float, float]
tuple[float, float]
The reference position `(time, frequency)`.
Raises
@ -192,9 +192,9 @@ class Targets(TargetProtocol):
Parameters
----------
pos : Tuple[float, float]
pos
The reference position `(time, frequency)`.
dims : np.ndarray
dims
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
@ -292,7 +292,7 @@ def load_targets(
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[str | None, Position, Size]]:
) -> Iterable[tuple[str | None, Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import lightning as L
import torch
@ -97,7 +97,7 @@ class TrainingModule(L.LightningModule):
def load_model_from_checkpoint(
path: PathLike,
) -> Tuple[Model, "BatDetect2Config"]:
) -> tuple[Model, "BatDetect2Config"]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config

View File

@ -1,6 +1,6 @@
"""Types used in the code base."""
from typing import Any, List, NamedTuple, TypedDict
from typing import Any, NamedTuple, TypedDict
import numpy as np
import torch
@ -86,7 +86,7 @@ class ModelParameters(TypedDict):
resize_factor: float
"""Resize factor."""
class_names: List[str]
class_names: list[str]
"""Class names.
The model is trained to detect these classes.
@ -158,7 +158,7 @@ class FileAnnotation(TypedDict):
notes: str
"""Notes of file."""
annotation: List[Annotation]
annotation: list[Annotation]
"""List of annotations."""
@ -168,26 +168,26 @@ class RunResults(TypedDict):
pred_dict: FileAnnotation
"""Predictions in the format expected by the annotation tool."""
spec_feats: NotRequired[List[np.ndarray]]
spec_feats: NotRequired[list[np.ndarray]]
"""Spectrogram features."""
spec_feat_names: NotRequired[List[str]]
spec_feat_names: NotRequired[list[str]]
"""Spectrogram feature names."""
cnn_feats: NotRequired[List[np.ndarray]]
cnn_feats: NotRequired[list[np.ndarray]]
"""CNN features."""
cnn_feat_names: NotRequired[List[str]]
cnn_feat_names: NotRequired[list[str]]
"""CNN feature names."""
spec_slices: NotRequired[List[np.ndarray]]
spec_slices: NotRequired[list[np.ndarray]]
"""Spectrogram slices."""
class ResultParams(TypedDict):
"""Result parameters."""
class_names: List[str]
class_names: list[str]
"""Class names."""
spec_features: bool
@ -234,7 +234,7 @@ class ProcessingConfiguration(TypedDict):
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
class_names: List[str]
class_names: list[str]
"""Names of the classes the model can detect."""
detection_threshold: float
@ -466,7 +466,7 @@ class FeatureExtractionParameters(TypedDict):
class HeatmapParameters(TypedDict):
"""Parameters that control the heatmap generation function."""
class_names: List[str]
class_names: list[str]
fft_win_length: float
"""Length of the FFT window in seconds."""
@ -553,15 +553,15 @@ class AudioLoaderAnnotationGroup(TypedDict):
individual_ids: np.ndarray
x_inds: np.ndarray
y_inds: np.ndarray
annotation: List[Annotation]
annotation: list[Annotation]
annotated: bool
class_id_file: int
"""ID of the class of the file."""
class AudioLoaderParameters(TypedDict):
class_names: List[str]
classes_to_ignore: List[str]
class_names: list[str]
classes_to_ignore: list[str]
target_samp_rate: int
scale_raw_audio: bool
fft_win_length: float

View File

@ -1,12 +1,9 @@
from dataclasses import dataclass
from typing import (
Dict,
Generic,
Iterable,
List,
Protocol,
Sequence,
Tuple,
TypeVar,
)
@ -33,7 +30,7 @@ class MatchEvaluation:
gt_geometry: data.Geometry | None
pred_score: float
pred_class_scores: Dict[str, float]
pred_class_scores: dict[str, float]
pred_geometry: data.Geometry | None
affinity: float
@ -66,7 +63,7 @@ class MatchEvaluation:
@dataclass
class ClipMatches:
clip: data.Clip
matches: List[MatchEvaluation]
matches: list[MatchEvaluation]
class MatcherProtocol(Protocol):
@ -75,7 +72,7 @@ class MatcherProtocol(Protocol):
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
) -> Iterable[Tuple[int | None, int | None, float]]: ...
) -> Iterable[tuple[int | None, int | None, float]]: ...
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
@ -94,7 +91,7 @@ class MetricsProtocol(Protocol):
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]],
) -> Dict[str, float]: ...
) -> dict[str, float]: ...
class PlotterProtocol(Protocol):
@ -102,7 +99,7 @@ class PlotterProtocol(Protocol):
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]],
) -> Iterable[Tuple[str, Figure]]: ...
) -> Iterable[tuple[str, Figure]]: ...
EvaluationOutput = TypeVar("EvaluationOutput")
@ -119,8 +116,8 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
def compute_metrics(
self, eval_outputs: EvaluationOutput
) -> Dict[str, float]: ...
) -> dict[str, float]: ...
def generate_plots(
self, eval_outputs: EvaluationOutput
) -> Iterable[Tuple[str, Figure]]: ...
) -> Iterable[tuple[str, Figure]]: ...

View File

@ -1,4 +1,4 @@
from typing import Callable, NamedTuple, Protocol, Tuple
from typing import Callable, NamedTuple, Protocol
import torch
from soundevent import data
@ -52,7 +52,7 @@ steps, and returns the final `Heatmaps` used for model training.
Augmentation = Callable[
[torch.Tensor, data.ClipAnnotation],
Tuple[torch.Tensor, data.ClipAnnotation],
tuple[torch.Tensor, data.ClipAnnotation],
]

View File

@ -1,5 +1,4 @@
import warnings
from typing import Tuple
import librosa
import librosa.core.spectrum
@ -147,7 +146,7 @@ def load_audio(
target_samp_rate: int,
scale: bool = False,
max_duration: float | None = None,
) -> Tuple[int, np.ndarray]:
) -> tuple[int, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.

View File

@ -1,6 +1,6 @@
import json
import os
from typing import Any, Iterator, List, Tuple
from typing import Any, Iterator
import librosa
import numpy as np
@ -60,7 +60,7 @@ def get_default_bd_args():
return args
def list_audio_files(ip_dir: str) -> List[str]:
def list_audio_files(ip_dir: str) -> list[str]:
"""Get all audio files in directory.
Args:
@ -86,7 +86,7 @@ def load_model(
load_weights: bool = True,
device: torch.device | str | None = None,
weights_only: bool = True,
) -> Tuple[DetectionModel, ModelParameters]:
) -> tuple[DetectionModel, ModelParameters]:
"""Load model from file.
Args:
@ -185,15 +185,16 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
def get_annotations_from_preds(
predictions: PredictionResults,
class_names: List[str],
) -> List[Annotation]:
class_names: list[str],
) -> list[Annotation]:
"""Get list of annotations from predictions."""
# Get the best class prediction probability and index for each detection
class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0)
# Pack the results into a list of dictionaries
annotations: List[Annotation] = [
annotations: list[Annotation] = [
Annotation(
{
"start_time": round(float(start_time), 4),
"end_time": round(float(end_time), 4),
@ -205,6 +206,7 @@ def get_annotations_from_preds(
"individual": "-1",
"event": "Echolocation",
}
)
for (
start_time,
end_time,
@ -232,7 +234,7 @@ def format_single_result(
time_exp: float,
duration: float,
predictions: PredictionResults,
class_names: List[str],
class_names: list[str],
) -> FileAnnotation:
"""Format results into the format expected by the annotation tool.
@ -315,9 +317,9 @@ def convert_results(
]
# combine into final results dictionary
results: RunResults = {
results: RunResults = RunResults({ # type: ignore
"pred_dict": pred_dict,
}
})
# add spectrogram features if they exist
if len(spec_feats) > 0 and params["spec_features"]:
@ -413,7 +415,7 @@ def compute_spectrogram(
sampling_rate: int,
params: SpectrogramParameters,
device: torch.device,
) -> Tuple[float, torch.Tensor]:
) -> tuple[float, torch.Tensor]:
"""Compute a spectrogram from an audio array.
Will pad the audio array so that it is evenly divisible by the
@ -475,7 +477,7 @@ def iterate_over_chunks(
audio: np.ndarray,
samplerate: float,
chunk_size: float,
) -> Iterator[Tuple[float, np.ndarray]]:
) -> Iterator[tuple[float, np.ndarray]]:
"""Iterate over audio in chunks of size chunk_size.
Parameters
@ -509,7 +511,7 @@ def _process_spectrogram(
samplerate: float,
model: DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[PredictionResults, np.ndarray]:
) -> tuple[PredictionResults, np.ndarray]:
# evaluate model
with torch.no_grad():
outputs = model(spec)
@ -546,7 +548,7 @@ def postprocess_model_outputs(
outputs: ModelOutput,
samp_rate: int,
config: ProcessingConfiguration,
) -> Tuple[List[Annotation], np.ndarray]:
) -> tuple[list[Annotation], np.ndarray]:
# run non-max suppression
pred_nms_list, features = pp.run_nms(
outputs,
@ -585,7 +587,7 @@ def process_spectrogram(
samplerate: int,
model: DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[List[Annotation], np.ndarray]:
) -> tuple[list[Annotation], np.ndarray]:
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
@ -604,9 +606,9 @@ def process_spectrogram(
Returns
-------
detections: List[Annotation]
detections
List of detections predicted by the model.
features : np.ndarray
features
An array of CNN features associated with each annotation.
The array is of shape (num_detections, num_features).
Is empty if `config["cnn_features"]` is False.
@ -632,7 +634,7 @@ def _process_audio_array(
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
) -> tuple[PredictionResults, np.ndarray, torch.Tensor]:
# load audio file and compute spectrogram
_, spec = compute_spectrogram(
audio,
@ -669,7 +671,7 @@ def process_audio_array(
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
"""Process a single audio array with detection model.
Parameters
@ -689,7 +691,7 @@ def process_audio_array(
Returns
-------
annotations : List[Annotation]
annotations : list[Annotation]
List of annotations predicted by the model.
features : np.ndarray
Array of CNN features associated with each annotation.

View File

@ -1,7 +1,7 @@
import json
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import pytest
from soundevent import data
@ -20,11 +20,11 @@ def create_legacy_file_annotation(
duration: float = 5.0,
time_exp: float = 1.0,
class_name: str = "Myotis",
annotations: Optional[List[Dict[str, Any]]] = None,
annotations: list[dict[str, Any]] | None = None,
annotated: bool = True,
issues: bool = False,
notes: str = "",
) -> Dict[str, Any]:
) -> dict[str, Any]:
if annotations is None:
annotations = [
{
@ -61,7 +61,7 @@ def create_legacy_file_annotation(
@pytest.fixture
def batdetect2_files_test_setup(
tmp_path: Path, wav_factory
) -> Tuple[Path, Path, List[Dict[str, Any]]]:
) -> tuple[Path, Path, list[dict[str, Any]]]:
"""Sets up a directory structure for batdetect2 files format tests."""
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
@ -143,7 +143,7 @@ def batdetect2_files_test_setup(
@pytest.fixture
def batdetect2_merged_test_setup(
tmp_path: Path, batdetect2_files_test_setup
) -> Tuple[Path, Path, List[Dict[str, Any]]]:
) -> tuple[Path, Path, list[dict[str, Any]]]:
"""Sets up a directory structure for batdetect2 merged file format tests."""
audio_dir, _, files_data = batdetect2_files_test_setup
merged_anns_path = tmp_path / "merged_anns.json"