mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 23:30:21 +02:00
Removing legacy types
This commit is contained in:
parent
8ac4f4c44d
commit
1a7c0b4b3a
@ -98,7 +98,6 @@ consult the API documentation in the code.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -272,7 +271,7 @@ def process_spectrogram(
|
|||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
) -> Tuple[List[Annotation], np.ndarray]:
|
) -> tuple[list[Annotation], np.ndarray]:
|
||||||
"""Process spectrogram with model.
|
"""Process spectrogram with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -314,7 +313,7 @@ def process_audio(
|
|||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
||||||
"""Process audio array with model.
|
"""Process audio array with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -357,7 +356,7 @@ def postprocess(
|
|||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
) -> Tuple[List[Annotation], np.ndarray]:
|
) -> tuple[list[Annotation], np.ndarray]:
|
||||||
"""Postprocess model outputs.
|
"""Postprocess model outputs.
|
||||||
|
|
||||||
Convert model tensor outputs to predicted bounding boxes and
|
Convert model tensor outputs to predicted bounding boxes and
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Sequence, Tuple
|
from typing import Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -110,7 +110,7 @@ class BatDetect2API:
|
|||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
save_predictions: bool = True,
|
save_predictions: bool = True,
|
||||||
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||||
return evaluate(
|
return evaluate(
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -187,7 +187,7 @@ class BatDetect2API:
|
|||||||
def process_audio(
|
def process_audio(
|
||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
) -> List[Detection]:
|
) -> list[Detection]:
|
||||||
spec = self.generate_spectrogram(audio)
|
spec = self.generate_spectrogram(audio)
|
||||||
return self.process_spectrogram(spec)
|
return self.process_spectrogram(spec)
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
) -> List[Detection]:
|
) -> list[Detection]:
|
||||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||||
raise ValueError("Batched spectrograms not supported.")
|
raise ValueError("Batched spectrograms not supported.")
|
||||||
|
|
||||||
@ -214,7 +214,7 @@ class BatDetect2API:
|
|||||||
def process_directory(
|
def process_directory(
|
||||||
self,
|
self,
|
||||||
audio_dir: data.PathLike,
|
audio_dir: data.PathLike,
|
||||||
) -> List[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
files = list(get_audio_files(audio_dir))
|
files = list(get_audio_files(audio_dir))
|
||||||
return self.process_files(files)
|
return self.process_files(files)
|
||||||
|
|
||||||
@ -222,7 +222,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
audio_files: Sequence[data.PathLike],
|
audio_files: Sequence[data.PathLike],
|
||||||
num_workers: int | None = None,
|
num_workers: int | None = None,
|
||||||
) -> List[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
audio_files,
|
audio_files,
|
||||||
@ -238,7 +238,7 @@ class BatDetect2API:
|
|||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int | None = None,
|
||||||
) -> List[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
clips,
|
clips,
|
||||||
@ -274,7 +274,7 @@ class BatDetect2API:
|
|||||||
def load_predictions(
|
def load_predictions(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
) -> List[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
return self.formatter.load(path)
|
return self.formatter.load(path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
|
|||||||
configuration sections.
|
configuration sections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Type, TypeVar, Union, overload
|
from typing import Any, Type, TypeVar, overload
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from deepmerge.merger import Merger
|
from deepmerge.merger import Merger
|
||||||
@ -69,8 +69,7 @@ class BaseConfig(BaseModel):
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||||
|
Schema = Type[T_Model] | TypeAdapter[T]
|
||||||
Schema = Union[Type[T_Model], TypeAdapter[T]]
|
|
||||||
|
|
||||||
|
|
||||||
def get_object_field(obj: dict, current_key: str) -> Any:
|
def get_object_field(obj: dict, current_key: str) -> Any:
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -10,7 +9,7 @@ from batdetect2.typing.targets import TargetProtocol
|
|||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]:
|
) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
|
||||||
"""Iterate over sound events in a dataset.
|
"""Iterate over sound events in a dataset.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -24,7 +23,7 @@ def iterate_over_sound_events(
|
|||||||
|
|
||||||
Yields
|
Yields
|
||||||
------
|
------
|
||||||
Tuple[Optional[str], data.SoundEventAnnotation]
|
tuple[Optional[str], data.SoundEventAnnotation]
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
- The encoded class name (str) for the sound event, or None if it
|
- The encoded class name (str) for the sound event, or None if it
|
||||||
cannot be encoded to a specific class.
|
cannot be encoded to a specific class.
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
@ -15,7 +13,7 @@ def split_dataset_by_recordings(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
train_size: float = 0.75,
|
train_size: float = 0.75,
|
||||||
random_state: int | None = None,
|
random_state: int | None = None,
|
||||||
) -> Tuple[Dataset, Dataset]:
|
) -> tuple[Dataset, Dataset]:
|
||||||
recordings = extract_recordings_df(dataset)
|
recordings = extract_recordings_df(dataset)
|
||||||
|
|
||||||
sound_events = extract_sound_events_df(
|
sound_events = extract_sound_events_df(
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
"""Post-processing of the output of the model."""
|
"""Post-processing of the output of the model."""
|
||||||
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -45,7 +43,7 @@ def run_nms(
|
|||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
params: NonMaximumSuppressionConfig,
|
params: NonMaximumSuppressionConfig,
|
||||||
sampling_rate: np.ndarray,
|
sampling_rate: np.ndarray,
|
||||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
) -> tuple[list[PredictionResults], list[np.ndarray]]:
|
||||||
"""Run non-maximum suppression on the output of the model.
|
"""Run non-maximum suppression on the output of the model.
|
||||||
|
|
||||||
Model outputs processed are expected to have a batch dimension.
|
Model outputs processed are expected to have a batch dimension.
|
||||||
@ -73,8 +71,8 @@ def run_nms(
|
|||||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||||
|
|
||||||
# loop over batch to save outputs
|
# loop over batch to save outputs
|
||||||
preds: List[PredictionResults] = []
|
preds: list[PredictionResults] = []
|
||||||
feats: List[np.ndarray] = []
|
feats: list[np.ndarray] = []
|
||||||
for num_detection in range(pred_det_nms.shape[0]):
|
for num_detection in range(pred_det_nms.shape[0]):
|
||||||
# get valid indices
|
# get valid indices
|
||||||
inds_ord = torch.argsort(x_pos[num_detection, :])
|
inds_ord = torch.argsort(x_pos[num_detection, :])
|
||||||
@ -151,7 +149,7 @@ def run_nms(
|
|||||||
|
|
||||||
def non_max_suppression(
|
def non_max_suppression(
|
||||||
heat: torch.Tensor,
|
heat: torch.Tensor,
|
||||||
kernel_size: int | Tuple[int, int],
|
kernel_size: int | tuple[int, int],
|
||||||
):
|
):
|
||||||
# kernel can be an int or list/tuple
|
# kernel can be an int or list/tuple
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
|
|||||||
@ -4,12 +4,9 @@ from dataclasses import dataclass, field
|
|||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -32,7 +29,7 @@ from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
|||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
|
|
||||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
|
||||||
|
|
||||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||||
name="top_class_plot"
|
name="top_class_plot"
|
||||||
@ -73,7 +70,7 @@ class PRCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
num_positives = 0
|
num_positives = 0
|
||||||
@ -140,7 +137,7 @@ class ROCCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
|
|
||||||
@ -223,7 +220,7 @@ class ConfusionMatrix(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[tuple[str, Figure]]:
|
||||||
cm, labels = compute_confusion_matrix(
|
cm, labels = compute_confusion_matrix(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
self.targets,
|
self.targets,
|
||||||
@ -295,26 +292,26 @@ class ExampleClassificationPlot(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[tuple[str, Figure]]:
|
||||||
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||||
|
|
||||||
for class_name, matches in grouped.items():
|
for class_name, matches in grouped.items():
|
||||||
true_positives: List[MatchEval] = get_binned_sample(
|
true_positives: list[MatchEval] = get_binned_sample(
|
||||||
matches.true_positives,
|
matches.true_positives,
|
||||||
n_examples=self.num_examples,
|
n_examples=self.num_examples,
|
||||||
)
|
)
|
||||||
|
|
||||||
false_positives: List[MatchEval] = get_binned_sample(
|
false_positives: list[MatchEval] = get_binned_sample(
|
||||||
matches.false_positives,
|
matches.false_positives,
|
||||||
n_examples=self.num_examples,
|
n_examples=self.num_examples,
|
||||||
)
|
)
|
||||||
|
|
||||||
false_negatives: List[MatchEval] = random.sample(
|
false_negatives: list[MatchEval] = random.sample(
|
||||||
matches.false_negatives,
|
matches.false_negatives,
|
||||||
k=min(self.num_examples, len(matches.false_negatives)),
|
k=min(self.num_examples, len(matches.false_negatives)),
|
||||||
)
|
)
|
||||||
|
|
||||||
cross_triggers: List[MatchEval] = get_binned_sample(
|
cross_triggers: list[MatchEval] = get_binned_sample(
|
||||||
matches.cross_triggers, n_examples=self.num_examples
|
matches.cross_triggers, n_examples=self.num_examples
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -374,16 +371,16 @@ def build_top_class_plotter(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClassMatches:
|
class ClassMatches:
|
||||||
false_positives: List[MatchEval] = field(default_factory=list)
|
false_positives: list[MatchEval] = field(default_factory=list)
|
||||||
false_negatives: List[MatchEval] = field(default_factory=list)
|
false_negatives: list[MatchEval] = field(default_factory=list)
|
||||||
true_positives: List[MatchEval] = field(default_factory=list)
|
true_positives: list[MatchEval] = field(default_factory=list)
|
||||||
cross_triggers: List[MatchEval] = field(default_factory=list)
|
cross_triggers: list[MatchEval] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def group_matches(
|
def group_matches(
|
||||||
clip_evals: Sequence[ClipEval],
|
clip_evals: Sequence[ClipEval],
|
||||||
threshold: float = 0.2,
|
threshold: float = 0.2,
|
||||||
) -> Dict[str, ClassMatches]:
|
) -> dict[str, ClassMatches]:
|
||||||
class_examples = defaultdict(ClassMatches)
|
class_examples = defaultdict(ClassMatches)
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
for clip_eval in clip_evals:
|
||||||
@ -412,7 +409,7 @@ def group_matches(
|
|||||||
return class_examples
|
return class_examples
|
||||||
|
|
||||||
|
|
||||||
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
|
||||||
if len(matches) < n_examples:
|
if len(matches) < n_examples:
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.model_selection import StratifiedGroupKFold
|
from sklearn.model_selection import StratifiedGroupKFold
|
||||||
@ -12,8 +11,8 @@ from batdetect2 import types
|
|||||||
|
|
||||||
|
|
||||||
def print_dataset_stats(
|
def print_dataset_stats(
|
||||||
data: List[types.FileAnnotation],
|
data: list[types.FileAnnotation],
|
||||||
classes_to_ignore: List[str] | None = None,
|
classes_to_ignore: list[str] | None = None,
|
||||||
) -> Counter[str]:
|
) -> Counter[str]:
|
||||||
print("Num files:", len(data))
|
print("Num files:", len(data))
|
||||||
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||||
@ -22,7 +21,7 @@ def print_dataset_stats(
|
|||||||
return counts
|
return counts
|
||||||
|
|
||||||
|
|
||||||
def load_file_names(file_name: str) -> List[str]:
|
def load_file_names(file_name: str) -> list[str]:
|
||||||
if not os.path.isfile(file_name):
|
if not os.path.isfile(file_name):
|
||||||
raise FileNotFoundError(f"Input file not found - {file_name}")
|
raise FileNotFoundError(f"Input file not found - {file_name}")
|
||||||
|
|
||||||
@ -100,12 +99,12 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def split_data(
|
def split_data(
|
||||||
data: List[types.FileAnnotation],
|
data: list[types.FileAnnotation],
|
||||||
train_file: str,
|
train_file: str,
|
||||||
test_file: str,
|
test_file: str,
|
||||||
n_splits: int = 5,
|
n_splits: int = 5,
|
||||||
random_state: int = 0,
|
random_state: int = 0,
|
||||||
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
|
) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
|
||||||
if train_file != "" and test_file != "":
|
if train_file != "" and test_file != "":
|
||||||
# user has specifed the train / test split
|
# user has specifed the train / test split
|
||||||
mapping = {
|
mapping = {
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, NamedTuple, Sequence
|
from typing import NamedTuple, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -29,7 +29,7 @@ class DatasetItem(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class InferenceDataset(Dataset[DatasetItem]):
|
class InferenceDataset(Dataset[DatasetItem]):
|
||||||
clips: List[data.Clip]
|
clips: list[data.Clip]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -111,7 +111,7 @@ def build_inference_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
|
||||||
max_width = max(item.spec.shape[-1] for item in batch)
|
max_width = max(item.spec.shape[-1] for item in batch)
|
||||||
return DatasetItem(
|
return DatasetItem(
|
||||||
spec=torch.stack(
|
spec=torch.stack(
|
||||||
|
|||||||
@ -26,8 +26,6 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
|
|||||||
is the ``build_model`` factory function exported from this module.
|
is the ``build_model`` factory function exported from this module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
@ -142,7 +140,7 @@ class Model(torch.nn.Module):
|
|||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||||
"""Run the full detection pipeline on a waveform tensor.
|
"""Run the full detection pipeline on a waveform tensor.
|
||||||
|
|
||||||
Converts the waveform to a spectrogram, passes it through the
|
Converts the waveform to a spectrogram, passes it through the
|
||||||
@ -157,7 +155,7 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[ClipDetectionsTensor]
|
list[ClipDetectionsTensor]
|
||||||
One detection tensor per clip in the batch. Each tensor encodes
|
One detection tensor per clip in the batch. Each tensor encodes
|
||||||
the detected events (locations, class scores, sizes) for that
|
the detected events (locations, class scores, sizes) for that
|
||||||
clip.
|
clip.
|
||||||
|
|||||||
@ -23,7 +23,7 @@ output so that the output spatial dimensions always match the input spatial
|
|||||||
dimensions.
|
dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal, Tuple, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -58,6 +58,14 @@ from batdetect2.typing.models import (
|
|||||||
EncoderProtocol,
|
EncoderProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackboneImportConfig",
|
||||||
|
"UNetBackbone",
|
||||||
|
"BackboneConfig",
|
||||||
|
"load_backbone_config",
|
||||||
|
"build_backbone",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class UNetBackboneConfig(BaseConfig):
|
class UNetBackboneConfig(BaseConfig):
|
||||||
"""Configuration for a U-Net-style encoder-decoder backbone.
|
"""Configuration for a U-Net-style encoder-decoder backbone.
|
||||||
@ -110,15 +118,6 @@ class BackboneImportConfig(ImportConfig):
|
|||||||
name: Literal["import"] = "import"
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BackboneImportConfig",
|
|
||||||
"UNetBackbone",
|
|
||||||
"BackboneConfig",
|
|
||||||
"load_backbone_config",
|
|
||||||
"build_backbone",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class UNetBackbone(BackboneModel):
|
class UNetBackbone(BackboneModel):
|
||||||
"""U-Net-style encoder-decoder backbone network.
|
"""U-Net-style encoder-decoder backbone network.
|
||||||
|
|
||||||
@ -262,7 +261,8 @@ class UNetBackbone(BackboneModel):
|
|||||||
|
|
||||||
|
|
||||||
BackboneConfig = Annotated[
|
BackboneConfig = Annotated[
|
||||||
Union[UNetBackboneConfig,], Field(discriminator="name")
|
UNetBackboneConfig | BackboneImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -292,7 +292,7 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
|||||||
def _pad_adjust(
|
def _pad_adjust(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
factor: int = 32,
|
factor: int = 32,
|
||||||
) -> Tuple[torch.Tensor, int, int]:
|
) -> tuple[torch.Tensor, int, int]:
|
||||||
"""Pad a tensor's height and width to be divisible by ``factor``.
|
"""Pad a tensor's height and width to be divisible by ``factor``.
|
||||||
|
|
||||||
Adds zero-padding to the bottom and right edges of the tensor so that
|
Adds zero-padding to the bottom and right edges of the tensor so that
|
||||||
@ -308,7 +308,7 @@ def _pad_adjust(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tuple[torch.Tensor, int, int]
|
tuple[torch.Tensor, int, int]
|
||||||
- Padded tensor.
|
- Padded tensor.
|
||||||
- Number of rows added to the height (``h_pad``).
|
- Number of rows added to the height (``h_pad``).
|
||||||
- Number of columns added to the width (``w_pad``).
|
- Number of columns added to the width (``w_pad``).
|
||||||
|
|||||||
@ -46,7 +46,7 @@ configuration object (one of the ``*Config`` classes exported here), using
|
|||||||
a discriminated-union ``name`` field to dispatch to the correct class.
|
a discriminated-union ``name`` field to dispatch to the correct class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Tuple, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -687,7 +687,7 @@ class FreqCoordConvUpConfig(BaseConfig):
|
|||||||
up_mode: str = "bilinear"
|
up_mode: str = "bilinear"
|
||||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||||
|
|
||||||
up_scale: Tuple[int, int] = (2, 2)
|
up_scale: tuple[int, int] = (2, 2)
|
||||||
"""Scaling factor for height and width during upsampling."""
|
"""Scaling factor for height and width during upsampling."""
|
||||||
|
|
||||||
|
|
||||||
@ -706,22 +706,22 @@ class FreqCoordConvUpBlock(Block):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels
|
||||||
Number of channels in the input tensor (before upsampling).
|
Number of channels in the input tensor (before upsampling).
|
||||||
out_channels : int
|
out_channels
|
||||||
Number of output channels after the convolution.
|
Number of output channels after the convolution.
|
||||||
input_height : int
|
input_height
|
||||||
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
||||||
Used to calculate the height for coordinate feature generation after
|
Used to calculate the height for coordinate feature generation after
|
||||||
upsampling.
|
upsampling.
|
||||||
kernel_size : int, default=3
|
kernel_size
|
||||||
Size of the square convolutional kernel.
|
Size of the square convolutional kernel.
|
||||||
pad_size : int, default=1
|
pad_size
|
||||||
Padding added before convolution.
|
Padding added before convolution.
|
||||||
up_mode : str, default="bilinear"
|
up_mode
|
||||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
||||||
"bicubic").
|
"bicubic").
|
||||||
up_scale : Tuple[int, int], default=(2, 2)
|
up_scale
|
||||||
Scaling factor for height and width during upsampling
|
Scaling factor for height and width during upsampling
|
||||||
(typically (2, 2)).
|
(typically (2, 2)).
|
||||||
"""
|
"""
|
||||||
@ -734,7 +734,7 @@ class FreqCoordConvUpBlock(Block):
|
|||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: tuple[int, int] = (2, 2),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -824,7 +824,7 @@ class StandardConvUpConfig(BaseConfig):
|
|||||||
up_mode: str = "bilinear"
|
up_mode: str = "bilinear"
|
||||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||||
|
|
||||||
up_scale: Tuple[int, int] = (2, 2)
|
up_scale: tuple[int, int] = (2, 2)
|
||||||
"""Scaling factor for height and width during upsampling."""
|
"""Scaling factor for height and width during upsampling."""
|
||||||
|
|
||||||
|
|
||||||
@ -839,17 +839,17 @@ class StandardConvUpBlock(Block):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels
|
||||||
Number of channels in the input tensor (before upsampling).
|
Number of channels in the input tensor (before upsampling).
|
||||||
out_channels : int
|
out_channels
|
||||||
Number of output channels after the convolution.
|
Number of output channels after the convolution.
|
||||||
kernel_size : int, default=3
|
kernel_size
|
||||||
Size of the square convolutional kernel.
|
Size of the square convolutional kernel.
|
||||||
pad_size : int, default=1
|
pad_size
|
||||||
Padding added before convolution.
|
Padding added before convolution.
|
||||||
up_mode : str, default="bilinear"
|
up_mode
|
||||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
||||||
up_scale : Tuple[int, int], default=(2, 2)
|
up_scale
|
||||||
Scaling factor for height and width during upsampling.
|
Scaling factor for height and width during upsampling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -860,7 +860,7 @@ class StandardConvUpBlock(Block):
|
|||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: tuple[int, int] = (2, 2),
|
||||||
):
|
):
|
||||||
super(StandardConvUpBlock, self).__init__()
|
super(StandardConvUpBlock, self).__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -922,15 +922,14 @@ class StandardConvUpBlock(Block):
|
|||||||
|
|
||||||
|
|
||||||
LayerConfig = Annotated[
|
LayerConfig = Annotated[
|
||||||
Union[
|
ConvConfig
|
||||||
ConvConfig,
|
| BlockImportConfig
|
||||||
FreqCoordConvDownConfig,
|
| FreqCoordConvDownConfig
|
||||||
StandardConvDownConfig,
|
| StandardConvDownConfig
|
||||||
FreqCoordConvUpConfig,
|
| FreqCoordConvUpConfig
|
||||||
StandardConvUpConfig,
|
| StandardConvUpConfig
|
||||||
SelfAttentionConfig,
|
| SelfAttentionConfig
|
||||||
"LayerGroupConfig",
|
| "LayerGroupConfig",
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
@ -952,7 +951,7 @@ class LayerGroupConfig(BaseConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: Literal["LayerGroup"] = "LayerGroup"
|
name: Literal["LayerGroup"] = "LayerGroup"
|
||||||
layers: List[LayerConfig]
|
layers: list[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
class LayerGroup(nn.Module):
|
class LayerGroup(nn.Module):
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
|
|
||||||
@ -16,7 +14,7 @@ __all__ = [
|
|||||||
def plot_clip_annotation(
|
def plot_clip_annotation(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
@ -50,7 +48,7 @@ def plot_clip_annotation(
|
|||||||
def plot_anchor_points(
|
def plot_anchor_points(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
size: int = 1,
|
size: int = 1,
|
||||||
color: str = "red",
|
color: str = "red",
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, Tuple
|
from typing import Iterable
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -18,7 +18,7 @@ __all__ = [
|
|||||||
def plot_clip_prediction(
|
def plot_clip_prediction(
|
||||||
clip_prediction: data.ClipPrediction,
|
clip_prediction: data.ClipPrediction,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
add_legend: bool = False,
|
add_legend: bool = False,
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
@ -19,7 +17,7 @@ def plot_clip(
|
|||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
"""General plotting utilities."""
|
"""General plotting utilities."""
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -14,7 +12,7 @@ __all__ = [
|
|||||||
|
|
||||||
def create_ax(
|
def create_ax(
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
@ -31,7 +29,7 @@ def plot_spectrogram(
|
|||||||
min_freq: float | None = None,
|
min_freq: float | None = None,
|
||||||
max_freq: float | None = None,
|
max_freq: float | None = None,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_colorbar: bool = False,
|
add_colorbar: bool = False,
|
||||||
colorbar_kwargs: dict | None = None,
|
colorbar_kwargs: dict | None = None,
|
||||||
vmin: float | None = None,
|
vmin: float | None = None,
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
"""Plot heatmaps"""
|
"""Plot heatmaps"""
|
||||||
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import axes, patches
|
from matplotlib import axes, patches
|
||||||
@ -14,7 +12,7 @@ from batdetect2.plotting.common import create_ax
|
|||||||
def plot_detection_heatmap(
|
def plot_detection_heatmap(
|
||||||
heatmap: torch.Tensor | np.ndarray,
|
heatmap: torch.Tensor | np.ndarray,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
figsize: tuple[int, int] = (10, 10),
|
||||||
threshold: float | None = None,
|
threshold: float | None = None,
|
||||||
alpha: float = 1,
|
alpha: float = 1,
|
||||||
cmap: str | Colormap = "jet",
|
cmap: str | Colormap = "jet",
|
||||||
@ -50,8 +48,8 @@ def plot_detection_heatmap(
|
|||||||
def plot_classification_heatmap(
|
def plot_classification_heatmap(
|
||||||
heatmap: torch.Tensor | np.ndarray,
|
heatmap: torch.Tensor | np.ndarray,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
figsize: tuple[int, int] = (10, 10),
|
||||||
class_names: List[str] | None = None,
|
class_names: list[str] | None = None,
|
||||||
threshold: float | None = 0.1,
|
threshold: float | None = 0.1,
|
||||||
alpha: float = 1,
|
alpha: float = 1,
|
||||||
cmap: str | Colormap = "tab20",
|
cmap: str | Colormap = "tab20",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Plot functions to visualize detections and spectrograms."""
|
"""Plot functions to visualize detections and spectrograms."""
|
||||||
|
|
||||||
from typing import List, Tuple, cast
|
from typing import cast
|
||||||
|
|
||||||
import matplotlib.ticker as tick
|
import matplotlib.ticker as tick
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -27,7 +27,7 @@ def spectrogram(
|
|||||||
spec: torch.Tensor | np.ndarray,
|
spec: torch.Tensor | np.ndarray,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
cmap: str = "plasma",
|
cmap: str = "plasma",
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
@ -35,18 +35,18 @@ def spectrogram(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
spec: Spectrogram to plot.
|
||||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
config: Configuration
|
||||||
used to compute the spectrogram. Defaults to None. If None,
|
used to compute the spectrogram. Defaults to None. If None,
|
||||||
the default configuration will be used.
|
the default configuration will be used.
|
||||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
ax: Matplotlib axes object.
|
||||||
Defaults to None. if provided, the spectrogram will be plotted
|
Defaults to None. if provided, the spectrogram will be plotted
|
||||||
on this axes.
|
on this axes.
|
||||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
figsize: Figure size.
|
||||||
Defaults to None. If `ax` is None, this will be used to create
|
Defaults to None. If `ax` is None, this will be used to create
|
||||||
a new figure of the given size.
|
a new figure of the given size.
|
||||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
cmap: Colormap to use. Defaults to "plasma".
|
||||||
start_time (float, optional): Start time of the spectrogram.
|
start_time: Start time of the spectrogram.
|
||||||
Defaults to 0. This is useful if plotting a spectrogram
|
Defaults to 0. This is useful if plotting a spectrogram
|
||||||
of a segment of a longer audio file.
|
of a segment of a longer audio file.
|
||||||
|
|
||||||
@ -104,10 +104,10 @@ def spectrogram(
|
|||||||
|
|
||||||
def spectrogram_with_detections(
|
def spectrogram_with_detections(
|
||||||
spec: torch.Tensor | np.ndarray,
|
spec: torch.Tensor | np.ndarray,
|
||||||
dets: List[Annotation],
|
dets: list[Annotation],
|
||||||
config: ProcessingConfiguration | None = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
cmap: str = "plasma",
|
cmap: str = "plasma",
|
||||||
with_names: bool = True,
|
with_names: bool = True,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
@ -117,21 +117,21 @@ def spectrogram_with_detections(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
spec: Spectrogram to plot.
|
||||||
detections (List[Annotation]): List of detections.
|
detections: List of detections.
|
||||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
config: Configuration
|
||||||
used to compute the spectrogram. Defaults to None. If None,
|
used to compute the spectrogram. Defaults to None. If None,
|
||||||
the default configuration will be used.
|
the default configuration will be used.
|
||||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
ax: Matplotlib axes object.
|
||||||
Defaults to None. if provided, the spectrogram will be plotted
|
Defaults to None. if provided, the spectrogram will be plotted
|
||||||
on this axes.
|
on this axes.
|
||||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
figsize: Figure size.
|
||||||
Defaults to None. If `ax` is None, this will be used to create
|
Defaults to None. If `ax` is None, this will be used to create
|
||||||
a new figure of the given size.
|
a new figure of the given size.
|
||||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
cmap: Colormap to use. Defaults to "plasma".
|
||||||
with_names (bool, optional): Whether to plot the name of the
|
with_names: Whether to plot the name of the
|
||||||
predicted class next to the detection. Defaults to True.
|
predicted class next to the detection. Defaults to True.
|
||||||
start_time (float, optional): Start time of the spectrogram.
|
start_time: Start time of the spectrogram.
|
||||||
Defaults to 0. This is useful if plotting a spectrogram
|
Defaults to 0. This is useful if plotting a spectrogram
|
||||||
of a segment of a longer audio file.
|
of a segment of a longer audio file.
|
||||||
**kwargs: Additional keyword arguments to pass to the
|
**kwargs: Additional keyword arguments to pass to the
|
||||||
@ -167,9 +167,9 @@ def spectrogram_with_detections(
|
|||||||
|
|
||||||
|
|
||||||
def detections(
|
def detections(
|
||||||
dets: List[Annotation],
|
dets: list[Annotation],
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
with_names: bool = True,
|
with_names: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
@ -177,14 +177,14 @@ def detections(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
dets (List[Annotation]): List of detections.
|
dets: List of detections.
|
||||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
ax: Matplotlib axes object.
|
||||||
Defaults to None. if provided, the spectrogram will be plotted
|
Defaults to None. if provided, the spectrogram will be plotted
|
||||||
on this axes.
|
on this axes.
|
||||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
figsize: Figure size.
|
||||||
Defaults to None. If `ax` is None, this will be used to create
|
Defaults to None. If `ax` is None, this will be used to create
|
||||||
a new figure of the given size.
|
a new figure of the given size.
|
||||||
with_names (bool, optional): Whether to plot the name of the
|
with_names: Whether to plot the name of the
|
||||||
predicted class next to the detection. Defaults to True.
|
predicted class next to the detection. Defaults to True.
|
||||||
**kwargs: Additional keyword arguments to pass to the
|
**kwargs: Additional keyword arguments to pass to the
|
||||||
`plot.detection` function.
|
`plot.detection` function.
|
||||||
@ -214,7 +214,7 @@ def detections(
|
|||||||
def detection(
|
def detection(
|
||||||
det: Annotation,
|
det: Annotation,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
linewidth: float = 1,
|
linewidth: float = 1,
|
||||||
edgecolor: str = "w",
|
edgecolor: str = "w",
|
||||||
facecolor: str = "none",
|
facecolor: str = "none",
|
||||||
@ -224,19 +224,19 @@ def detection(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
det (Annotation): Detection to plot.
|
det: Detection to plot.
|
||||||
ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults
|
ax: Matplotlib axes object. Defaults
|
||||||
to None. If provided, the spectrogram will be plotted on this axes.
|
to None. If provided, the spectrogram will be plotted on this axes.
|
||||||
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
|
figsize: Figure size. Defaults
|
||||||
to None. If `ax` is None, this will be used to create a new figure
|
to None. If `ax` is None, this will be used to create a new figure
|
||||||
of the given size.
|
of the given size.
|
||||||
linewidth (float, optional): Line width of the detection.
|
linewidth: Line width of the detection.
|
||||||
Defaults to 1.
|
Defaults to 1.
|
||||||
edgecolor (str, optional): Edge color of the detection.
|
edgecolor: Edge color of the detection.
|
||||||
Defaults to "w", i.e. white.
|
Defaults to "w", i.e. white.
|
||||||
facecolor (str, optional): Face color of the detection.
|
facecolor: Face color of the detection.
|
||||||
Defaults to "none", i.e. transparent.
|
Defaults to "none", i.e. transparent.
|
||||||
with_name (bool, optional): Whether to plot the name of the
|
with_name: Whether to plot the name of the
|
||||||
predicted class next to the detection. Defaults to True.
|
predicted class next to the detection. Defaults to True.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -277,22 +277,22 @@ def detection(
|
|||||||
|
|
||||||
|
|
||||||
def _compute_spec_extent(
|
def _compute_spec_extent(
|
||||||
shape: Tuple[int, int],
|
shape: tuple[int, int],
|
||||||
params: SpectrogramParameters,
|
params: SpectrogramParameters,
|
||||||
) -> Tuple[float, float, float, float]:
|
) -> tuple[float, float, float, float]:
|
||||||
"""Compute the extent of a spectrogram.
|
"""Compute the extent of a spectrogram.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
shape (Tuple[int, int]): Shape of the spectrogram.
|
shape: Shape of the spectrogram.
|
||||||
The first dimension is the frequency axis and the second
|
The first dimension is the frequency axis and the second
|
||||||
dimension is the time axis.
|
dimension is the time axis.
|
||||||
params (SpectrogramParameters): Spectrogram parameters.
|
params: Spectrogram parameters.
|
||||||
Should be the same as the ones used to compute the spectrogram.
|
Should be the same as the ones used to compute the spectrogram.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tuple[float, float, float, float]: Extent of the spectrogram.
|
tuple[float, float, float, float]: Extent of the spectrogram.
|
||||||
The first two values are the minimum and maximum time values,
|
The first two values are the minimum and maximum time values,
|
||||||
the last two values are the minimum and maximum frequency values.
|
the last two values are the minimum and maximum frequency values.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Protocol, Tuple
|
from typing import Protocol
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
@ -40,7 +40,7 @@ def plot_false_positive_match(
|
|||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
@ -111,7 +111,7 @@ def plot_false_negative_match(
|
|||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
@ -171,7 +171,7 @@ def plot_true_positive_match(
|
|||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
@ -259,7 +259,7 @@ def plot_cross_trigger_match(
|
|||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
from cycler import cycler
|
from cycler import cycler
|
||||||
@ -34,14 +32,14 @@ def plot_pr_curve(
|
|||||||
recall: np.ndarray,
|
recall: np.ndarray,
|
||||||
thresholds: np.ndarray,
|
thresholds: np.ndarray,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
color: str | Tuple[float, float, float] | None = None,
|
color: str | tuple[float, float, float] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
add_legend: bool = False,
|
add_legend: bool = False,
|
||||||
marker: str | Tuple[int, int, float] | None = "o",
|
marker: str | tuple[int, int, float] | None = "o",
|
||||||
markeredgecolor: str | Tuple[float, float, float] | None = None,
|
markeredgecolor: str | tuple[float, float, float] | None = None,
|
||||||
markersize: float | None = None,
|
markersize: float | None = None,
|
||||||
linestyle: str | Tuple[int, ...] | None = None,
|
linestyle: str | tuple[int, ...] | None = None,
|
||||||
linewidth: float | None = None,
|
linewidth: float | None = None,
|
||||||
label: str = "PR Curve",
|
label: str = "PR Curve",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
@ -76,9 +74,9 @@ def plot_pr_curve(
|
|||||||
|
|
||||||
|
|
||||||
def plot_pr_curves(
|
def plot_pr_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
include_ap: bool = False,
|
include_ap: bool = False,
|
||||||
@ -119,7 +117,7 @@ def plot_threshold_precision_curve(
|
|||||||
threshold: np.ndarray,
|
threshold: np.ndarray,
|
||||||
precision: np.ndarray,
|
precision: np.ndarray,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -139,9 +137,9 @@ def plot_threshold_precision_curve(
|
|||||||
|
|
||||||
|
|
||||||
def plot_threshold_precision_curves(
|
def plot_threshold_precision_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
@ -177,7 +175,7 @@ def plot_threshold_recall_curve(
|
|||||||
threshold: np.ndarray,
|
threshold: np.ndarray,
|
||||||
recall: np.ndarray,
|
recall: np.ndarray,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -197,9 +195,9 @@ def plot_threshold_recall_curve(
|
|||||||
|
|
||||||
|
|
||||||
def plot_threshold_recall_curves(
|
def plot_threshold_recall_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
@ -236,7 +234,7 @@ def plot_roc_curve(
|
|||||||
tpr: np.ndarray,
|
tpr: np.ndarray,
|
||||||
thresholds: np.ndarray,
|
thresholds: np.ndarray,
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -260,9 +258,9 @@ def plot_roc_curve(
|
|||||||
|
|
||||||
|
|
||||||
def plot_roc_curves(
|
def plot_roc_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: axes.Axes | None = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
|
|||||||
@ -11,8 +11,6 @@ activations that have lower scores than a local maximum. This helps prevent
|
|||||||
multiple, overlapping detections originating from the same sound event.
|
multiple, overlapping detections originating from the same sound event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
NMS_KERNEL_SIZE = 9
|
NMS_KERNEL_SIZE = 9
|
||||||
@ -27,7 +25,7 @@ BatDetect2.
|
|||||||
|
|
||||||
def non_max_suppression(
|
def non_max_suppression(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
||||||
|
|
||||||
@ -42,11 +40,11 @@ def non_max_suppression(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
tensor : torch.Tensor
|
tensor
|
||||||
Input tensor, typically representing a detection heatmap. Must be a
|
Input tensor, typically representing a detection heatmap. Must be a
|
||||||
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
|
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
|
||||||
`torch.nn.functional.max_pool2d` operation.
|
`torch.nn.functional.max_pool2d` operation.
|
||||||
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
|
kernel_size
|
||||||
Size of the sliding window neighborhood used to find local maxima.
|
Size of the sliding window neighborhood used to find local maxima.
|
||||||
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
|
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
|
||||||
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
|
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -51,7 +49,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
max_freq: float,
|
max_freq: float,
|
||||||
top_k_per_sec: int = 200,
|
top_k_per_sec: int = 200,
|
||||||
detection_threshold: float = 0.01,
|
detection_threshold: float = 0.01,
|
||||||
nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
nms_kernel_size: int | tuple[int, int] = NMS_KERNEL_SIZE,
|
||||||
):
|
):
|
||||||
"""Initialize the Postprocessor."""
|
"""Initialize the Postprocessor."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -66,8 +64,8 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
start_times: List[float] | None = None,
|
start_times: list[float] | None = None,
|
||||||
) -> List[ClipDetectionsTensor]:
|
) -> list[ClipDetectionsTensor]:
|
||||||
detection_heatmap = non_max_suppression(
|
detection_heatmap = non_max_suppression(
|
||||||
output.detection_probs.detach(),
|
output.detection_probs.detach(),
|
||||||
kernel_size=self.nms_kernel_size,
|
kernel_size=self.nms_kernel_size,
|
||||||
|
|||||||
@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
|
|||||||
*geometric* aspect of target definition from *semantic* classification.
|
*geometric* aspect of target definition from *semantic* classification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal, Tuple
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -144,7 +144,7 @@ class AnchorBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
dimension_names : List[str]
|
dimension_names : list[str]
|
||||||
The output dimension names: `['width', 'height']`.
|
The output dimension names: `['width', 'height']`.
|
||||||
anchor : Anchor
|
anchor : Anchor
|
||||||
The configured anchor point type (e.g., "center", "bottom-left").
|
The configured anchor point type (e.g., "center", "bottom-left").
|
||||||
@ -177,7 +177,7 @@ class AnchorBBoxMapper(ROITargetMapper):
|
|||||||
self.time_scale = time_scale
|
self.time_scale = time_scale
|
||||||
self.frequency_scale = frequency_scale
|
self.frequency_scale = frequency_scale
|
||||||
|
|
||||||
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
|
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||||
"""Encode a SoundEvent into an anchor position and scaled box size.
|
"""Encode a SoundEvent into an anchor position and scaled box size.
|
||||||
|
|
||||||
The position is determined by the configured anchor on the sound
|
The position is determined by the configured anchor on the sound
|
||||||
@ -190,7 +190,7 @@ class AnchorBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tuple[Position, Size]
|
tuple[Position, Size]
|
||||||
A tuple of (anchor_position, [scaled_width, scaled_height]).
|
A tuple of (anchor_position, [scaled_width, scaled_height]).
|
||||||
"""
|
"""
|
||||||
from soundevent import geometry
|
from soundevent import geometry
|
||||||
@ -314,7 +314,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
dimension_names : List[str]
|
dimension_names : list[str]
|
||||||
The output dimension names: `['left', 'bottom', 'right', 'top']`.
|
The output dimension names: `['left', 'bottom', 'right', 'top']`.
|
||||||
preprocessor : PreprocessorProtocol
|
preprocessor : PreprocessorProtocol
|
||||||
The spectrogram preprocessor instance.
|
The spectrogram preprocessor instance.
|
||||||
@ -371,7 +371,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tuple[Position, Size]
|
tuple[Position, Size]
|
||||||
A tuple of (peak_position, [l, b, r, t] distances).
|
A tuple of (peak_position, [l, b, r, t] distances).
|
||||||
"""
|
"""
|
||||||
from soundevent import geometry
|
from soundevent import geometry
|
||||||
@ -519,14 +519,14 @@ def _build_bounding_box(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
pos : Tuple[float, float]
|
pos
|
||||||
Reference position (time, frequency).
|
Reference position (time, frequency).
|
||||||
duration : float
|
duration
|
||||||
The required *unscaled* duration (width) of the bounding box.
|
The required *unscaled* duration (width) of the bounding box.
|
||||||
bandwidth : float
|
bandwidth
|
||||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||||
box.
|
box.
|
||||||
anchor : Anchor
|
anchor
|
||||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, List, Tuple
|
from typing import Iterable
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -33,20 +33,20 @@ class Targets(TargetProtocol):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
class_names : List[str]
|
class_names
|
||||||
An ordered list of the unique names of the specific target classes
|
An ordered list of the unique names of the specific target classes
|
||||||
defined in the configuration.
|
defined in the configuration.
|
||||||
generic_class_tags : List[data.Tag]
|
generic_class_tags
|
||||||
A list of `soundevent.data.Tag` objects representing the configured
|
A list of `soundevent.data.Tag` objects representing the configured
|
||||||
generic class category (used when no specific class matches).
|
generic class category (used when no specific class matches).
|
||||||
dimension_names : List[str]
|
dimension_names
|
||||||
The names of the size dimensions handled by the ROI mapper
|
The names of the size dimensions handled by the ROI mapper
|
||||||
(e.g., ['width', 'height']).
|
(e.g., ['width', 'height']).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: list[str]
|
||||||
detection_class_tags: List[data.Tag]
|
detection_class_tags: list[data.Tag]
|
||||||
dimension_names: List[str]
|
dimension_names: list[str]
|
||||||
detection_class_name: str
|
detection_class_name: str
|
||||||
|
|
||||||
def __init__(self, config: TargetConfig):
|
def __init__(self, config: TargetConfig):
|
||||||
@ -128,7 +128,7 @@ class Targets(TargetProtocol):
|
|||||||
"""
|
"""
|
||||||
return self._encode_fn(sound_event)
|
return self._encode_fn(sound_event)
|
||||||
|
|
||||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
def decode_class(self, class_label: str) -> list[data.Tag]:
|
||||||
"""Decode a predicted class name back into representative tags.
|
"""Decode a predicted class name back into representative tags.
|
||||||
|
|
||||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||||
@ -142,7 +142,7 @@ class Targets(TargetProtocol):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[data.Tag]
|
list[data.Tag]
|
||||||
The list of tags corresponding to the input class name.
|
The list of tags corresponding to the input class name.
|
||||||
"""
|
"""
|
||||||
return self._decode_fn(class_label)
|
return self._decode_fn(class_label)
|
||||||
@ -161,7 +161,7 @@ class Targets(TargetProtocol):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tuple[float, float]
|
tuple[float, float]
|
||||||
The reference position `(time, frequency)`.
|
The reference position `(time, frequency)`.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
@ -192,9 +192,9 @@ class Targets(TargetProtocol):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
pos : Tuple[float, float]
|
pos
|
||||||
The reference position `(time, frequency)`.
|
The reference position `(time, frequency)`.
|
||||||
dims : np.ndarray
|
dims
|
||||||
NumPy array with size dimensions (e.g., from model prediction),
|
NumPy array with size dimensions (e.g., from model prediction),
|
||||||
matching the order in `self.dimension_names`.
|
matching the order in `self.dimension_names`.
|
||||||
|
|
||||||
@ -292,7 +292,7 @@ def load_targets(
|
|||||||
def iterate_encoded_sound_events(
|
def iterate_encoded_sound_events(
|
||||||
sound_events: Iterable[data.SoundEventAnnotation],
|
sound_events: Iterable[data.SoundEventAnnotation],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> Iterable[Tuple[str | None, Position, Size]]:
|
) -> Iterable[tuple[str | None, Position, Size]]:
|
||||||
for sound_event in sound_events:
|
for sound_event in sound_events:
|
||||||
if not targets.filter(sound_event):
|
if not targets.filter(sound_event):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
@ -97,7 +97,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
) -> Tuple[Model, "BatDetect2Config"]:
|
) -> tuple[Model, "BatDetect2Config"]:
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
return module.model, module.config
|
return module.model, module.config
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import Any, List, NamedTuple, TypedDict
|
from typing import Any, NamedTuple, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -86,7 +86,7 @@ class ModelParameters(TypedDict):
|
|||||||
resize_factor: float
|
resize_factor: float
|
||||||
"""Resize factor."""
|
"""Resize factor."""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: list[str]
|
||||||
"""Class names.
|
"""Class names.
|
||||||
|
|
||||||
The model is trained to detect these classes.
|
The model is trained to detect these classes.
|
||||||
@ -158,7 +158,7 @@ class FileAnnotation(TypedDict):
|
|||||||
notes: str
|
notes: str
|
||||||
"""Notes of file."""
|
"""Notes of file."""
|
||||||
|
|
||||||
annotation: List[Annotation]
|
annotation: list[Annotation]
|
||||||
"""List of annotations."""
|
"""List of annotations."""
|
||||||
|
|
||||||
|
|
||||||
@ -168,26 +168,26 @@ class RunResults(TypedDict):
|
|||||||
pred_dict: FileAnnotation
|
pred_dict: FileAnnotation
|
||||||
"""Predictions in the format expected by the annotation tool."""
|
"""Predictions in the format expected by the annotation tool."""
|
||||||
|
|
||||||
spec_feats: NotRequired[List[np.ndarray]]
|
spec_feats: NotRequired[list[np.ndarray]]
|
||||||
"""Spectrogram features."""
|
"""Spectrogram features."""
|
||||||
|
|
||||||
spec_feat_names: NotRequired[List[str]]
|
spec_feat_names: NotRequired[list[str]]
|
||||||
"""Spectrogram feature names."""
|
"""Spectrogram feature names."""
|
||||||
|
|
||||||
cnn_feats: NotRequired[List[np.ndarray]]
|
cnn_feats: NotRequired[list[np.ndarray]]
|
||||||
"""CNN features."""
|
"""CNN features."""
|
||||||
|
|
||||||
cnn_feat_names: NotRequired[List[str]]
|
cnn_feat_names: NotRequired[list[str]]
|
||||||
"""CNN feature names."""
|
"""CNN feature names."""
|
||||||
|
|
||||||
spec_slices: NotRequired[List[np.ndarray]]
|
spec_slices: NotRequired[list[np.ndarray]]
|
||||||
"""Spectrogram slices."""
|
"""Spectrogram slices."""
|
||||||
|
|
||||||
|
|
||||||
class ResultParams(TypedDict):
|
class ResultParams(TypedDict):
|
||||||
"""Result parameters."""
|
"""Result parameters."""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: list[str]
|
||||||
"""Class names."""
|
"""Class names."""
|
||||||
|
|
||||||
spec_features: bool
|
spec_features: bool
|
||||||
@ -234,7 +234,7 @@ class ProcessingConfiguration(TypedDict):
|
|||||||
scale_raw_audio: bool
|
scale_raw_audio: bool
|
||||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: list[str]
|
||||||
"""Names of the classes the model can detect."""
|
"""Names of the classes the model can detect."""
|
||||||
|
|
||||||
detection_threshold: float
|
detection_threshold: float
|
||||||
@ -466,7 +466,7 @@ class FeatureExtractionParameters(TypedDict):
|
|||||||
class HeatmapParameters(TypedDict):
|
class HeatmapParameters(TypedDict):
|
||||||
"""Parameters that control the heatmap generation function."""
|
"""Parameters that control the heatmap generation function."""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: list[str]
|
||||||
|
|
||||||
fft_win_length: float
|
fft_win_length: float
|
||||||
"""Length of the FFT window in seconds."""
|
"""Length of the FFT window in seconds."""
|
||||||
@ -553,15 +553,15 @@ class AudioLoaderAnnotationGroup(TypedDict):
|
|||||||
individual_ids: np.ndarray
|
individual_ids: np.ndarray
|
||||||
x_inds: np.ndarray
|
x_inds: np.ndarray
|
||||||
y_inds: np.ndarray
|
y_inds: np.ndarray
|
||||||
annotation: List[Annotation]
|
annotation: list[Annotation]
|
||||||
annotated: bool
|
annotated: bool
|
||||||
class_id_file: int
|
class_id_file: int
|
||||||
"""ID of the class of the file."""
|
"""ID of the class of the file."""
|
||||||
|
|
||||||
|
|
||||||
class AudioLoaderParameters(TypedDict):
|
class AudioLoaderParameters(TypedDict):
|
||||||
class_names: List[str]
|
class_names: list[str]
|
||||||
classes_to_ignore: List[str]
|
classes_to_ignore: list[str]
|
||||||
target_samp_rate: int
|
target_samp_rate: int
|
||||||
scale_raw_audio: bool
|
scale_raw_audio: bool
|
||||||
fft_win_length: float
|
fft_win_length: float
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,7 +30,7 @@ class MatchEvaluation:
|
|||||||
gt_geometry: data.Geometry | None
|
gt_geometry: data.Geometry | None
|
||||||
|
|
||||||
pred_score: float
|
pred_score: float
|
||||||
pred_class_scores: Dict[str, float]
|
pred_class_scores: dict[str, float]
|
||||||
pred_geometry: data.Geometry | None
|
pred_geometry: data.Geometry | None
|
||||||
|
|
||||||
affinity: float
|
affinity: float
|
||||||
@ -66,7 +63,7 @@ class MatchEvaluation:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ClipMatches:
|
class ClipMatches:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
matches: List[MatchEvaluation]
|
matches: list[MatchEvaluation]
|
||||||
|
|
||||||
|
|
||||||
class MatcherProtocol(Protocol):
|
class MatcherProtocol(Protocol):
|
||||||
@ -75,7 +72,7 @@ class MatcherProtocol(Protocol):
|
|||||||
ground_truth: Sequence[data.Geometry],
|
ground_truth: Sequence[data.Geometry],
|
||||||
predictions: Sequence[data.Geometry],
|
predictions: Sequence[data.Geometry],
|
||||||
scores: Sequence[float],
|
scores: Sequence[float],
|
||||||
) -> Iterable[Tuple[int | None, int | None, float]]: ...
|
) -> Iterable[tuple[int | None, int | None, float]]: ...
|
||||||
|
|
||||||
|
|
||||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||||
@ -94,7 +91,7 @@ class MetricsProtocol(Protocol):
|
|||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[Detection]],
|
predictions: Sequence[Sequence[Detection]],
|
||||||
) -> Dict[str, float]: ...
|
) -> dict[str, float]: ...
|
||||||
|
|
||||||
|
|
||||||
class PlotterProtocol(Protocol):
|
class PlotterProtocol(Protocol):
|
||||||
@ -102,7 +99,7 @@ class PlotterProtocol(Protocol):
|
|||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[Detection]],
|
predictions: Sequence[Sequence[Detection]],
|
||||||
) -> Iterable[Tuple[str, Figure]]: ...
|
) -> Iterable[tuple[str, Figure]]: ...
|
||||||
|
|
||||||
|
|
||||||
EvaluationOutput = TypeVar("EvaluationOutput")
|
EvaluationOutput = TypeVar("EvaluationOutput")
|
||||||
@ -119,8 +116,8 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
|||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(
|
||||||
self, eval_outputs: EvaluationOutput
|
self, eval_outputs: EvaluationOutput
|
||||||
) -> Dict[str, float]: ...
|
) -> dict[str, float]: ...
|
||||||
|
|
||||||
def generate_plots(
|
def generate_plots(
|
||||||
self, eval_outputs: EvaluationOutput
|
self, eval_outputs: EvaluationOutput
|
||||||
) -> Iterable[Tuple[str, Figure]]: ...
|
) -> Iterable[tuple[str, Figure]]: ...
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
from typing import Callable, NamedTuple, Protocol
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -52,7 +52,7 @@ steps, and returns the final `Heatmaps` used for model training.
|
|||||||
|
|
||||||
Augmentation = Callable[
|
Augmentation = Callable[
|
||||||
[torch.Tensor, data.ClipAnnotation],
|
[torch.Tensor, data.ClipAnnotation],
|
||||||
Tuple[torch.Tensor, data.ClipAnnotation],
|
tuple[torch.Tensor, data.ClipAnnotation],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
@ -147,7 +146,7 @@ def load_audio(
|
|||||||
target_samp_rate: int,
|
target_samp_rate: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: float | None = None,
|
max_duration: float | None = None,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> tuple[int, np.ndarray]:
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Iterator, List, Tuple
|
from typing import Any, Iterator
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -60,7 +60,7 @@ def get_default_bd_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def list_audio_files(ip_dir: str) -> List[str]:
|
def list_audio_files(ip_dir: str) -> list[str]:
|
||||||
"""Get all audio files in directory.
|
"""Get all audio files in directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -86,7 +86,7 @@ def load_model(
|
|||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
device: torch.device | str | None = None,
|
device: torch.device | str | None = None,
|
||||||
weights_only: bool = True,
|
weights_only: bool = True,
|
||||||
) -> Tuple[DetectionModel, ModelParameters]:
|
) -> tuple[DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -185,26 +185,28 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|||||||
|
|
||||||
def get_annotations_from_preds(
|
def get_annotations_from_preds(
|
||||||
predictions: PredictionResults,
|
predictions: PredictionResults,
|
||||||
class_names: List[str],
|
class_names: list[str],
|
||||||
) -> List[Annotation]:
|
) -> list[Annotation]:
|
||||||
"""Get list of annotations from predictions."""
|
"""Get list of annotations from predictions."""
|
||||||
# Get the best class prediction probability and index for each detection
|
# Get the best class prediction probability and index for each detection
|
||||||
class_prob_best = predictions["class_probs"].max(0)
|
class_prob_best = predictions["class_probs"].max(0)
|
||||||
class_ind_best = predictions["class_probs"].argmax(0)
|
class_ind_best = predictions["class_probs"].argmax(0)
|
||||||
|
|
||||||
# Pack the results into a list of dictionaries
|
# Pack the results into a list of dictionaries
|
||||||
annotations: List[Annotation] = [
|
annotations: list[Annotation] = [
|
||||||
{
|
Annotation(
|
||||||
"start_time": round(float(start_time), 4),
|
{
|
||||||
"end_time": round(float(end_time), 4),
|
"start_time": round(float(start_time), 4),
|
||||||
"low_freq": int(low_freq),
|
"end_time": round(float(end_time), 4),
|
||||||
"high_freq": int(high_freq),
|
"low_freq": int(low_freq),
|
||||||
"class": str(class_names[class_index]),
|
"high_freq": int(high_freq),
|
||||||
"class_prob": round(float(class_prob), 3),
|
"class": str(class_names[class_index]),
|
||||||
"det_prob": round(float(det_prob), 3),
|
"class_prob": round(float(class_prob), 3),
|
||||||
"individual": "-1",
|
"det_prob": round(float(det_prob), 3),
|
||||||
"event": "Echolocation",
|
"individual": "-1",
|
||||||
}
|
"event": "Echolocation",
|
||||||
|
}
|
||||||
|
)
|
||||||
for (
|
for (
|
||||||
start_time,
|
start_time,
|
||||||
end_time,
|
end_time,
|
||||||
@ -232,7 +234,7 @@ def format_single_result(
|
|||||||
time_exp: float,
|
time_exp: float,
|
||||||
duration: float,
|
duration: float,
|
||||||
predictions: PredictionResults,
|
predictions: PredictionResults,
|
||||||
class_names: List[str],
|
class_names: list[str],
|
||||||
) -> FileAnnotation:
|
) -> FileAnnotation:
|
||||||
"""Format results into the format expected by the annotation tool.
|
"""Format results into the format expected by the annotation tool.
|
||||||
|
|
||||||
@ -315,9 +317,9 @@ def convert_results(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# combine into final results dictionary
|
# combine into final results dictionary
|
||||||
results: RunResults = {
|
results: RunResults = RunResults({ # type: ignore
|
||||||
"pred_dict": pred_dict,
|
"pred_dict": pred_dict,
|
||||||
}
|
})
|
||||||
|
|
||||||
# add spectrogram features if they exist
|
# add spectrogram features if they exist
|
||||||
if len(spec_feats) > 0 and params["spec_features"]:
|
if len(spec_feats) > 0 and params["spec_features"]:
|
||||||
@ -413,7 +415,7 @@ def compute_spectrogram(
|
|||||||
sampling_rate: int,
|
sampling_rate: int,
|
||||||
params: SpectrogramParameters,
|
params: SpectrogramParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[float, torch.Tensor]:
|
) -> tuple[float, torch.Tensor]:
|
||||||
"""Compute a spectrogram from an audio array.
|
"""Compute a spectrogram from an audio array.
|
||||||
|
|
||||||
Will pad the audio array so that it is evenly divisible by the
|
Will pad the audio array so that it is evenly divisible by the
|
||||||
@ -475,7 +477,7 @@ def iterate_over_chunks(
|
|||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samplerate: float,
|
samplerate: float,
|
||||||
chunk_size: float,
|
chunk_size: float,
|
||||||
) -> Iterator[Tuple[float, np.ndarray]]:
|
) -> Iterator[tuple[float, np.ndarray]]:
|
||||||
"""Iterate over audio in chunks of size chunk_size.
|
"""Iterate over audio in chunks of size chunk_size.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -509,7 +511,7 @@ def _process_spectrogram(
|
|||||||
samplerate: float,
|
samplerate: float,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[PredictionResults, np.ndarray]:
|
) -> tuple[PredictionResults, np.ndarray]:
|
||||||
# evaluate model
|
# evaluate model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(spec)
|
outputs = model(spec)
|
||||||
@ -546,7 +548,7 @@ def postprocess_model_outputs(
|
|||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
samp_rate: int,
|
samp_rate: int,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[List[Annotation], np.ndarray]:
|
) -> tuple[list[Annotation], np.ndarray]:
|
||||||
# run non-max suppression
|
# run non-max suppression
|
||||||
pred_nms_list, features = pp.run_nms(
|
pred_nms_list, features = pp.run_nms(
|
||||||
outputs,
|
outputs,
|
||||||
@ -585,7 +587,7 @@ def process_spectrogram(
|
|||||||
samplerate: int,
|
samplerate: int,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[List[Annotation], np.ndarray]:
|
) -> tuple[list[Annotation], np.ndarray]:
|
||||||
"""Process a spectrogram with detection model.
|
"""Process a spectrogram with detection model.
|
||||||
|
|
||||||
Will run non-maximum suppression on the output of the model.
|
Will run non-maximum suppression on the output of the model.
|
||||||
@ -604,9 +606,9 @@ def process_spectrogram(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
detections: List[Annotation]
|
detections
|
||||||
List of detections predicted by the model.
|
List of detections predicted by the model.
|
||||||
features : np.ndarray
|
features
|
||||||
An array of CNN features associated with each annotation.
|
An array of CNN features associated with each annotation.
|
||||||
The array is of shape (num_detections, num_features).
|
The array is of shape (num_detections, num_features).
|
||||||
Is empty if `config["cnn_features"]` is False.
|
Is empty if `config["cnn_features"]` is False.
|
||||||
@ -632,7 +634,7 @@ def _process_audio_array(
|
|||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
) -> tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||||
# load audio file and compute spectrogram
|
# load audio file and compute spectrogram
|
||||||
_, spec = compute_spectrogram(
|
_, spec = compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
@ -669,7 +671,7 @@ def process_audio_array(
|
|||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
||||||
"""Process a single audio array with detection model.
|
"""Process a single audio array with detection model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -689,7 +691,7 @@ def process_audio_array(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
annotations : List[Annotation]
|
annotations : list[Annotation]
|
||||||
List of annotations predicted by the model.
|
List of annotations predicted by the model.
|
||||||
features : np.ndarray
|
features : np.ndarray
|
||||||
Array of CNN features associated with each annotation.
|
Array of CNN features associated with each annotation.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -20,11 +20,11 @@ def create_legacy_file_annotation(
|
|||||||
duration: float = 5.0,
|
duration: float = 5.0,
|
||||||
time_exp: float = 1.0,
|
time_exp: float = 1.0,
|
||||||
class_name: str = "Myotis",
|
class_name: str = "Myotis",
|
||||||
annotations: Optional[List[Dict[str, Any]]] = None,
|
annotations: list[dict[str, Any]] | None = None,
|
||||||
annotated: bool = True,
|
annotated: bool = True,
|
||||||
issues: bool = False,
|
issues: bool = False,
|
||||||
notes: str = "",
|
notes: str = "",
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if annotations is None:
|
if annotations is None:
|
||||||
annotations = [
|
annotations = [
|
||||||
{
|
{
|
||||||
@ -61,7 +61,7 @@ def create_legacy_file_annotation(
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def batdetect2_files_test_setup(
|
def batdetect2_files_test_setup(
|
||||||
tmp_path: Path, wav_factory
|
tmp_path: Path, wav_factory
|
||||||
) -> Tuple[Path, Path, List[Dict[str, Any]]]:
|
) -> tuple[Path, Path, list[dict[str, Any]]]:
|
||||||
"""Sets up a directory structure for batdetect2 files format tests."""
|
"""Sets up a directory structure for batdetect2 files format tests."""
|
||||||
audio_dir = tmp_path / "audio"
|
audio_dir = tmp_path / "audio"
|
||||||
audio_dir.mkdir()
|
audio_dir.mkdir()
|
||||||
@ -143,7 +143,7 @@ def batdetect2_files_test_setup(
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def batdetect2_merged_test_setup(
|
def batdetect2_merged_test_setup(
|
||||||
tmp_path: Path, batdetect2_files_test_setup
|
tmp_path: Path, batdetect2_files_test_setup
|
||||||
) -> Tuple[Path, Path, List[Dict[str, Any]]]:
|
) -> tuple[Path, Path, list[dict[str, Any]]]:
|
||||||
"""Sets up a directory structure for batdetect2 merged file format tests."""
|
"""Sets up a directory structure for batdetect2 merged file format tests."""
|
||||||
audio_dir, _, files_data = batdetect2_files_test_setup
|
audio_dir, _, files_data = batdetect2_files_test_setup
|
||||||
merged_anns_path = tmp_path / "merged_anns.json"
|
merged_anns_path = tmp_path / "merged_anns.json"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user