mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Removing legacy types
This commit is contained in:
parent
8ac4f4c44d
commit
1a7c0b4b3a
@ -98,7 +98,6 @@ consult the API documentation in the code.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -272,7 +271,7 @@ def process_spectrogram(
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
Parameters
|
||||
@ -314,7 +313,7 @@ def process_audio(
|
||||
model: DetectionModel = MODEL,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
|
||||
Parameters
|
||||
@ -357,7 +356,7 @@ def postprocess(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
"""Postprocess model outputs.
|
||||
|
||||
Convert model tensor outputs to predicted bounding boxes and
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -110,7 +110,7 @@ class BatDetect2API:
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
save_predictions: bool = True,
|
||||
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||
return evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
@ -187,7 +187,7 @@ class BatDetect2API:
|
||||
def process_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> List[Detection]:
|
||||
) -> list[Detection]:
|
||||
spec = self.generate_spectrogram(audio)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
@ -195,7 +195,7 @@ class BatDetect2API:
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
start_time: float = 0,
|
||||
) -> List[Detection]:
|
||||
) -> list[Detection]:
|
||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||
raise ValueError("Batched spectrograms not supported.")
|
||||
|
||||
@ -214,7 +214,7 @@ class BatDetect2API:
|
||||
def process_directory(
|
||||
self,
|
||||
audio_dir: data.PathLike,
|
||||
) -> List[ClipDetections]:
|
||||
) -> list[ClipDetections]:
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(files)
|
||||
|
||||
@ -222,7 +222,7 @@ class BatDetect2API:
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
num_workers: int | None = None,
|
||||
) -> List[ClipDetections]:
|
||||
) -> list[ClipDetections]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
@ -238,7 +238,7 @@ class BatDetect2API:
|
||||
clips: Sequence[data.Clip],
|
||||
batch_size: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> List[ClipDetections]:
|
||||
) -> list[ClipDetections]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
@ -274,7 +274,7 @@ class BatDetect2API:
|
||||
def load_predictions(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
) -> List[ClipDetections]:
|
||||
) -> list[ClipDetections]:
|
||||
return self.formatter.load(path)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
from typing import Any, Type, TypeVar, Union, overload
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
|
||||
import yaml
|
||||
from deepmerge.merger import Merger
|
||||
@ -69,8 +69,7 @@ class BaseConfig(BaseModel):
|
||||
|
||||
T = TypeVar("T")
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
|
||||
Schema = Union[Type[T_Model], TypeAdapter[T]]
|
||||
Schema = Type[T_Model] | TypeAdapter[T]
|
||||
|
||||
|
||||
def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Tuple
|
||||
|
||||
from soundevent import data
|
||||
|
||||
@ -10,7 +9,7 @@ from batdetect2.typing.targets import TargetProtocol
|
||||
def iterate_over_sound_events(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]:
|
||||
) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
|
||||
"""Iterate over sound events in a dataset.
|
||||
|
||||
Parameters
|
||||
@ -24,7 +23,7 @@ def iterate_over_sound_events(
|
||||
|
||||
Yields
|
||||
------
|
||||
Tuple[Optional[str], data.SoundEventAnnotation]
|
||||
tuple[Optional[str], data.SoundEventAnnotation]
|
||||
A tuple containing:
|
||||
- The encoded class name (str) for the sound event, or None if it
|
||||
cannot be encoded to a specific class.
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Tuple
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
@ -15,7 +13,7 @@ def split_dataset_by_recordings(
|
||||
targets: TargetProtocol,
|
||||
train_size: float = 0.75,
|
||||
random_state: int | None = None,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
) -> tuple[Dataset, Dataset]:
|
||||
recordings = extract_recordings_df(dataset)
|
||||
|
||||
sound_events = extract_sound_events_df(
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -45,7 +43,7 @@ def run_nms(
|
||||
outputs: ModelOutput,
|
||||
params: NonMaximumSuppressionConfig,
|
||||
sampling_rate: np.ndarray,
|
||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||
) -> tuple[list[PredictionResults], list[np.ndarray]]:
|
||||
"""Run non-maximum suppression on the output of the model.
|
||||
|
||||
Model outputs processed are expected to have a batch dimension.
|
||||
@ -73,8 +71,8 @@ def run_nms(
|
||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||
|
||||
# loop over batch to save outputs
|
||||
preds: List[PredictionResults] = []
|
||||
feats: List[np.ndarray] = []
|
||||
preds: list[PredictionResults] = []
|
||||
feats: list[np.ndarray] = []
|
||||
for num_detection in range(pred_det_nms.shape[0]):
|
||||
# get valid indices
|
||||
inds_ord = torch.argsort(x_pos[num_detection, :])
|
||||
@ -151,7 +149,7 @@ def run_nms(
|
||||
|
||||
def non_max_suppression(
|
||||
heat: torch.Tensor,
|
||||
kernel_size: int | Tuple[int, int],
|
||||
kernel_size: int | tuple[int, int],
|
||||
):
|
||||
# kernel can be an int or list/tuple
|
||||
if isinstance(kernel_size, int):
|
||||
|
||||
@ -4,12 +4,9 @@ from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -32,7 +29,7 @@ from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
|
||||
|
||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
name="top_class_plot"
|
||||
@ -73,7 +70,7 @@ class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
@ -140,7 +137,7 @@ class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
@ -223,7 +220,7 @@ class ConfusionMatrix(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
cm, labels = compute_confusion_matrix(
|
||||
clip_evaluations,
|
||||
self.targets,
|
||||
@ -295,26 +292,26 @@ class ExampleClassificationPlot(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||
|
||||
for class_name, matches in grouped.items():
|
||||
true_positives: List[MatchEval] = get_binned_sample(
|
||||
true_positives: list[MatchEval] = get_binned_sample(
|
||||
matches.true_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_positives: List[MatchEval] = get_binned_sample(
|
||||
false_positives: list[MatchEval] = get_binned_sample(
|
||||
matches.false_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_negatives: List[MatchEval] = random.sample(
|
||||
false_negatives: list[MatchEval] = random.sample(
|
||||
matches.false_negatives,
|
||||
k=min(self.num_examples, len(matches.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers: List[MatchEval] = get_binned_sample(
|
||||
cross_triggers: list[MatchEval] = get_binned_sample(
|
||||
matches.cross_triggers, n_examples=self.num_examples
|
||||
)
|
||||
|
||||
@ -374,16 +371,16 @@ def build_top_class_plotter(
|
||||
|
||||
@dataclass
|
||||
class ClassMatches:
|
||||
false_positives: List[MatchEval] = field(default_factory=list)
|
||||
false_negatives: List[MatchEval] = field(default_factory=list)
|
||||
true_positives: List[MatchEval] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEval] = field(default_factory=list)
|
||||
false_positives: list[MatchEval] = field(default_factory=list)
|
||||
false_negatives: list[MatchEval] = field(default_factory=list)
|
||||
true_positives: list[MatchEval] = field(default_factory=list)
|
||||
cross_triggers: list[MatchEval] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(
|
||||
clip_evals: Sequence[ClipEval],
|
||||
threshold: float = 0.2,
|
||||
) -> Dict[str, ClassMatches]:
|
||||
) -> dict[str, ClassMatches]:
|
||||
class_examples = defaultdict(ClassMatches)
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
@ -412,7 +409,7 @@ def group_matches(
|
||||
return class_examples
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||
def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import StratifiedGroupKFold
|
||||
@ -12,8 +11,8 @@ from batdetect2 import types
|
||||
|
||||
|
||||
def print_dataset_stats(
|
||||
data: List[types.FileAnnotation],
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
data: list[types.FileAnnotation],
|
||||
classes_to_ignore: list[str] | None = None,
|
||||
) -> Counter[str]:
|
||||
print("Num files:", len(data))
|
||||
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||
@ -22,7 +21,7 @@ def print_dataset_stats(
|
||||
return counts
|
||||
|
||||
|
||||
def load_file_names(file_name: str) -> List[str]:
|
||||
def load_file_names(file_name: str) -> list[str]:
|
||||
if not os.path.isfile(file_name):
|
||||
raise FileNotFoundError(f"Input file not found - {file_name}")
|
||||
|
||||
@ -100,12 +99,12 @@ def parse_args():
|
||||
|
||||
|
||||
def split_data(
|
||||
data: List[types.FileAnnotation],
|
||||
data: list[types.FileAnnotation],
|
||||
train_file: str,
|
||||
test_file: str,
|
||||
n_splits: int = 5,
|
||||
random_state: int = 0,
|
||||
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
|
||||
) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
|
||||
if train_file != "" and test_file != "":
|
||||
# user has specifed the train / test split
|
||||
mapping = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, NamedTuple, Sequence
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
@ -29,7 +29,7 @@ class DatasetItem(NamedTuple):
|
||||
|
||||
|
||||
class InferenceDataset(Dataset[DatasetItem]):
|
||||
clips: List[data.Clip]
|
||||
clips: list[data.Clip]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -111,7 +111,7 @@ def build_inference_dataset(
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
||||
def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return DatasetItem(
|
||||
spec=torch.stack(
|
||||
|
||||
@ -26,8 +26,6 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
|
||||
is the ``build_model`` factory function exported from this module.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
@ -142,7 +140,7 @@ class Model(torch.nn.Module):
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||
"""Run the full detection pipeline on a waveform tensor.
|
||||
|
||||
Converts the waveform to a spectrogram, passes it through the
|
||||
@ -157,7 +155,7 @@ class Model(torch.nn.Module):
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ClipDetectionsTensor]
|
||||
list[ClipDetectionsTensor]
|
||||
One detection tensor per clip in the batch. Each tensor encodes
|
||||
the detected events (locations, class scores, sizes) for that
|
||||
clip.
|
||||
|
||||
@ -23,7 +23,7 @@ output so that the output spatial dimensions always match the input spatial
|
||||
dimensions.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Literal, Tuple, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -58,6 +58,14 @@ from batdetect2.typing.models import (
|
||||
EncoderProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BackboneImportConfig",
|
||||
"UNetBackbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
|
||||
class UNetBackboneConfig(BaseConfig):
|
||||
"""Configuration for a U-Net-style encoder-decoder backbone.
|
||||
@ -110,15 +118,6 @@ class BackboneImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BackboneImportConfig",
|
||||
"UNetBackbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
|
||||
class UNetBackbone(BackboneModel):
|
||||
"""U-Net-style encoder-decoder backbone network.
|
||||
|
||||
@ -262,7 +261,8 @@ class UNetBackbone(BackboneModel):
|
||||
|
||||
|
||||
BackboneConfig = Annotated[
|
||||
Union[UNetBackboneConfig,], Field(discriminator="name")
|
||||
UNetBackboneConfig | BackboneImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -292,7 +292,7 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
def _pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
) -> tuple[torch.Tensor, int, int]:
|
||||
"""Pad a tensor's height and width to be divisible by ``factor``.
|
||||
|
||||
Adds zero-padding to the bottom and right edges of the tensor so that
|
||||
@ -308,7 +308,7 @@ def _pad_adjust(
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[torch.Tensor, int, int]
|
||||
tuple[torch.Tensor, int, int]
|
||||
- Padded tensor.
|
||||
- Number of rows added to the height (``h_pad``).
|
||||
- Number of columns added to the width (``w_pad``).
|
||||
|
||||
@ -46,7 +46,7 @@ configuration object (one of the ``*Config`` classes exported here), using
|
||||
a discriminated-union ``name`` field to dispatch to the correct class.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Tuple, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -687,7 +687,7 @@ class FreqCoordConvUpConfig(BaseConfig):
|
||||
up_mode: str = "bilinear"
|
||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||
|
||||
up_scale: Tuple[int, int] = (2, 2)
|
||||
up_scale: tuple[int, int] = (2, 2)
|
||||
"""Scaling factor for height and width during upsampling."""
|
||||
|
||||
|
||||
@ -706,22 +706,22 @@ class FreqCoordConvUpBlock(Block):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
in_channels
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
out_channels
|
||||
Number of output channels after the convolution.
|
||||
input_height : int
|
||||
input_height
|
||||
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
||||
Used to calculate the height for coordinate feature generation after
|
||||
upsampling.
|
||||
kernel_size : int, default=3
|
||||
kernel_size
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
pad_size
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
up_mode
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
||||
"bicubic").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
up_scale
|
||||
Scaling factor for height and width during upsampling
|
||||
(typically (2, 2)).
|
||||
"""
|
||||
@ -734,7 +734,7 @@ class FreqCoordConvUpBlock(Block):
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
up_scale: tuple[int, int] = (2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@ -824,7 +824,7 @@ class StandardConvUpConfig(BaseConfig):
|
||||
up_mode: str = "bilinear"
|
||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||
|
||||
up_scale: Tuple[int, int] = (2, 2)
|
||||
up_scale: tuple[int, int] = (2, 2)
|
||||
"""Scaling factor for height and width during upsampling."""
|
||||
|
||||
|
||||
@ -839,17 +839,17 @@ class StandardConvUpBlock(Block):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
in_channels
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
out_channels
|
||||
Number of output channels after the convolution.
|
||||
kernel_size : int, default=3
|
||||
kernel_size
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
pad_size
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
up_mode
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
up_scale
|
||||
Scaling factor for height and width during upsampling.
|
||||
"""
|
||||
|
||||
@ -860,7 +860,7 @@ class StandardConvUpBlock(Block):
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
up_scale: tuple[int, int] = (2, 2),
|
||||
):
|
||||
super(StandardConvUpBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
@ -922,15 +922,14 @@ class StandardConvUpBlock(Block):
|
||||
|
||||
|
||||
LayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
SelfAttentionConfig,
|
||||
"LayerGroupConfig",
|
||||
],
|
||||
ConvConfig
|
||||
| BlockImportConfig
|
||||
| FreqCoordConvDownConfig
|
||||
| StandardConvDownConfig
|
||||
| FreqCoordConvUpConfig
|
||||
| StandardConvUpConfig
|
||||
| SelfAttentionConfig
|
||||
| "LayerGroupConfig",
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
@ -952,7 +951,7 @@ class LayerGroupConfig(BaseConfig):
|
||||
"""
|
||||
|
||||
name: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: List[LayerConfig]
|
||||
layers: list[LayerConfig]
|
||||
|
||||
|
||||
class LayerGroup(nn.Module):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Tuple
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
|
||||
@ -16,7 +14,7 @@ __all__ = [
|
||||
def plot_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
add_points: bool = False,
|
||||
@ -50,7 +48,7 @@ def plot_clip_annotation(
|
||||
def plot_anchor_points(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
targets: TargetProtocol,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
size: int = 1,
|
||||
color: str = "red",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Iterable
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
@ -18,7 +18,7 @@ __all__ = [
|
||||
def plot_clip_prediction(
|
||||
clip_prediction: data.ClipPrediction,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
add_legend: bool = False,
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.axes import Axes
|
||||
@ -19,7 +17,7 @@ def plot_clip(
|
||||
clip: data.Clip,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
spec_cmap: str = "gray",
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""General plotting utilities."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -14,7 +12,7 @@ __all__ = [
|
||||
|
||||
def create_ax(
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
@ -31,7 +29,7 @@ def plot_spectrogram(
|
||||
min_freq: float | None = None,
|
||||
max_freq: float | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_colorbar: bool = False,
|
||||
colorbar_kwargs: dict | None = None,
|
||||
vmin: float | None = None,
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Plot heatmaps"""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
@ -14,7 +12,7 @@ from batdetect2.plotting.common import create_ax
|
||||
def plot_detection_heatmap(
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
threshold: float | None = None,
|
||||
alpha: float = 1,
|
||||
cmap: str | Colormap = "jet",
|
||||
@ -50,8 +48,8 @@ def plot_detection_heatmap(
|
||||
def plot_classification_heatmap(
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
class_names: List[str] | None = None,
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
class_names: list[str] | None = None,
|
||||
threshold: float | None = 0.1,
|
||||
alpha: float = 1,
|
||||
cmap: str | Colormap = "tab20",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Plot functions to visualize detections and spectrograms."""
|
||||
|
||||
from typing import List, Tuple, cast
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.ticker as tick
|
||||
import numpy as np
|
||||
@ -27,7 +27,7 @@ def spectrogram(
|
||||
spec: torch.Tensor | np.ndarray,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
cmap: str = "plasma",
|
||||
start_time: float = 0,
|
||||
) -> axes.Axes:
|
||||
@ -35,18 +35,18 @@ def spectrogram(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
||||
spec: Spectrogram to plot.
|
||||
config: Configuration
|
||||
used to compute the spectrogram. Defaults to None. If None,
|
||||
the default configuration will be used.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
ax: Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
figsize: Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
||||
start_time (float, optional): Start time of the spectrogram.
|
||||
cmap: Colormap to use. Defaults to "plasma".
|
||||
start_time: Start time of the spectrogram.
|
||||
Defaults to 0. This is useful if plotting a spectrogram
|
||||
of a segment of a longer audio file.
|
||||
|
||||
@ -104,10 +104,10 @@ def spectrogram(
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: torch.Tensor | np.ndarray,
|
||||
dets: List[Annotation],
|
||||
dets: list[Annotation],
|
||||
config: ProcessingConfiguration | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
cmap: str = "plasma",
|
||||
with_names: bool = True,
|
||||
start_time: float = 0,
|
||||
@ -117,21 +117,21 @@ def spectrogram_with_detections(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
||||
detections (List[Annotation]): List of detections.
|
||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
||||
spec: Spectrogram to plot.
|
||||
detections: List of detections.
|
||||
config: Configuration
|
||||
used to compute the spectrogram. Defaults to None. If None,
|
||||
the default configuration will be used.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
ax: Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
figsize: Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
||||
with_names (bool, optional): Whether to plot the name of the
|
||||
cmap: Colormap to use. Defaults to "plasma".
|
||||
with_names: Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
start_time (float, optional): Start time of the spectrogram.
|
||||
start_time: Start time of the spectrogram.
|
||||
Defaults to 0. This is useful if plotting a spectrogram
|
||||
of a segment of a longer audio file.
|
||||
**kwargs: Additional keyword arguments to pass to the
|
||||
@ -167,9 +167,9 @@ def spectrogram_with_detections(
|
||||
|
||||
|
||||
def detections(
|
||||
dets: List[Annotation],
|
||||
dets: list[Annotation],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
with_names: bool = True,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
@ -177,14 +177,14 @@ def detections(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dets (List[Annotation]): List of detections.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
dets: List of detections.
|
||||
ax: Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
figsize: Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
with_names (bool, optional): Whether to plot the name of the
|
||||
with_names: Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
**kwargs: Additional keyword arguments to pass to the
|
||||
`plot.detection` function.
|
||||
@ -214,7 +214,7 @@ def detections(
|
||||
def detection(
|
||||
det: Annotation,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
linewidth: float = 1,
|
||||
edgecolor: str = "w",
|
||||
facecolor: str = "none",
|
||||
@ -224,19 +224,19 @@ def detection(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
det (Annotation): Detection to plot.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults
|
||||
det: Detection to plot.
|
||||
ax: Matplotlib axes object. Defaults
|
||||
to None. If provided, the spectrogram will be plotted on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
|
||||
figsize: Figure size. Defaults
|
||||
to None. If `ax` is None, this will be used to create a new figure
|
||||
of the given size.
|
||||
linewidth (float, optional): Line width of the detection.
|
||||
linewidth: Line width of the detection.
|
||||
Defaults to 1.
|
||||
edgecolor (str, optional): Edge color of the detection.
|
||||
edgecolor: Edge color of the detection.
|
||||
Defaults to "w", i.e. white.
|
||||
facecolor (str, optional): Face color of the detection.
|
||||
facecolor: Face color of the detection.
|
||||
Defaults to "none", i.e. transparent.
|
||||
with_name (bool, optional): Whether to plot the name of the
|
||||
with_name: Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
|
||||
Returns
|
||||
@ -277,22 +277,22 @@ def detection(
|
||||
|
||||
|
||||
def _compute_spec_extent(
|
||||
shape: Tuple[int, int],
|
||||
shape: tuple[int, int],
|
||||
params: SpectrogramParameters,
|
||||
) -> Tuple[float, float, float, float]:
|
||||
) -> tuple[float, float, float, float]:
|
||||
"""Compute the extent of a spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape (Tuple[int, int]): Shape of the spectrogram.
|
||||
shape: Shape of the spectrogram.
|
||||
The first dimension is the frequency axis and the second
|
||||
dimension is the time axis.
|
||||
params (SpectrogramParameters): Spectrogram parameters.
|
||||
params: Spectrogram parameters.
|
||||
Should be the same as the ones used to compute the spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float, float, float]: Extent of the spectrogram.
|
||||
tuple[float, float, float, float]: Extent of the spectrogram.
|
||||
The first two values are the minimum and maximum time values,
|
||||
the last two values are the minimum and maximum frequency values.
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Protocol, Tuple
|
||||
from typing import Protocol
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
@ -40,7 +40,7 @@ def plot_false_positive_match(
|
||||
match: MatchProtocol,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
@ -111,7 +111,7 @@ def plot_false_negative_match(
|
||||
match: MatchProtocol,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
@ -171,7 +171,7 @@ def plot_true_positive_match(
|
||||
match: MatchProtocol,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
@ -259,7 +259,7 @@ def plot_cross_trigger_match(
|
||||
match: MatchProtocol,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from cycler import cycler
|
||||
@ -34,14 +32,14 @@ def plot_pr_curve(
|
||||
recall: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
color: str | Tuple[float, float, float] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
color: str | tuple[float, float, float] | None = None,
|
||||
add_labels: bool = True,
|
||||
add_legend: bool = False,
|
||||
marker: str | Tuple[int, int, float] | None = "o",
|
||||
markeredgecolor: str | Tuple[float, float, float] | None = None,
|
||||
marker: str | tuple[int, int, float] | None = "o",
|
||||
markeredgecolor: str | tuple[float, float, float] | None = None,
|
||||
markersize: float | None = None,
|
||||
linestyle: str | Tuple[int, ...] | None = None,
|
||||
linestyle: str | tuple[int, ...] | None = None,
|
||||
linewidth: float | None = None,
|
||||
label: str = "PR Curve",
|
||||
) -> axes.Axes:
|
||||
@ -76,9 +74,9 @@ def plot_pr_curve(
|
||||
|
||||
|
||||
def plot_pr_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
include_ap: bool = False,
|
||||
@ -119,7 +117,7 @@ def plot_threshold_precision_curve(
|
||||
threshold: np.ndarray,
|
||||
precision: np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -139,9 +137,9 @@ def plot_threshold_precision_curve(
|
||||
|
||||
|
||||
def plot_threshold_precision_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
@ -177,7 +175,7 @@ def plot_threshold_recall_curve(
|
||||
threshold: np.ndarray,
|
||||
recall: np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -197,9 +195,9 @@ def plot_threshold_recall_curve(
|
||||
|
||||
|
||||
def plot_threshold_recall_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
@ -236,7 +234,7 @@ def plot_roc_curve(
|
||||
tpr: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -260,9 +258,9 @@ def plot_roc_curve(
|
||||
|
||||
|
||||
def plot_roc_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
|
||||
@ -11,8 +11,6 @@ activations that have lower scores than a local maximum. This helps prevent
|
||||
multiple, overlapping detections originating from the same sound event.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
@ -27,7 +25,7 @@ BatDetect2.
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
||||
kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
||||
|
||||
@ -42,11 +40,11 @@ def non_max_suppression(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
tensor
|
||||
Input tensor, typically representing a detection heatmap. Must be a
|
||||
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
|
||||
`torch.nn.functional.max_pool2d` operation.
|
||||
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
|
||||
kernel_size
|
||||
Size of the sliding window neighborhood used to find local maxima.
|
||||
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
|
||||
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -51,7 +49,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
||||
nms_kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
@ -66,8 +64,8 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
def forward(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: List[float] | None = None,
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
start_times: list[float] | None = None,
|
||||
) -> list[ClipDetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=self.nms_kernel_size,
|
||||
|
||||
@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
|
||||
*geometric* aspect of target definition from *semantic* classification.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Literal, Tuple
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
@ -144,7 +144,7 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
dimension_names : list[str]
|
||||
The output dimension names: `['width', 'height']`.
|
||||
anchor : Anchor
|
||||
The configured anchor point type (e.g., "center", "bottom-left").
|
||||
@ -177,7 +177,7 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
|
||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||
"""Encode a SoundEvent into an anchor position and scaled box size.
|
||||
|
||||
The position is determined by the configured anchor on the sound
|
||||
@ -190,7 +190,7 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
tuple[Position, Size]
|
||||
A tuple of (anchor_position, [scaled_width, scaled_height]).
|
||||
"""
|
||||
from soundevent import geometry
|
||||
@ -314,7 +314,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
dimension_names : list[str]
|
||||
The output dimension names: `['left', 'bottom', 'right', 'top']`.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The spectrogram preprocessor instance.
|
||||
@ -371,7 +371,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
tuple[Position, Size]
|
||||
A tuple of (peak_position, [l, b, r, t] distances).
|
||||
"""
|
||||
from soundevent import geometry
|
||||
@ -519,14 +519,14 @@ def _build_bounding_box(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
pos
|
||||
Reference position (time, frequency).
|
||||
duration : float
|
||||
duration
|
||||
The required *unscaled* duration (width) of the bounding box.
|
||||
bandwidth : float
|
||||
bandwidth
|
||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||
box.
|
||||
anchor : Anchor
|
||||
anchor
|
||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||
|
||||
Returns
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Tuple
|
||||
from typing import Iterable
|
||||
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
@ -33,20 +33,20 @@ class Targets(TargetProtocol):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
class_names
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined in the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
generic_class_tags
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
dimension_names
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
detection_class_tags: List[data.Tag]
|
||||
dimension_names: List[str]
|
||||
class_names: list[str]
|
||||
detection_class_tags: list[data.Tag]
|
||||
dimension_names: list[str]
|
||||
detection_class_name: str
|
||||
|
||||
def __init__(self, config: TargetConfig):
|
||||
@ -128,7 +128,7 @@ class Targets(TargetProtocol):
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
def decode_class(self, class_label: str) -> list[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
@ -142,7 +142,7 @@ class Targets(TargetProtocol):
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
list[data.Tag]
|
||||
The list of tags corresponding to the input class name.
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
@ -161,7 +161,7 @@ class Targets(TargetProtocol):
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
@ -192,9 +192,9 @@ class Targets(TargetProtocol):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
pos
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
dims
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
@ -292,7 +292,7 @@ def load_targets(
|
||||
def iterate_encoded_sound_events(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> Iterable[Tuple[str | None, Position, Size]]:
|
||||
) -> Iterable[tuple[str | None, Position, Size]]:
|
||||
for sound_event in sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
@ -97,7 +97,7 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> Tuple[Model, "BatDetect2Config"]:
|
||||
) -> tuple[Model, "BatDetect2Config"]:
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
return module.model, module.config
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, List, NamedTuple, TypedDict
|
||||
from typing import Any, NamedTuple, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -86,7 +86,7 @@ class ModelParameters(TypedDict):
|
||||
resize_factor: float
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: List[str]
|
||||
class_names: list[str]
|
||||
"""Class names.
|
||||
|
||||
The model is trained to detect these classes.
|
||||
@ -158,7 +158,7 @@ class FileAnnotation(TypedDict):
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: List[Annotation]
|
||||
annotation: list[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
|
||||
@ -168,26 +168,26 @@ class RunResults(TypedDict):
|
||||
pred_dict: FileAnnotation
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: NotRequired[List[np.ndarray]]
|
||||
spec_feats: NotRequired[list[np.ndarray]]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: NotRequired[List[str]]
|
||||
spec_feat_names: NotRequired[list[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: NotRequired[List[np.ndarray]]
|
||||
cnn_feats: NotRequired[list[np.ndarray]]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: NotRequired[List[str]]
|
||||
cnn_feat_names: NotRequired[list[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: NotRequired[List[np.ndarray]]
|
||||
spec_slices: NotRequired[list[np.ndarray]]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
class ResultParams(TypedDict):
|
||||
"""Result parameters."""
|
||||
|
||||
class_names: List[str]
|
||||
class_names: list[str]
|
||||
"""Class names."""
|
||||
|
||||
spec_features: bool
|
||||
@ -234,7 +234,7 @@ class ProcessingConfiguration(TypedDict):
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
class_names: List[str]
|
||||
class_names: list[str]
|
||||
"""Names of the classes the model can detect."""
|
||||
|
||||
detection_threshold: float
|
||||
@ -466,7 +466,7 @@ class FeatureExtractionParameters(TypedDict):
|
||||
class HeatmapParameters(TypedDict):
|
||||
"""Parameters that control the heatmap generation function."""
|
||||
|
||||
class_names: List[str]
|
||||
class_names: list[str]
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
@ -553,15 +553,15 @@ class AudioLoaderAnnotationGroup(TypedDict):
|
||||
individual_ids: np.ndarray
|
||||
x_inds: np.ndarray
|
||||
y_inds: np.ndarray
|
||||
annotation: List[Annotation]
|
||||
annotation: list[Annotation]
|
||||
annotated: bool
|
||||
class_id_file: int
|
||||
"""ID of the class of the file."""
|
||||
|
||||
|
||||
class AudioLoaderParameters(TypedDict):
|
||||
class_names: List[str]
|
||||
classes_to_ignore: List[str]
|
||||
class_names: list[str]
|
||||
classes_to_ignore: list[str]
|
||||
target_samp_rate: int
|
||||
scale_raw_audio: bool
|
||||
fft_win_length: float
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
@ -33,7 +30,7 @@ class MatchEvaluation:
|
||||
gt_geometry: data.Geometry | None
|
||||
|
||||
pred_score: float
|
||||
pred_class_scores: Dict[str, float]
|
||||
pred_class_scores: dict[str, float]
|
||||
pred_geometry: data.Geometry | None
|
||||
|
||||
affinity: float
|
||||
@ -66,7 +63,7 @@ class MatchEvaluation:
|
||||
@dataclass
|
||||
class ClipMatches:
|
||||
clip: data.Clip
|
||||
matches: List[MatchEvaluation]
|
||||
matches: list[MatchEvaluation]
|
||||
|
||||
|
||||
class MatcherProtocol(Protocol):
|
||||
@ -75,7 +72,7 @@ class MatcherProtocol(Protocol):
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]: ...
|
||||
) -> Iterable[tuple[int | None, int | None, float]]: ...
|
||||
|
||||
|
||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||
@ -94,7 +91,7 @@ class MetricsProtocol(Protocol):
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[Detection]],
|
||||
) -> Dict[str, float]: ...
|
||||
) -> dict[str, float]: ...
|
||||
|
||||
|
||||
class PlotterProtocol(Protocol):
|
||||
@ -102,7 +99,7 @@ class PlotterProtocol(Protocol):
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[Detection]],
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
EvaluationOutput = TypeVar("EvaluationOutput")
|
||||
@ -119,8 +116,8 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
|
||||
def compute_metrics(
|
||||
self, eval_outputs: EvaluationOutput
|
||||
) -> Dict[str, float]: ...
|
||||
) -> dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self, eval_outputs: EvaluationOutput
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||
from typing import Callable, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
@ -52,7 +52,7 @@ steps, and returns the final `Heatmaps` used for model training.
|
||||
|
||||
Augmentation = Callable[
|
||||
[torch.Tensor, data.ClipAnnotation],
|
||||
Tuple[torch.Tensor, data.ClipAnnotation],
|
||||
tuple[torch.Tensor, data.ClipAnnotation],
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
@ -147,7 +146,7 @@ def load_audio(
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: float | None = None,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
) -> tuple[int, np.ndarray]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Iterator, List, Tuple
|
||||
from typing import Any, Iterator
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -60,7 +60,7 @@ def get_default_bd_args():
|
||||
return args
|
||||
|
||||
|
||||
def list_audio_files(ip_dir: str) -> List[str]:
|
||||
def list_audio_files(ip_dir: str) -> list[str]:
|
||||
"""Get all audio files in directory.
|
||||
|
||||
Args:
|
||||
@ -86,7 +86,7 @@ def load_model(
|
||||
load_weights: bool = True,
|
||||
device: torch.device | str | None = None,
|
||||
weights_only: bool = True,
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
) -> tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
@ -185,26 +185,28 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
|
||||
def get_annotations_from_preds(
|
||||
predictions: PredictionResults,
|
||||
class_names: List[str],
|
||||
) -> List[Annotation]:
|
||||
class_names: list[str],
|
||||
) -> list[Annotation]:
|
||||
"""Get list of annotations from predictions."""
|
||||
# Get the best class prediction probability and index for each detection
|
||||
class_prob_best = predictions["class_probs"].max(0)
|
||||
class_ind_best = predictions["class_probs"].argmax(0)
|
||||
|
||||
# Pack the results into a list of dictionaries
|
||||
annotations: List[Annotation] = [
|
||||
{
|
||||
"start_time": round(float(start_time), 4),
|
||||
"end_time": round(float(end_time), 4),
|
||||
"low_freq": int(low_freq),
|
||||
"high_freq": int(high_freq),
|
||||
"class": str(class_names[class_index]),
|
||||
"class_prob": round(float(class_prob), 3),
|
||||
"det_prob": round(float(det_prob), 3),
|
||||
"individual": "-1",
|
||||
"event": "Echolocation",
|
||||
}
|
||||
annotations: list[Annotation] = [
|
||||
Annotation(
|
||||
{
|
||||
"start_time": round(float(start_time), 4),
|
||||
"end_time": round(float(end_time), 4),
|
||||
"low_freq": int(low_freq),
|
||||
"high_freq": int(high_freq),
|
||||
"class": str(class_names[class_index]),
|
||||
"class_prob": round(float(class_prob), 3),
|
||||
"det_prob": round(float(det_prob), 3),
|
||||
"individual": "-1",
|
||||
"event": "Echolocation",
|
||||
}
|
||||
)
|
||||
for (
|
||||
start_time,
|
||||
end_time,
|
||||
@ -232,7 +234,7 @@ def format_single_result(
|
||||
time_exp: float,
|
||||
duration: float,
|
||||
predictions: PredictionResults,
|
||||
class_names: List[str],
|
||||
class_names: list[str],
|
||||
) -> FileAnnotation:
|
||||
"""Format results into the format expected by the annotation tool.
|
||||
|
||||
@ -315,9 +317,9 @@ def convert_results(
|
||||
]
|
||||
|
||||
# combine into final results dictionary
|
||||
results: RunResults = {
|
||||
results: RunResults = RunResults({ # type: ignore
|
||||
"pred_dict": pred_dict,
|
||||
}
|
||||
})
|
||||
|
||||
# add spectrogram features if they exist
|
||||
if len(spec_feats) > 0 and params["spec_features"]:
|
||||
@ -413,7 +415,7 @@ def compute_spectrogram(
|
||||
sampling_rate: int,
|
||||
params: SpectrogramParameters,
|
||||
device: torch.device,
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
) -> tuple[float, torch.Tensor]:
|
||||
"""Compute a spectrogram from an audio array.
|
||||
|
||||
Will pad the audio array so that it is evenly divisible by the
|
||||
@ -475,7 +477,7 @@ def iterate_over_chunks(
|
||||
audio: np.ndarray,
|
||||
samplerate: float,
|
||||
chunk_size: float,
|
||||
) -> Iterator[Tuple[float, np.ndarray]]:
|
||||
) -> Iterator[tuple[float, np.ndarray]]:
|
||||
"""Iterate over audio in chunks of size chunk_size.
|
||||
|
||||
Parameters
|
||||
@ -509,7 +511,7 @@ def _process_spectrogram(
|
||||
samplerate: float,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[PredictionResults, np.ndarray]:
|
||||
) -> tuple[PredictionResults, np.ndarray]:
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec)
|
||||
@ -546,7 +548,7 @@ def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
# run non-max suppression
|
||||
pred_nms_list, features = pp.run_nms(
|
||||
outputs,
|
||||
@ -585,7 +587,7 @@ def process_spectrogram(
|
||||
samplerate: int,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
"""Process a spectrogram with detection model.
|
||||
|
||||
Will run non-maximum suppression on the output of the model.
|
||||
@ -604,9 +606,9 @@ def process_spectrogram(
|
||||
|
||||
Returns
|
||||
-------
|
||||
detections: List[Annotation]
|
||||
detections
|
||||
List of detections predicted by the model.
|
||||
features : np.ndarray
|
||||
features
|
||||
An array of CNN features associated with each annotation.
|
||||
The array is of shape (num_detections, num_features).
|
||||
Is empty if `config["cnn_features"]` is False.
|
||||
@ -632,7 +634,7 @@ def _process_audio_array(
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||
) -> tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||
# load audio file and compute spectrogram
|
||||
_, spec = compute_spectrogram(
|
||||
audio,
|
||||
@ -669,7 +671,7 @@ def process_audio_array(
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process a single audio array with detection model.
|
||||
|
||||
Parameters
|
||||
@ -689,7 +691,7 @@ def process_audio_array(
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
annotations : list[Annotation]
|
||||
List of annotations predicted by the model.
|
||||
features : np.ndarray
|
||||
Array of CNN features associated with each annotation.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
@ -20,11 +20,11 @@ def create_legacy_file_annotation(
|
||||
duration: float = 5.0,
|
||||
time_exp: float = 1.0,
|
||||
class_name: str = "Myotis",
|
||||
annotations: Optional[List[Dict[str, Any]]] = None,
|
||||
annotations: list[dict[str, Any]] | None = None,
|
||||
annotated: bool = True,
|
||||
issues: bool = False,
|
||||
notes: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if annotations is None:
|
||||
annotations = [
|
||||
{
|
||||
@ -61,7 +61,7 @@ def create_legacy_file_annotation(
|
||||
@pytest.fixture
|
||||
def batdetect2_files_test_setup(
|
||||
tmp_path: Path, wav_factory
|
||||
) -> Tuple[Path, Path, List[Dict[str, Any]]]:
|
||||
) -> tuple[Path, Path, list[dict[str, Any]]]:
|
||||
"""Sets up a directory structure for batdetect2 files format tests."""
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
@ -143,7 +143,7 @@ def batdetect2_files_test_setup(
|
||||
@pytest.fixture
|
||||
def batdetect2_merged_test_setup(
|
||||
tmp_path: Path, batdetect2_files_test_setup
|
||||
) -> Tuple[Path, Path, List[Dict[str, Any]]]:
|
||||
) -> tuple[Path, Path, list[dict[str, Any]]]:
|
||||
"""Sets up a directory structure for batdetect2 merged file format tests."""
|
||||
audio_dir, _, files_data = batdetect2_files_test_setup
|
||||
merged_anns_path = tmp_path / "merged_anns.json"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user