mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Update type hints to python 3.10
This commit is contained in:
parent
9c72537ddd
commit
2563f26ed3
@ -93,7 +93,7 @@ mlflow = ["mlflow>=3.1.1"]
|
|||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
target-version = "py39"
|
target-version = "py310"
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
@ -107,7 +107,7 @@ convention = "numpy"
|
|||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = ["src", "tests"]
|
include = ["src", "tests"]
|
||||||
pythonVersion = "3.9"
|
pythonVersion = "3.10"
|
||||||
pythonPlatform = "All"
|
pythonPlatform = "All"
|
||||||
exclude = [
|
exclude = [
|
||||||
"src/batdetect2/detector/",
|
"src/batdetect2/detector/",
|
||||||
|
|||||||
@ -165,7 +165,7 @@ def load_audio(
|
|||||||
time_exp_fact: float = 1,
|
time_exp_fact: float = 1,
|
||||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: float | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load audio from file.
|
"""Load audio from file.
|
||||||
|
|
||||||
@ -203,7 +203,7 @@ def load_audio(
|
|||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: Optional[SpectrogramParameters] = None,
|
config: SpectrogramParameters | None = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Generate spectrogram from audio array.
|
"""Generate spectrogram from audio array.
|
||||||
@ -240,7 +240,7 @@ def generate_spectrogram(
|
|||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> du.RunResults:
|
) -> du.RunResults:
|
||||||
"""Process audio file with model.
|
"""Process audio file with model.
|
||||||
@ -271,7 +271,7 @@ def process_spectrogram(
|
|||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
) -> Tuple[List[Annotation], np.ndarray]:
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
"""Process spectrogram with model.
|
"""Process spectrogram with model.
|
||||||
|
|
||||||
@ -312,7 +312,7 @@ def process_audio(
|
|||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||||
"""Process audio array with model.
|
"""Process audio array with model.
|
||||||
@ -356,7 +356,7 @@ def process_audio(
|
|||||||
def postprocess(
|
def postprocess(
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
) -> Tuple[List[Annotation], np.ndarray]:
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
"""Postprocess model outputs.
|
"""Postprocess model outputs.
|
||||||
|
|
||||||
|
|||||||
@ -67,22 +67,22 @@ class BatDetect2API:
|
|||||||
def load_annotations(
|
def load_annotations(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
base_dir: Optional[data.PathLike] = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
return load_dataset_from_config(path, base_dir=base_dir)
|
return load_dataset_from_config(path, base_dir=base_dir)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: int | None = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: int | None = None,
|
||||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
num_epochs: Optional[int] = None,
|
num_epochs: int | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
):
|
):
|
||||||
train(
|
train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
@ -105,10 +105,10 @@ class BatDetect2API:
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
save_predictions: bool = True,
|
save_predictions: bool = True,
|
||||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||||
return evaluate(
|
return evaluate(
|
||||||
@ -129,7 +129,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
annotations: Sequence[data.ClipAnnotation],
|
annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
output_dir: Optional[data.PathLike] = None,
|
output_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
clip_evals = self.evaluator.evaluate(
|
clip_evals = self.evaluator.evaluate(
|
||||||
annotations,
|
annotations,
|
||||||
@ -221,7 +221,7 @@ class BatDetect2API:
|
|||||||
def process_files(
|
def process_files(
|
||||||
self,
|
self,
|
||||||
audio_files: Sequence[data.PathLike],
|
audio_files: Sequence[data.PathLike],
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[BatDetect2Prediction]:
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
@ -236,8 +236,8 @@ class BatDetect2API:
|
|||||||
def process_clips(
|
def process_clips(
|
||||||
self,
|
self,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
batch_size: Optional[int] = None,
|
batch_size: int | None = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[BatDetect2Prediction]:
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
@ -254,9 +254,9 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
format: Optional[str] = None,
|
format: str | None = None,
|
||||||
config: Optional[OutputFormatConfig] = None,
|
config: OutputFormatConfig | None = None,
|
||||||
):
|
):
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
|
|
||||||
@ -331,7 +331,7 @@ class BatDetect2API:
|
|||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: Optional[BatDetect2Config] = None,
|
config: BatDetect2Config | None = None,
|
||||||
):
|
):
|
||||||
model, stored_config = load_model_from_checkpoint(path)
|
model, stored_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
|
|||||||
@ -245,16 +245,12 @@ class FixedDurationClip:
|
|||||||
|
|
||||||
|
|
||||||
ClipConfig = Annotated[
|
ClipConfig = Annotated[
|
||||||
Union[
|
RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
|
||||||
RandomClipConfig,
|
|
||||||
PaddedClipConfig,
|
|
||||||
FixedDurationClipConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
|
||||||
config = config or RandomClipConfig()
|
config = config or RandomClipConfig()
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
|
|||||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||||
|
|
||||||
|
|
||||||
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
|
||||||
"""Factory function to create an AudioLoader based on configuration."""
|
"""Factory function to create an AudioLoader based on configuration."""
|
||||||
config = config or AudioConfig()
|
config = config or AudioConfig()
|
||||||
return SoundEventAudioLoader(
|
return SoundEventAudioLoader(
|
||||||
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: Optional[ResampleConfig] = None,
|
config: ResampleConfig | None = None,
|
||||||
):
|
):
|
||||||
self.samplerate = samplerate
|
self.samplerate = samplerate
|
||||||
self.config = config or ResampleConfig()
|
self.config = config or ResampleConfig()
|
||||||
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def load_file(
|
def load_file(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio directly from a file path."""
|
"""Load and preprocess audio directly from a file path."""
|
||||||
return load_file_audio(
|
return load_file_audio(
|
||||||
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def load_recording(
|
def load_recording(
|
||||||
self,
|
self,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio for a Recording object."""
|
"""Load and preprocess the entire audio for a Recording object."""
|
||||||
return load_recording_audio(
|
return load_recording_audio(
|
||||||
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def load_clip(
|
def load_clip(
|
||||||
self,
|
self,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||||
return load_clip_audio(
|
return load_clip_audio(
|
||||||
@ -112,9 +112,9 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
|
|
||||||
def load_file_audio(
|
def load_file_audio(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
samplerate: Optional[int] = None,
|
samplerate: int | None = None,
|
||||||
config: Optional[ResampleConfig] = None,
|
config: ResampleConfig | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio from a file path using specified config."""
|
"""Load and preprocess audio from a file path using specified config."""
|
||||||
@ -136,9 +136,9 @@ def load_file_audio(
|
|||||||
|
|
||||||
def load_recording_audio(
|
def load_recording_audio(
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
samplerate: Optional[int] = None,
|
samplerate: int | None = None,
|
||||||
config: Optional[ResampleConfig] = None,
|
config: ResampleConfig | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio content of a recording using config."""
|
"""Load and preprocess the entire audio content of a recording using config."""
|
||||||
@ -158,9 +158,9 @@ def load_recording_audio(
|
|||||||
|
|
||||||
def load_clip_audio(
|
def load_clip_audio(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
samplerate: Optional[int] = None,
|
samplerate: int | None = None,
|
||||||
config: Optional[ResampleConfig] = None,
|
config: ResampleConfig | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess a specific audio clip segment based on config."""
|
"""Load and preprocess a specific audio clip segment based on config."""
|
||||||
|
|||||||
@ -34,9 +34,9 @@ def data(): ...
|
|||||||
)
|
)
|
||||||
def summary(
|
def summary(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
targets_path: Optional[Path] = None,
|
targets_path: Path | None = None,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Path | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.data import compute_class_summary, load_dataset_from_config
|
from batdetect2.data import compute_class_summary, load_dataset_from_config
|
||||||
from batdetect2.targets import load_targets
|
from batdetect2.targets import load_targets
|
||||||
@ -83,9 +83,9 @@ def summary(
|
|||||||
)
|
)
|
||||||
def convert(
|
def convert(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
output: Path = Path("annotations.json"),
|
output: Path = Path("annotations.json"),
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Path | None = None,
|
||||||
):
|
):
|
||||||
"""Convert a dataset config file to soundevent format."""
|
"""Convert a dataset config file to soundevent format."""
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|||||||
@ -25,11 +25,11 @@ def evaluate_command(
|
|||||||
model_path: Path,
|
model_path: Path,
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
base_dir: Path,
|
base_dir: Path,
|
||||||
config_path: Optional[Path],
|
config_path: Path | None,
|
||||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import load_full_config
|
from batdetect2.config import load_full_config
|
||||||
|
|||||||
@ -26,19 +26,19 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--seed", type=int)
|
@click.option("--seed", type=int)
|
||||||
def train_command(
|
def train_command(
|
||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
val_dataset: Optional[Path] = None,
|
val_dataset: Path | None = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Path | None = None,
|
||||||
ckpt_dir: Optional[Path] = None,
|
ckpt_dir: Path | None = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
config: Optional[Path] = None,
|
config: Path | None = None,
|
||||||
targets_config: Optional[Path] = None,
|
targets_config: Path | None = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: str | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
num_epochs: Optional[int] = None,
|
num_epochs: int | None = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import (
|
from batdetect2.config import (
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -17,7 +17,7 @@ from batdetect2.types import (
|
|||||||
FileAnnotation,
|
FileAnnotation,
|
||||||
)
|
)
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Path | str | os.PathLike
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_to_annotation_group",
|
"convert_to_annotation_group",
|
||||||
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
|
|||||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||||
|
|
||||||
|
|
||||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
||||||
|
|
||||||
ClassFn = Callable[[data.Recording], int]
|
ClassFn = Callable[[data.Recording], int]
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
|
|||||||
|
|
||||||
def file_annotation_to_clip(
|
def file_annotation_to_clip(
|
||||||
file_annotation: FileAnnotation,
|
file_annotation: FileAnnotation,
|
||||||
audio_dir: Optional[PathLike] = None,
|
audio_dir: PathLike | None = None,
|
||||||
label_key: str = "class",
|
label_key: str = "class",
|
||||||
) -> data.Clip:
|
) -> data.Clip:
|
||||||
"""Convert file annotation to recording."""
|
"""Convert file annotation to recording."""
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class BatDetect2Config(BaseConfig):
|
|||||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||||
|
|
||||||
|
|
||||||
def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
def validate_config(config: dict | None) -> BatDetect2Config:
|
||||||
if config is None:
|
if config is None:
|
||||||
return BatDetect2Config()
|
return BatDetect2Config()
|
||||||
|
|
||||||
@ -52,6 +52,6 @@ def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
|||||||
|
|
||||||
def load_full_config(
|
def load_full_config(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> BatDetect2Config:
|
) -> BatDetect2Config:
|
||||||
return load_config(path, schema=BatDetect2Config, field=field)
|
return load_config(path, schema=BatDetect2Config, field=field)
|
||||||
|
|||||||
@ -86,8 +86,8 @@ def adjust_width(
|
|||||||
|
|
||||||
def slice_tensor(
|
def slice_tensor(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
start: Optional[int] = None,
|
start: int | None = None,
|
||||||
end: Optional[int] = None,
|
end: int | None = None,
|
||||||
dim: int = -1,
|
dim: int = -1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
slices = [slice(None)] * tensor.ndim
|
slices = [slice(None)] * tensor.ndim
|
||||||
|
|||||||
@ -128,7 +128,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
|||||||
def load_config(
|
def load_config(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
schema: Type[T],
|
schema: Type[T],
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Load and validate configuration data from a file against a schema.
|
"""Load and validate configuration data from a file against a schema.
|
||||||
|
|
||||||
|
|||||||
@ -43,11 +43,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
AnnotationFormats = Annotated[
|
AnnotationFormats = Annotated[
|
||||||
Union[
|
BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
|
||||||
BatDetect2MergedAnnotations,
|
|
||||||
BatDetect2FilesAnnotations,
|
|
||||||
AOEFAnnotations,
|
|
||||||
],
|
|
||||||
Field(discriminator="format"),
|
Field(discriminator="format"),
|
||||||
]
|
]
|
||||||
"""Type Alias representing all supported data source configurations.
|
"""Type Alias representing all supported data source configurations.
|
||||||
@ -63,7 +59,7 @@ source configuration represents.
|
|||||||
|
|
||||||
def load_annotated_dataset(
|
def load_annotated_dataset(
|
||||||
dataset: AnnotatedDataset,
|
dataset: AnnotatedDataset,
|
||||||
base_dir: Optional[data.PathLike] = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load annotations for a single data source based on its configuration.
|
"""Load annotations for a single data source based on its configuration.
|
||||||
|
|
||||||
|
|||||||
@ -77,14 +77,14 @@ class AOEFAnnotations(AnnotatedDataset):
|
|||||||
|
|
||||||
annotations_path: Path
|
annotations_path: Path
|
||||||
|
|
||||||
filter: Optional[AnnotationTaskFilter] = Field(
|
filter: AnnotationTaskFilter | None = Field(
|
||||||
default_factory=AnnotationTaskFilter
|
default_factory=AnnotationTaskFilter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_aoef_annotated_dataset(
|
def load_aoef_annotated_dataset(
|
||||||
dataset: AOEFAnnotations,
|
dataset: AOEFAnnotations,
|
||||||
base_dir: Optional[data.PathLike] = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field, ValidationError
|
from pydantic import Field, ValidationError
|
||||||
@ -43,7 +43,7 @@ from batdetect2.data.annotations.legacy import (
|
|||||||
)
|
)
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Path | str | os.PathLike
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -102,7 +102,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
|
|||||||
format: Literal["batdetect2"] = "batdetect2"
|
format: Literal["batdetect2"] = "batdetect2"
|
||||||
annotations_dir: Path
|
annotations_dir: Path
|
||||||
|
|
||||||
filter: Optional[AnnotationFilter] = Field(
|
filter: AnnotationFilter | None = Field(
|
||||||
default_factory=AnnotationFilter,
|
default_factory=AnnotationFilter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -133,14 +133,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
|
|||||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||||
annotations_path: Path
|
annotations_path: Path
|
||||||
|
|
||||||
filter: Optional[AnnotationFilter] = Field(
|
filter: AnnotationFilter | None = Field(
|
||||||
default_factory=AnnotationFilter,
|
default_factory=AnnotationFilter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_batdetect2_files_annotated_dataset(
|
def load_batdetect2_files_annotated_dataset(
|
||||||
dataset: BatDetect2FilesAnnotations,
|
dataset: BatDetect2FilesAnnotations,
|
||||||
base_dir: Optional[PathLike] = None,
|
base_dir: PathLike | None = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
|
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ def load_batdetect2_files_annotated_dataset(
|
|||||||
|
|
||||||
def load_batdetect2_merged_annotated_dataset(
|
def load_batdetect2_merged_annotated_dataset(
|
||||||
dataset: BatDetect2MergedAnnotations,
|
dataset: BatDetect2MergedAnnotations,
|
||||||
base_dir: Optional[PathLike] = None,
|
base_dir: PathLike | None = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
|
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
|
||||||
|
|
||||||
|
|||||||
@ -3,12 +3,12 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Path | str | os.PathLike
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
||||||
|
|
||||||
ClassFn = Callable[[data.Recording], int]
|
ClassFn = Callable[[data.Recording], int]
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ def get_sound_event_tags(
|
|||||||
|
|
||||||
def file_annotation_to_clip(
|
def file_annotation_to_clip(
|
||||||
file_annotation: FileAnnotation,
|
file_annotation: FileAnnotation,
|
||||||
audio_dir: Optional[PathLike] = None,
|
audio_dir: PathLike | None = None,
|
||||||
label_key: str = "class",
|
label_key: str = "class",
|
||||||
) -> data.Clip:
|
) -> data.Clip:
|
||||||
"""Convert file annotation to recording."""
|
"""Convert file annotation to recording."""
|
||||||
|
|||||||
@ -264,16 +264,7 @@ class Not:
|
|||||||
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
SoundEventConditionConfig = Annotated[
|
||||||
Union[
|
HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig,
|
||||||
HasTagConfig,
|
|
||||||
HasAllTagsConfig,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
DurationConfig,
|
|
||||||
FrequencyConfig,
|
|
||||||
AllOfConfig,
|
|
||||||
AnyOfConfig,
|
|
||||||
NotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
|
|||||||
description: str
|
description: str
|
||||||
sources: List[AnnotationFormats]
|
sources: List[AnnotationFormats]
|
||||||
|
|
||||||
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
sound_event_filter: SoundEventConditionConfig | None = None
|
||||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
config: DatasetConfig,
|
config: DatasetConfig,
|
||||||
base_dir: Optional[data.PathLike] = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
@ -161,14 +161,14 @@ def insert_source_tag(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
def load_dataset_config(path: data.PathLike, field: str | None = None):
|
||||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_config(
|
def load_dataset_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
base_dir: Optional[data.PathLike] = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load dataset annotation metadata from a configuration file.
|
"""Load dataset annotation metadata from a configuration file.
|
||||||
|
|
||||||
@ -215,9 +215,9 @@ def load_dataset_from_config(
|
|||||||
def save_dataset(
|
def save_dataset(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
description: Optional[str] = None,
|
description: str | None = None,
|
||||||
audio_dir: Optional[Path] = None,
|
audio_dir: Path | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save a loaded dataset (list of ClipAnnotations) to a file.
|
"""Save a loaded dataset (list of ClipAnnotations) to a file.
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.typing.targets import TargetProtocol
|
|||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]:
|
||||||
"""Iterate over sound events in a dataset.
|
"""Iterate over sound events in a dataset.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@ -24,19 +24,14 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
OutputFormatConfig = Annotated[
|
OutputFormatConfig = Annotated[
|
||||||
Union[
|
BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig,
|
||||||
BatDetect2OutputConfig,
|
|
||||||
ParquetOutputConfig,
|
|
||||||
SoundEventOutputConfig,
|
|
||||||
RawOutputConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_output_formatter(
|
def build_output_formatter(
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
config: Optional[OutputFormatConfig] = None,
|
config: OutputFormatConfig | None = None,
|
||||||
) -> OutputFormatterProtocol:
|
) -> OutputFormatterProtocol:
|
||||||
"""Construct the final output formatter."""
|
"""Construct the final output formatter."""
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
@ -48,9 +43,9 @@ def build_output_formatter(
|
|||||||
|
|
||||||
|
|
||||||
def get_output_formatter(
|
def get_output_formatter(
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
config: Optional[OutputFormatConfig] = None,
|
config: OutputFormatConfig | None = None,
|
||||||
) -> OutputFormatterProtocol:
|
) -> OutputFormatterProtocol:
|
||||||
"""Get the output formatter by name."""
|
"""Get the output formatter by name."""
|
||||||
|
|
||||||
@ -71,9 +66,9 @@ def get_output_formatter(
|
|||||||
|
|
||||||
def load_predictions(
|
def load_predictions(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
format: Optional[str] = "raw",
|
format: str | None = "raw",
|
||||||
config: Optional[OutputFormatConfig] = None,
|
config: OutputFormatConfig | None = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
):
|
):
|
||||||
"""Load predictions from a file."""
|
"""Load predictions from a file."""
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[FileAnnotation],
|
predictions: Sequence[FileAnnotation],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
def pred_to_xr(
|
def pred_to_xr(
|
||||||
self,
|
self,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: BatDetect2Prediction,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
clip = prediction.clip
|
clip = prediction.clip
|
||||||
recording = clip.recording
|
recording = clip.recording
|
||||||
|
|||||||
@ -18,16 +18,16 @@ from batdetect2.typing import (
|
|||||||
|
|
||||||
class SoundEventOutputConfig(BaseConfig):
|
class SoundEventOutputConfig(BaseConfig):
|
||||||
name: Literal["soundevent"] = "soundevent"
|
name: Literal["soundevent"] = "soundevent"
|
||||||
top_k: Optional[int] = 1
|
top_k: int | None = 1
|
||||||
min_score: Optional[float] = None
|
min_score: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
top_k: Optional[int] = 1,
|
top_k: int | None = 1,
|
||||||
min_score: Optional[float] = 0,
|
min_score: float | None = 0,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@ -45,7 +45,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[data.ClipPrediction],
|
predictions: Sequence[data.ClipPrediction],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
run = data.PredictionSet(clip_predictions=list(predictions))
|
run = data.PredictionSet(clip_predictions=list(predictions))
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ def split_dataset_by_recordings(
|
|||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
train_size: float = 0.75,
|
train_size: float = 0.75,
|
||||||
random_state: Optional[int] = None,
|
random_state: int | None = None,
|
||||||
) -> Tuple[Dataset, Dataset]:
|
) -> Tuple[Dataset, Dataset]:
|
||||||
recordings = extract_recordings_df(dataset)
|
recordings = extract_recordings_df(dataset)
|
||||||
|
|
||||||
|
|||||||
@ -142,7 +142,7 @@ class MapTagValueConfig(BaseConfig):
|
|||||||
name: Literal["map_tag_value"] = "map_tag_value"
|
name: Literal["map_tag_value"] = "map_tag_value"
|
||||||
tag_key: str
|
tag_key: str
|
||||||
value_mapping: Dict[str, str]
|
value_mapping: Dict[str, str]
|
||||||
target_key: Optional[str] = None
|
target_key: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MapTagValue:
|
class MapTagValue:
|
||||||
@ -150,7 +150,7 @@ class MapTagValue:
|
|||||||
self,
|
self,
|
||||||
tag_key: str,
|
tag_key: str,
|
||||||
value_mapping: Dict[str, str],
|
value_mapping: Dict[str, str],
|
||||||
target_key: Optional[str] = None,
|
target_key: str | None = None,
|
||||||
):
|
):
|
||||||
self.tag_key = tag_key
|
self.tag_key = tag_key
|
||||||
self.value_mapping = value_mapping
|
self.value_mapping = value_mapping
|
||||||
@ -221,13 +221,7 @@ class ApplyAll:
|
|||||||
|
|
||||||
|
|
||||||
SoundEventTransformConfig = Annotated[
|
SoundEventTransformConfig = Annotated[
|
||||||
Union[
|
SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig,
|
||||||
SetFrequencyBoundConfig,
|
|
||||||
ReplaceTagConfig,
|
|
||||||
MapTagValueConfig,
|
|
||||||
ApplyIfConfig,
|
|
||||||
ApplyAllConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -86,7 +86,7 @@ def compute_bandwidth(
|
|||||||
|
|
||||||
def compute_max_power_bb(
|
def compute_max_power_bb(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: Optional[np.ndarray] = None,
|
spec: np.ndarray | None = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -131,7 +131,7 @@ def compute_max_power_bb(
|
|||||||
|
|
||||||
def compute_max_power(
|
def compute_max_power(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: Optional[np.ndarray] = None,
|
spec: np.ndarray | None = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -157,7 +157,7 @@ def compute_max_power(
|
|||||||
|
|
||||||
def compute_max_power_first(
|
def compute_max_power_first(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: Optional[np.ndarray] = None,
|
spec: np.ndarray | None = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -184,7 +184,7 @@ def compute_max_power_first(
|
|||||||
|
|
||||||
def compute_max_power_second(
|
def compute_max_power_second(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: Optional[np.ndarray] = None,
|
spec: np.ndarray | None = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -211,7 +211,7 @@ def compute_max_power_second(
|
|||||||
|
|
||||||
def compute_call_interval(
|
def compute_call_interval(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
previous: Optional[types.Prediction] = None,
|
previous: types.Prediction | None = None,
|
||||||
**_,
|
**_,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Compute time between this call and the previous call in seconds."""
|
"""Compute time between this call and the previous call in seconds."""
|
||||||
|
|||||||
@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
|
|||||||
def get_params(
|
def get_params(
|
||||||
make_dirs: bool = False,
|
make_dirs: bool = False,
|
||||||
exps_dir: str = "../../experiments/",
|
exps_dir: str = "../../experiments/",
|
||||||
model_name: Optional[str] = None,
|
model_name: str | None = None,
|
||||||
experiment: Union[Path, str, None] = None,
|
experiment: Path | str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> TrainingParameters:
|
) -> TrainingParameters:
|
||||||
experiments_dir = Path(exps_dir)
|
experiments_dir = Path(exps_dir)
|
||||||
|
|||||||
@ -151,7 +151,7 @@ def run_nms(
|
|||||||
|
|
||||||
def non_max_suppression(
|
def non_max_suppression(
|
||||||
heat: torch.Tensor,
|
heat: torch.Tensor,
|
||||||
kernel_size: Union[int, Tuple[int, int]],
|
kernel_size: int | Tuple[int, int],
|
||||||
):
|
):
|
||||||
# kernel can be an int or list/tuple
|
# kernel can be an int or list/tuple
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
|
|||||||
@ -213,18 +213,13 @@ class GeometricIOU(AffinityFunction):
|
|||||||
|
|
||||||
|
|
||||||
AffinityConfig = Annotated[
|
AffinityConfig = Annotated[
|
||||||
Union[
|
TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig,
|
||||||
TimeAffinityConfig,
|
|
||||||
IntervalIOUConfig,
|
|
||||||
BBoxIOUConfig,
|
|
||||||
GeometricIOUConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_affinity_function(
|
def build_affinity_function(
|
||||||
config: Optional[AffinityConfig] = None,
|
config: AffinityConfig | None = None,
|
||||||
) -> AffinityFunction:
|
) -> AffinityFunction:
|
||||||
config = config or GeometricIOUConfig()
|
config = config or GeometricIOUConfig()
|
||||||
return affinity_functions.build(config)
|
return affinity_functions.build(config)
|
||||||
|
|||||||
@ -51,6 +51,6 @@ def get_default_eval_config() -> EvaluationConfig:
|
|||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> EvaluationConfig:
|
) -> EvaluationConfig:
|
||||||
return load_config(path, schema=EvaluationConfig, field=field)
|
return load_config(path, schema=EvaluationConfig, field=field)
|
||||||
|
|||||||
@ -39,8 +39,8 @@ class TestDataset(Dataset[TestExample]):
|
|||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
clipper: Optional[ClipperProtocol] = None,
|
clipper: ClipperProtocol | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
self.clip_annotations = list(clip_annotations)
|
self.clip_annotations = list(clip_annotations)
|
||||||
self.clipper = clipper
|
self.clipper = clipper
|
||||||
@ -78,10 +78,10 @@ class TestLoaderConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_test_loader(
|
def build_test_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional[TestLoaderConfig] = None,
|
config: TestLoaderConfig | None = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
) -> DataLoader[TestExample]:
|
) -> DataLoader[TestExample]:
|
||||||
logger.info("Building test data loader...")
|
logger.info("Building test data loader...")
|
||||||
config = config or TestLoaderConfig()
|
config = config or TestLoaderConfig()
|
||||||
@ -109,9 +109,9 @@ def build_test_loader(
|
|||||||
|
|
||||||
def build_test_dataset(
|
def build_test_dataset(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional[TestLoaderConfig] = None,
|
config: TestLoaderConfig | None = None,
|
||||||
) -> TestDataset:
|
) -> TestDataset:
|
||||||
logger.info("Building training dataset...")
|
logger.info("Building training dataset...")
|
||||||
config = config or TestLoaderConfig()
|
config = config or TestLoaderConfig()
|
||||||
|
|||||||
@ -34,10 +34,10 @@ def evaluate(
|
|||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
formatter: Optional["OutputFormatterProtocol"] = None,
|
formatter: Optional["OutputFormatterProtocol"] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
|
|||||||
@ -51,8 +51,8 @@ class Evaluator:
|
|||||||
|
|
||||||
|
|
||||||
def build_evaluator(
|
def build_evaluator(
|
||||||
config: Optional[Union[EvaluationConfig, dict]] = None,
|
config: EvaluationConfig | dict | None = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
) -> EvaluatorProtocol:
|
) -> EvaluatorProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
|
|
||||||
|
|||||||
@ -35,9 +35,9 @@ def match(
|
|||||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||||
raw_predictions: Sequence[RawPrediction],
|
raw_predictions: Sequence[RawPrediction],
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
scores: Optional[Sequence[float]] = None,
|
scores: Sequence[float] | None = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
matcher: Optional[MatcherProtocol] = None,
|
matcher: MatcherProtocol | None = None,
|
||||||
) -> ClipMatches:
|
) -> ClipMatches:
|
||||||
if matcher is None:
|
if matcher is None:
|
||||||
matcher = build_matcher()
|
matcher = build_matcher()
|
||||||
@ -151,7 +151,7 @@ def match_start_times(
|
|||||||
predictions: Sequence[data.Geometry],
|
predictions: Sequence[data.Geometry],
|
||||||
scores: Sequence[float],
|
scores: Sequence[float],
|
||||||
distance_threshold: float = 0.01,
|
distance_threshold: float = 0.01,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||||
if not ground_truth:
|
if not ground_truth:
|
||||||
for index in range(len(predictions)):
|
for index in range(len(predictions)):
|
||||||
yield index, None, 0
|
yield index, None, 0
|
||||||
@ -287,7 +287,7 @@ def greedy_match(
|
|||||||
scores: Sequence[float],
|
scores: Sequence[float],
|
||||||
affinity_threshold: float = 0.5,
|
affinity_threshold: float = 0.5,
|
||||||
affinity_function: AffinityFunction = compute_affinity,
|
affinity_function: AffinityFunction = compute_affinity,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||||
|
|
||||||
Iterates through source geometries, prioritizing by score if provided. Each
|
Iterates through source geometries, prioritizing by score if provided. Each
|
||||||
@ -514,12 +514,7 @@ class OptimalMatcher(MatcherProtocol):
|
|||||||
|
|
||||||
|
|
||||||
MatchConfig = Annotated[
|
MatchConfig = Annotated[
|
||||||
Union[
|
GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig,
|
||||||
GreedyMatchConfig,
|
|
||||||
StartTimeMatchConfig,
|
|
||||||
OptimalMatchConfig,
|
|
||||||
GreedyAffinityMatchConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -558,7 +553,7 @@ def compute_affinity_matrix(
|
|||||||
def select_optimal_matches(
|
def select_optimal_matches(
|
||||||
affinity_matrix: np.ndarray,
|
affinity_matrix: np.ndarray,
|
||||||
affinity_threshold: float = 0.5,
|
affinity_threshold: float = 0.5,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||||
num_gt, num_pred = affinity_matrix.shape
|
num_gt, num_pred = affinity_matrix.shape
|
||||||
gts = set(range(num_gt))
|
gts = set(range(num_gt))
|
||||||
preds = set(range(num_pred))
|
preds = set(range(num_pred))
|
||||||
@ -588,7 +583,7 @@ def select_optimal_matches(
|
|||||||
def select_greedy_matches(
|
def select_greedy_matches(
|
||||||
affinity_matrix: np.ndarray,
|
affinity_matrix: np.ndarray,
|
||||||
affinity_threshold: float = 0.5,
|
affinity_threshold: float = 0.5,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||||
num_gt, num_pred = affinity_matrix.shape
|
num_gt, num_pred = affinity_matrix.shape
|
||||||
unmatched_pred = set(range(num_pred))
|
unmatched_pred = set(range(num_pred))
|
||||||
|
|
||||||
@ -612,6 +607,6 @@ def select_greedy_matches(
|
|||||||
yield pred_idx, None, 0
|
yield pred_idx, None, 0
|
||||||
|
|
||||||
|
|
||||||
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
|
||||||
config = config or StartTimeMatchConfig()
|
config = config or StartTimeMatchConfig()
|
||||||
return matching_strategies.build(config)
|
return matching_strategies.build(config)
|
||||||
|
|||||||
@ -36,13 +36,13 @@ __all__ = [
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: Optional[RawPrediction]
|
pred: RawPrediction | None
|
||||||
|
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
is_generic: bool
|
is_generic: bool
|
||||||
true_class: Optional[str]
|
true_class: str | None
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
@ -61,16 +61,16 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
|||||||
|
|
||||||
|
|
||||||
class BaseClassificationConfig(BaseConfig):
|
class BaseClassificationConfig(BaseConfig):
|
||||||
include: Optional[List[str]] = None
|
include: List[str] | None = None
|
||||||
exclude: Optional[List[str]] = None
|
exclude: List[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseClassificationMetric:
|
class BaseClassificationMetric:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
include: Optional[List[str]] = None,
|
include: List[str] | None = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: List[str] | None = None,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.include = include
|
self.include = include
|
||||||
@ -100,8 +100,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
|||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
label: str = "average_precision",
|
label: str = "average_precision",
|
||||||
include: Optional[List[str]] = None,
|
include: List[str] | None = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: List[str] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(include=include, exclude=exclude, targets=targets)
|
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
@ -169,8 +169,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
label: str = "roc_auc",
|
label: str = "roc_auc",
|
||||||
include: Optional[List[str]] = None,
|
include: List[str] | None = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: List[str] | None = None,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
@ -225,10 +225,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
|
|
||||||
|
|
||||||
ClassificationMetricConfig = Annotated[
|
ClassificationMetricConfig = Annotated[
|
||||||
Union[
|
ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
|
||||||
ClassificationAveragePrecisionConfig,
|
|
||||||
ClassificationROCAUCConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -123,10 +123,7 @@ class ClipClassificationROCAUC:
|
|||||||
|
|
||||||
|
|
||||||
ClipClassificationMetricConfig = Annotated[
|
ClipClassificationMetricConfig = Annotated[
|
||||||
Union[
|
ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
|
||||||
ClipClassificationAveragePrecisionConfig,
|
|
||||||
ClipClassificationROCAUCConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -159,12 +159,7 @@ class ClipDetectionPrecision:
|
|||||||
|
|
||||||
|
|
||||||
ClipDetectionMetricConfig = Annotated[
|
ClipDetectionMetricConfig = Annotated[
|
||||||
Union[
|
ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig,
|
||||||
ClipDetectionAveragePrecisionConfig,
|
|
||||||
ClipDetectionROCAUCConfig,
|
|
||||||
ClipDetectionRecallConfig,
|
|
||||||
ClipDetectionPrecisionConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ __all__ = [
|
|||||||
def compute_precision_recall(
|
def compute_precision_recall(
|
||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
num_positives: Optional[int] = None,
|
num_positives: int | None = None,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
y_true = np.array(y_true)
|
y_true = np.array(y_true)
|
||||||
y_score = np.array(y_score)
|
y_score = np.array(y_score)
|
||||||
@ -41,7 +41,7 @@ def compute_precision_recall(
|
|||||||
def average_precision(
|
def average_precision(
|
||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
num_positives: Optional[int] = None,
|
num_positives: int | None = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
if num_positives == 0:
|
if num_positives == 0:
|
||||||
return np.nan
|
return np.nan
|
||||||
|
|||||||
@ -28,8 +28,8 @@ __all__ = [
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: Optional[RawPrediction]
|
pred: RawPrediction | None
|
||||||
|
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
@ -212,12 +212,7 @@ class DetectionPrecision:
|
|||||||
|
|
||||||
|
|
||||||
DetectionMetricConfig = Annotated[
|
DetectionMetricConfig = Annotated[
|
||||||
Union[
|
DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig,
|
||||||
DetectionAveragePrecisionConfig,
|
|
||||||
DetectionROCAUCConfig,
|
|
||||||
DetectionRecallConfig,
|
|
||||||
DetectionPrecisionConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -31,14 +31,14 @@ __all__ = [
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: Optional[RawPrediction]
|
pred: RawPrediction | None
|
||||||
|
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
is_generic: bool
|
is_generic: bool
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
pred_class: Optional[str]
|
pred_class: str | None
|
||||||
true_class: Optional[str]
|
true_class: str | None
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
@ -301,13 +301,7 @@ class BalancedAccuracy:
|
|||||||
|
|
||||||
|
|
||||||
TopClassMetricConfig = Annotated[
|
TopClassMetricConfig = Annotated[
|
||||||
Union[
|
TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig,
|
||||||
TopClassAveragePrecisionConfig,
|
|
||||||
TopClassROCAUCConfig,
|
|
||||||
TopClassRecallConfig,
|
|
||||||
TopClassPrecisionConfig,
|
|
||||||
BalancedAccuracyConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.typing import TargetProtocol
|
|||||||
class BasePlotConfig(BaseConfig):
|
class BasePlotConfig(BaseConfig):
|
||||||
label: str = "plot"
|
label: str = "plot"
|
||||||
theme: str = "default"
|
theme: str = "default"
|
||||||
title: Optional[str] = None
|
title: str | None = None
|
||||||
figsize: tuple[int, int] = (10, 10)
|
figsize: tuple[int, int] = (10, 10)
|
||||||
dpi: int = 100
|
dpi: int = 100
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class BasePlot:
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
label: str = "plot",
|
label: str = "plot",
|
||||||
figsize: tuple[int, int] = (10, 10),
|
figsize: tuple[int, int] = (10, 10),
|
||||||
title: Optional[str] = None,
|
title: str | None = None,
|
||||||
dpi: int = 100,
|
dpi: int = 100,
|
||||||
theme: str = "default",
|
theme: str = "default",
|
||||||
):
|
):
|
||||||
|
|||||||
@ -45,7 +45,7 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: Optional[str] = "Classification Precision-Recall Curve"
|
title: str | None = "Classification Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -108,7 +108,7 @@ class PRCurve(BasePlot):
|
|||||||
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||||
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||||
label: str = "threshold_precision_curve"
|
label: str = "threshold_precision_curve"
|
||||||
title: Optional[str] = "Classification Threshold-Precision Curve"
|
title: str | None = "Classification Threshold-Precision Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -181,7 +181,7 @@ class ThresholdPrecisionCurve(BasePlot):
|
|||||||
class ThresholdRecallCurveConfig(BasePlotConfig):
|
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||||
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||||
label: str = "threshold_recall_curve"
|
label: str = "threshold_recall_curve"
|
||||||
title: Optional[str] = "Classification Threshold-Recall Curve"
|
title: str | None = "Classification Threshold-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -254,7 +254,7 @@ class ThresholdRecallCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: Optional[str] = "Classification ROC Curve"
|
title: str | None = "Classification ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -326,12 +326,7 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
ClassificationPlotConfig = Annotated[
|
ClassificationPlotConfig = Annotated[
|
||||||
Union[
|
PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig,
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ThresholdPrecisionCurveConfig,
|
|
||||||
ThresholdRecallCurveConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,7 @@ clip_classification_plots: Registry[
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
title: str | None = "Clip Classification Precision-Recall Curve"
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: Optional[str] = "Clip Classification ROC Curve"
|
title: str | None = "Clip Classification ROC Curve"
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -174,10 +174,7 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
ClipClassificationPlotConfig = Annotated[
|
ClipClassificationPlotConfig = Annotated[
|
||||||
Union[
|
PRCurveConfig | ROCCurveConfig,
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
title: str | None = "Clip Detection Precision-Recall Curve"
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
class PRCurve(BasePlot):
|
||||||
@ -74,7 +74,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: Optional[str] = "Clip Detection ROC Curve"
|
title: str | None = "Clip Detection ROC Curve"
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
class ROCCurve(BasePlot):
|
||||||
@ -107,7 +107,7 @@ class ROCCurve(BasePlot):
|
|||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
label: str = "score_distribution"
|
label: str = "score_distribution"
|
||||||
title: Optional[str] = "Clip Detection Score Distribution"
|
title: str | None = "Clip Detection Score Distribution"
|
||||||
|
|
||||||
|
|
||||||
class ScoreDistributionPlot(BasePlot):
|
class ScoreDistributionPlot(BasePlot):
|
||||||
@ -147,11 +147,7 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
ClipDetectionPlotConfig = Annotated[
|
ClipDetectionPlotConfig = Annotated[
|
||||||
Union[
|
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ScoreDistributionPlotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: Optional[str] = "Detection Precision-Recall Curve"
|
title: str | None = "Detection Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: Optional[str] = "Detection ROC Curve"
|
title: str | None = "Detection ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ class ROCCurve(BasePlot):
|
|||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
label: str = "score_distribution"
|
label: str = "score_distribution"
|
||||||
title: Optional[str] = "Detection Score Distribution"
|
title: str | None = "Detection Score Distribution"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -226,7 +226,7 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["example_detection"] = "example_detection"
|
name: Literal["example_detection"] = "example_detection"
|
||||||
label: str = "example_detection"
|
label: str = "example_detection"
|
||||||
title: Optional[str] = "Example Detection"
|
title: str | None = "Example Detection"
|
||||||
figsize: tuple[int, int] = (10, 4)
|
figsize: tuple[int, int] = (10, 4)
|
||||||
num_examples: int = 5
|
num_examples: int = 5
|
||||||
threshold: float = 0.2
|
threshold: float = 0.2
|
||||||
@ -292,12 +292,7 @@ class ExampleDetectionPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
DetectionPlotConfig = Annotated[
|
DetectionPlotConfig = Annotated[
|
||||||
Union[
|
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig,
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ScoreDistributionPlotConfig,
|
|
||||||
ExampleDetectionPlotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: Optional[str] = "Top Class Precision-Recall Curve"
|
title: str | None = "Top Class Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: Optional[str] = "Top Class ROC Curve"
|
title: str | None = "Top Class ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -173,7 +173,7 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
class ConfusionMatrixConfig(BasePlotConfig):
|
class ConfusionMatrixConfig(BasePlotConfig):
|
||||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||||
title: Optional[str] = "Top Class Confusion Matrix"
|
title: str | None = "Top Class Confusion Matrix"
|
||||||
figsize: tuple[int, int] = (10, 10)
|
figsize: tuple[int, int] = (10, 10)
|
||||||
label: str = "confusion_matrix"
|
label: str = "confusion_matrix"
|
||||||
exclude_generic: bool = True
|
exclude_generic: bool = True
|
||||||
@ -257,7 +257,7 @@ class ConfusionMatrix(BasePlot):
|
|||||||
class ExampleClassificationPlotConfig(BasePlotConfig):
|
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||||
name: Literal["example_classification"] = "example_classification"
|
name: Literal["example_classification"] = "example_classification"
|
||||||
label: str = "example_classification"
|
label: str = "example_classification"
|
||||||
title: Optional[str] = "Example Classification"
|
title: str | None = "Example Classification"
|
||||||
num_examples: int = 4
|
num_examples: int = 4
|
||||||
threshold: float = 0.2
|
threshold: float = 0.2
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
@ -348,12 +348,7 @@ class ExampleClassificationPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
TopClassPlotConfig = Annotated[
|
TopClassPlotConfig = Annotated[
|
||||||
Union[
|
PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig,
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ConfusionMatrixConfig,
|
|
||||||
ExampleClassificationPlotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -96,7 +96,7 @@ def extract_matches_dataframe(
|
|||||||
|
|
||||||
|
|
||||||
EvaluationTableConfig = Annotated[
|
EvaluationTableConfig = Annotated[
|
||||||
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
FullEvaluationTableConfig, Field(discriminator="name")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -26,20 +26,14 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
TaskConfig = Annotated[
|
TaskConfig = Annotated[
|
||||||
Union[
|
ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig,
|
||||||
ClassificationTaskConfig,
|
|
||||||
DetectionTaskConfig,
|
|
||||||
ClipDetectionTaskConfig,
|
|
||||||
ClipClassificationTaskConfig,
|
|
||||||
TopClassDetectionTaskConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_task(
|
def build_task(
|
||||||
config: TaskConfig,
|
config: TaskConfig,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
) -> EvaluatorProtocol:
|
) -> EvaluatorProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
return tasks_registry.build(config, targets)
|
return tasks_registry.build(config, targets)
|
||||||
@ -49,8 +43,8 @@ def evaluate_task(
|
|||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
task: Optional["str"] = None,
|
task: Optional["str"] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
config: Optional[Union[TaskConfig, dict]] = None,
|
config: TaskConfig | dict | None = None,
|
||||||
):
|
):
|
||||||
if isinstance(config, BaseTaskConfig):
|
if isinstance(config, BaseTaskConfig):
|
||||||
task_obj = build_task(config, targets)
|
task_obj = build_task(config, targets)
|
||||||
|
|||||||
@ -67,9 +67,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
ignore_start_end: float = 0.01,
|
ignore_start_end: float = 0.01,
|
||||||
plots: Optional[
|
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
||||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
|
||||||
] = None,
|
|
||||||
):
|
):
|
||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
@ -147,9 +145,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
config: BaseTaskConfig,
|
config: BaseTaskConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
plots: Optional[
|
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
||||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
|
||||||
] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
matcher = build_matcher(config.matching_strategy)
|
matcher = build_matcher(config.matching_strategy)
|
||||||
|
|||||||
@ -98,8 +98,8 @@ def load_annotations(
|
|||||||
dataset_name: str,
|
dataset_name: str,
|
||||||
ann_path: str,
|
ann_path: str,
|
||||||
audio_path: str,
|
audio_path: str,
|
||||||
classes_to_ignore: Optional[List[str]] = None,
|
classes_to_ignore: List[str] | None = None,
|
||||||
events_of_interest: Optional[List[str]] = None,
|
events_of_interest: List[str] | None = None,
|
||||||
) -> List[types.FileAnnotation]:
|
) -> List[types.FileAnnotation]:
|
||||||
train_sets: List[types.DatasetDict] = []
|
train_sets: List[types.DatasetDict] = []
|
||||||
train_sets.append(
|
train_sets.append(
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from batdetect2 import types
|
|||||||
|
|
||||||
def print_dataset_stats(
|
def print_dataset_stats(
|
||||||
data: List[types.FileAnnotation],
|
data: List[types.FileAnnotation],
|
||||||
classes_to_ignore: Optional[List[str]] = None,
|
classes_to_ignore: List[str] | None = None,
|
||||||
) -> Counter[str]:
|
) -> Counter[str]:
|
||||||
print("Num files:", len(data))
|
print("Num files:", len(data))
|
||||||
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||||
|
|||||||
@ -28,8 +28,8 @@ def run_batch_inference(
|
|||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[BatDetect2Prediction]:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ def process_file_list(
|
|||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[BatDetect2Prediction]:
|
||||||
clip_config = config.inference.clipping
|
clip_config = config.inference.clipping
|
||||||
clips = get_clips_from_files(
|
clips = get_clips_from_files(
|
||||||
|
|||||||
@ -36,7 +36,7 @@ class InferenceDataset(Dataset[DatasetItem]):
|
|||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
self.clips = list(clips)
|
self.clips = list(clips)
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
@ -66,11 +66,11 @@ class InferenceLoaderConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_inference_loader(
|
def build_inference_loader(
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional[InferenceLoaderConfig] = None,
|
config: InferenceLoaderConfig | None = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: int | None = None,
|
||||||
) -> DataLoader[DatasetItem]:
|
) -> DataLoader[DatasetItem]:
|
||||||
logger.info("Building inference data loader...")
|
logger.info("Building inference data loader...")
|
||||||
config = config or InferenceLoaderConfig()
|
config = config or InferenceLoaderConfig()
|
||||||
@ -95,8 +95,8 @@ def build_inference_loader(
|
|||||||
|
|
||||||
def build_inference_dataset(
|
def build_inference_dataset(
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
) -> InferenceDataset:
|
) -> InferenceDataset:
|
||||||
if audio_loader is None:
|
if audio_loader is None:
|
||||||
audio_loader = build_audio_loader()
|
audio_loader = build_audio_loader()
|
||||||
|
|||||||
@ -49,14 +49,14 @@ def enable_logging(level: int):
|
|||||||
|
|
||||||
class BaseLoggerConfig(BaseConfig):
|
class BaseLoggerConfig(BaseConfig):
|
||||||
log_dir: Path = DEFAULT_LOGS_DIR
|
log_dir: Path = DEFAULT_LOGS_DIR
|
||||||
experiment_name: Optional[str] = None
|
experiment_name: str | None = None
|
||||||
run_name: Optional[str] = None
|
run_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseLoggerConfig):
|
class DVCLiveConfig(BaseLoggerConfig):
|
||||||
name: Literal["dvclive"] = "dvclive"
|
name: Literal["dvclive"] = "dvclive"
|
||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
log_model: Union[bool, Literal["all"]] = False
|
log_model: bool | Literal["all"] = False
|
||||||
monitor_system: bool = False
|
monitor_system: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -72,18 +72,13 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
|
|||||||
|
|
||||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||||
name: Literal["mlflow"] = "mlflow"
|
name: Literal["mlflow"] = "mlflow"
|
||||||
tracking_uri: Optional[str] = "http://localhost:5000"
|
tracking_uri: str | None = "http://localhost:5000"
|
||||||
tags: Optional[dict[str, Any]] = None
|
tags: dict[str, Any] | None = None
|
||||||
log_model: bool = False
|
log_model: bool = False
|
||||||
|
|
||||||
|
|
||||||
LoggerConfig = Annotated[
|
LoggerConfig = Annotated[
|
||||||
Union[
|
DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig,
|
||||||
DVCLiveConfig,
|
|
||||||
CSVLoggerConfig,
|
|
||||||
TensorBoardLoggerConfig,
|
|
||||||
MLFlowLoggerConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -95,17 +90,17 @@ class LoggerBuilder(Protocol, Generic[T]):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
config: T,
|
config: T,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Logger: ...
|
) -> Logger: ...
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(
|
def create_dvclive_logger(
|
||||||
config: DVCLiveConfig,
|
config: DVCLiveConfig,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
@ -130,9 +125,9 @@ def create_dvclive_logger(
|
|||||||
|
|
||||||
def create_csv_logger(
|
def create_csv_logger(
|
||||||
config: CSVLoggerConfig,
|
config: CSVLoggerConfig,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
@ -159,9 +154,9 @@ def create_csv_logger(
|
|||||||
|
|
||||||
def create_tensorboard_logger(
|
def create_tensorboard_logger(
|
||||||
config: TensorBoardLoggerConfig,
|
config: TensorBoardLoggerConfig,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
@ -191,9 +186,9 @@ def create_tensorboard_logger(
|
|||||||
|
|
||||||
def create_mlflow_logger(
|
def create_mlflow_logger(
|
||||||
config: MLFlowLoggerConfig,
|
config: MLFlowLoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: data.PathLike | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
@ -232,9 +227,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
|
|||||||
|
|
||||||
def build_logger(
|
def build_logger(
|
||||||
config: LoggerConfig,
|
config: LoggerConfig,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building logger with config: \n{}",
|
"Building logger with config: \n{}",
|
||||||
@ -257,7 +252,7 @@ def build_logger(
|
|||||||
PlotLogger = Callable[[str, Figure, int], None]
|
PlotLogger = Callable[[str, Figure, int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return logger.experiment.add_figure
|
return logger.experiment.add_figure
|
||||||
|
|
||||||
@ -282,7 +277,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
|||||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
def get_table_logger(logger: Logger) -> TableLogger | None:
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return partial(save_table, dir=Path(logger.log_dir))
|
return partial(save_table, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
|
|||||||
@ -43,10 +43,7 @@ from batdetect2.models.bottleneck import (
|
|||||||
BottleneckConfig,
|
BottleneckConfig,
|
||||||
build_bottleneck,
|
build_bottleneck,
|
||||||
)
|
)
|
||||||
from batdetect2.models.config import (
|
from batdetect2.models.config import BackboneConfig, load_backbone_config
|
||||||
BackboneConfig,
|
|
||||||
load_backbone_config,
|
|
||||||
)
|
|
||||||
from batdetect2.models.decoder import (
|
from batdetect2.models.decoder import (
|
||||||
DEFAULT_DECODER_CONFIG,
|
DEFAULT_DECODER_CONFIG,
|
||||||
DecoderConfig,
|
DecoderConfig,
|
||||||
@ -122,10 +119,10 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def build_model(
|
def build_model(
|
||||||
config: Optional[BackboneConfig] = None,
|
config: BackboneConfig | None = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
|||||||
@ -78,8 +78,8 @@ class Bottleneck(nn.Module):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
bottleneck_channels: Optional[int] = None,
|
bottleneck_channels: int | None = None,
|
||||||
layers: Optional[List[torch.nn.Module]] = None,
|
layers: List[torch.nn.Module] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the base Bottleneck layer."""
|
"""Initialize the base Bottleneck layer."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -127,7 +127,7 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
BottleneckLayerConfig = Annotated[
|
BottleneckLayerConfig = Annotated[
|
||||||
Union[SelfAttentionConfig,],
|
SelfAttentionConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
@ -171,7 +171,7 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
|||||||
def build_bottleneck(
|
def build_bottleneck(
|
||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: Optional[BottleneckConfig] = None,
|
config: BottleneckConfig | None = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Factory function to build the Bottleneck module from configuration.
|
"""Factory function to build the Bottleneck module from configuration.
|
||||||
|
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class BackboneConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_backbone_config(
|
def load_backbone_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> BackboneConfig:
|
) -> BackboneConfig:
|
||||||
"""Load the backbone configuration from a file.
|
"""Load the backbone configuration from a file.
|
||||||
|
|
||||||
|
|||||||
@ -41,12 +41,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
DecoderLayerConfig = Annotated[
|
DecoderLayerConfig = Annotated[
|
||||||
Union[
|
ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig,
|
||||||
ConvConfig,
|
|
||||||
FreqCoordConvUpConfig,
|
|
||||||
StandardConvUpConfig,
|
|
||||||
LayerGroupConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
@ -216,7 +211,7 @@ convolutional block.
|
|||||||
def build_decoder(
|
def build_decoder(
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
config: Optional[DecoderConfig] = None,
|
config: DecoderConfig | None = None,
|
||||||
) -> Decoder:
|
) -> Decoder:
|
||||||
"""Factory function to build a Decoder instance from configuration.
|
"""Factory function to build a Decoder instance from configuration.
|
||||||
|
|
||||||
|
|||||||
@ -127,7 +127,7 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
|
|
||||||
def build_detector(
|
def build_detector(
|
||||||
num_classes: int, config: Optional[BackboneConfig] = None
|
num_classes: int, config: BackboneConfig | None = None
|
||||||
) -> DetectionModel:
|
) -> DetectionModel:
|
||||||
"""Build the complete BatDetect2 detection model.
|
"""Build the complete BatDetect2 detection model.
|
||||||
|
|
||||||
|
|||||||
@ -43,12 +43,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
EncoderLayerConfig = Annotated[
|
EncoderLayerConfig = Annotated[
|
||||||
Union[
|
ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig,
|
||||||
ConvConfig,
|
|
||||||
FreqCoordConvDownConfig,
|
|
||||||
StandardConvDownConfig,
|
|
||||||
LayerGroupConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
@ -252,7 +247,7 @@ Specifies an architecture typically used in BatDetect2:
|
|||||||
def build_encoder(
|
def build_encoder(
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
config: Optional[EncoderConfig] = None,
|
config: EncoderConfig | None = None,
|
||||||
) -> Encoder:
|
) -> Encoder:
|
||||||
"""Factory function to build an Encoder instance from configuration.
|
"""Factory function to build an Encoder instance from configuration.
|
||||||
|
|
||||||
|
|||||||
@ -15,10 +15,10 @@ __all__ = [
|
|||||||
|
|
||||||
def plot_clip_annotation(
|
def plot_clip_annotation(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
cmap: str = "gray",
|
cmap: str = "gray",
|
||||||
alpha: float = 1,
|
alpha: float = 1,
|
||||||
@ -50,8 +50,8 @@ def plot_clip_annotation(
|
|||||||
def plot_anchor_points(
|
def plot_anchor_points(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
size: int = 1,
|
size: int = 1,
|
||||||
color: str = "red",
|
color: str = "red",
|
||||||
marker: str = "x",
|
marker: str = "x",
|
||||||
|
|||||||
@ -17,10 +17,10 @@ __all__ = [
|
|||||||
|
|
||||||
def plot_clip_prediction(
|
def plot_clip_prediction(
|
||||||
clip_prediction: data.ClipPrediction,
|
clip_prediction: data.ClipPrediction,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
add_legend: bool = False,
|
add_legend: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
linewidth: float = 1,
|
linewidth: float = 1,
|
||||||
@ -50,14 +50,14 @@ def plot_clip_prediction(
|
|||||||
|
|
||||||
def plot_predictions(
|
def plot_predictions(
|
||||||
predictions: Iterable[data.SoundEventPrediction],
|
predictions: Iterable[data.SoundEventPrediction],
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
position: Positions = "top-right",
|
position: Positions = "top-right",
|
||||||
color_mapper: Optional[TagColorMapper] = None,
|
color_mapper: TagColorMapper | None = None,
|
||||||
time_offset: float = 0.001,
|
time_offset: float = 0.001,
|
||||||
freq_offset: float = 1000,
|
freq_offset: float = 1000,
|
||||||
legend: bool = True,
|
legend: bool = True,
|
||||||
max_alpha: float = 0.5,
|
max_alpha: float = 0.5,
|
||||||
color: Optional[str] = None,
|
color: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Plot an prediction."""
|
"""Plot an prediction."""
|
||||||
@ -88,14 +88,14 @@ def plot_predictions(
|
|||||||
|
|
||||||
def plot_prediction(
|
def plot_prediction(
|
||||||
prediction: data.SoundEventPrediction,
|
prediction: data.SoundEventPrediction,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
position: Positions = "top-right",
|
position: Positions = "top-right",
|
||||||
color_mapper: Optional[TagColorMapper] = None,
|
color_mapper: TagColorMapper | None = None,
|
||||||
time_offset: float = 0.001,
|
time_offset: float = 0.001,
|
||||||
freq_offset: float = 1000,
|
freq_offset: float = 1000,
|
||||||
max_alpha: float = 0.5,
|
max_alpha: float = 0.5,
|
||||||
alpha: Optional[float] = None,
|
alpha: float | None = None,
|
||||||
color: Optional[str] = None,
|
color: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
"""Plot an annotation."""
|
"""Plot an annotation."""
|
||||||
|
|||||||
@ -17,11 +17,11 @@ __all__ = [
|
|||||||
|
|
||||||
def plot_clip(
|
def plot_clip(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
if ax is None:
|
if ax is None:
|
||||||
|
|||||||
@ -13,8 +13,8 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def create_ax(
|
def create_ax(
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
@ -25,17 +25,17 @@ def create_ax(
|
|||||||
|
|
||||||
|
|
||||||
def plot_spectrogram(
|
def plot_spectrogram(
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
spec: torch.Tensor | np.ndarray,
|
||||||
start_time: Optional[float] = None,
|
start_time: float | None = None,
|
||||||
end_time: Optional[float] = None,
|
end_time: float | None = None,
|
||||||
min_freq: Optional[float] = None,
|
min_freq: float | None = None,
|
||||||
max_freq: Optional[float] = None,
|
max_freq: float | None = None,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_colorbar: bool = False,
|
add_colorbar: bool = False,
|
||||||
colorbar_kwargs: Optional[dict] = None,
|
colorbar_kwargs: dict | None = None,
|
||||||
vmin: Optional[float] = None,
|
vmin: float | None = None,
|
||||||
vmax: Optional[float] = None,
|
vmax: float | None = None,
|
||||||
cmap="gray",
|
cmap="gray",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
if isinstance(spec, torch.Tensor):
|
if isinstance(spec, torch.Tensor):
|
||||||
|
|||||||
@ -19,9 +19,9 @@ __all__ = [
|
|||||||
def plot_clip_detections(
|
def plot_clip_detections(
|
||||||
clip_eval: ClipEval,
|
clip_eval: ClipEval,
|
||||||
figsize: tuple[int, int] = (10, 10),
|
figsize: tuple[int, int] = (10, 10),
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
threshold: float = 0.2,
|
threshold: float = 0.2,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_title: bool = True,
|
add_title: bool = True,
|
||||||
|
|||||||
@ -20,11 +20,11 @@ def plot_match_gallery(
|
|||||||
false_positives: Sequence[MatchProtocol],
|
false_positives: Sequence[MatchProtocol],
|
||||||
false_negatives: Sequence[MatchProtocol],
|
false_negatives: Sequence[MatchProtocol],
|
||||||
cross_triggers: Sequence[MatchProtocol],
|
cross_triggers: Sequence[MatchProtocol],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
n_examples: int = 5,
|
n_examples: int = 5,
|
||||||
duration: float = 0.1,
|
duration: float = 0.1,
|
||||||
fig: Optional[Figure] = None,
|
fig: Figure | None = None,
|
||||||
):
|
):
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.figure(figsize=(20, 20))
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
|||||||
@ -12,13 +12,13 @@ from batdetect2.plotting.common import create_ax
|
|||||||
|
|
||||||
|
|
||||||
def plot_detection_heatmap(
|
def plot_detection_heatmap(
|
||||||
heatmap: Union[torch.Tensor, np.ndarray],
|
heatmap: torch.Tensor | np.ndarray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
threshold: Optional[float] = None,
|
threshold: float | None = None,
|
||||||
alpha: float = 1,
|
alpha: float = 1,
|
||||||
cmap: Union[str, Colormap] = "jet",
|
cmap: str | Colormap = "jet",
|
||||||
color: Optional[str] = None,
|
color: str | None = None,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax, figsize=figsize)
|
ax = create_ax(ax, figsize=figsize)
|
||||||
|
|
||||||
@ -48,13 +48,13 @@ def plot_detection_heatmap(
|
|||||||
|
|
||||||
|
|
||||||
def plot_classification_heatmap(
|
def plot_classification_heatmap(
|
||||||
heatmap: Union[torch.Tensor, np.ndarray],
|
heatmap: torch.Tensor | np.ndarray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
class_names: Optional[List[str]] = None,
|
class_names: List[str] | None = None,
|
||||||
threshold: Optional[float] = 0.1,
|
threshold: float | None = 0.1,
|
||||||
alpha: float = 1,
|
alpha: float = 1,
|
||||||
cmap: Union[str, Colormap] = "tab20",
|
cmap: str | Colormap = "tab20",
|
||||||
):
|
):
|
||||||
ax = create_ax(ax, figsize=figsize)
|
ax = create_ax(ax, figsize=figsize)
|
||||||
|
|
||||||
|
|||||||
@ -24,10 +24,10 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def spectrogram(
|
def spectrogram(
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
spec: torch.Tensor | np.ndarray,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
cmap: str = "plasma",
|
cmap: str = "plasma",
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
@ -103,11 +103,11 @@ def spectrogram(
|
|||||||
|
|
||||||
|
|
||||||
def spectrogram_with_detections(
|
def spectrogram_with_detections(
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
spec: torch.Tensor | np.ndarray,
|
||||||
dets: List[Annotation],
|
dets: List[Annotation],
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: ProcessingConfiguration | None = None,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
cmap: str = "plasma",
|
cmap: str = "plasma",
|
||||||
with_names: bool = True,
|
with_names: bool = True,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
@ -168,8 +168,8 @@ def spectrogram_with_detections(
|
|||||||
|
|
||||||
def detections(
|
def detections(
|
||||||
dets: List[Annotation],
|
dets: List[Annotation],
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
with_names: bool = True,
|
with_names: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
@ -213,8 +213,8 @@ def detections(
|
|||||||
|
|
||||||
def detection(
|
def detection(
|
||||||
det: Annotation,
|
det: Annotation,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
linewidth: float = 1,
|
linewidth: float = 1,
|
||||||
edgecolor: str = "w",
|
edgecolor: str = "w",
|
||||||
facecolor: str = "none",
|
facecolor: str = "none",
|
||||||
|
|||||||
@ -21,10 +21,10 @@ __all__ = [
|
|||||||
|
|
||||||
class MatchProtocol(Protocol):
|
class MatchProtocol(Protocol):
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: Optional[RawPrediction]
|
pred: RawPrediction | None
|
||||||
score: float
|
score: float
|
||||||
true_class: Optional[str]
|
true_class: str | None
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DURATION = 0.05
|
DEFAULT_DURATION = 0.05
|
||||||
@ -38,11 +38,11 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
|
|||||||
|
|
||||||
def plot_false_positive_match(
|
def plot_false_positive_match(
|
||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
use_score: bool = True,
|
use_score: bool = True,
|
||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
@ -52,7 +52,7 @@ def plot_false_positive_match(
|
|||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: float | str = "small",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.pred is not None
|
assert match.pred is not None
|
||||||
|
|
||||||
@ -109,11 +109,11 @@ def plot_false_positive_match(
|
|||||||
|
|
||||||
def plot_false_negative_match(
|
def plot_false_negative_match(
|
||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
@ -169,11 +169,11 @@ def plot_false_negative_match(
|
|||||||
|
|
||||||
def plot_true_positive_match(
|
def plot_true_positive_match(
|
||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
use_score: bool = True,
|
use_score: bool = True,
|
||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
@ -182,7 +182,7 @@ def plot_true_positive_match(
|
|||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: float | str = "small",
|
||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
add_title: bool = True,
|
add_title: bool = True,
|
||||||
@ -257,11 +257,11 @@ def plot_true_positive_match(
|
|||||||
|
|
||||||
def plot_cross_trigger_match(
|
def plot_cross_trigger_match(
|
||||||
match: MatchProtocol,
|
match: MatchProtocol,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Axes | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
use_score: bool = True,
|
use_score: bool = True,
|
||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
@ -271,7 +271,7 @@ def plot_cross_trigger_match(
|
|||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: float | str = "small",
|
||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
|
|||||||
@ -33,16 +33,16 @@ def plot_pr_curve(
|
|||||||
precision: np.ndarray,
|
precision: np.ndarray,
|
||||||
recall: np.ndarray,
|
recall: np.ndarray,
|
||||||
thresholds: np.ndarray,
|
thresholds: np.ndarray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
color: Union[str, Tuple[float, float, float], None] = None,
|
color: str | Tuple[float, float, float] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
add_legend: bool = False,
|
add_legend: bool = False,
|
||||||
marker: Union[str, Tuple[int, int, float], None] = "o",
|
marker: str | Tuple[int, int, float] | None = "o",
|
||||||
markeredgecolor: Union[str, Tuple[float, float, float], None] = None,
|
markeredgecolor: str | Tuple[float, float, float] | None = None,
|
||||||
markersize: Optional[float] = None,
|
markersize: float | None = None,
|
||||||
linestyle: Union[str, Tuple[int, ...], None] = None,
|
linestyle: str | Tuple[int, ...] | None = None,
|
||||||
linewidth: Optional[float] = None,
|
linewidth: float | None = None,
|
||||||
label: str = "PR Curve",
|
label: str = "PR Curve",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -77,8 +77,8 @@ def plot_pr_curve(
|
|||||||
|
|
||||||
def plot_pr_curves(
|
def plot_pr_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
include_ap: bool = False,
|
include_ap: bool = False,
|
||||||
@ -118,8 +118,8 @@ def plot_pr_curves(
|
|||||||
def plot_threshold_precision_curve(
|
def plot_threshold_precision_curve(
|
||||||
threshold: np.ndarray,
|
threshold: np.ndarray,
|
||||||
precision: np.ndarray,
|
precision: np.ndarray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -140,8 +140,8 @@ def plot_threshold_precision_curve(
|
|||||||
|
|
||||||
def plot_threshold_precision_curves(
|
def plot_threshold_precision_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
@ -176,8 +176,8 @@ def plot_threshold_precision_curves(
|
|||||||
def plot_threshold_recall_curve(
|
def plot_threshold_recall_curve(
|
||||||
threshold: np.ndarray,
|
threshold: np.ndarray,
|
||||||
recall: np.ndarray,
|
recall: np.ndarray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -198,8 +198,8 @@ def plot_threshold_recall_curve(
|
|||||||
|
|
||||||
def plot_threshold_recall_curves(
|
def plot_threshold_recall_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
):
|
):
|
||||||
@ -235,8 +235,8 @@ def plot_roc_curve(
|
|||||||
fpr: np.ndarray,
|
fpr: np.ndarray,
|
||||||
tpr: np.ndarray,
|
tpr: np.ndarray,
|
||||||
thresholds: np.ndarray,
|
thresholds: np.ndarray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -261,8 +261,8 @@ def plot_roc_curve(
|
|||||||
|
|
||||||
def plot_roc_curves(
|
def plot_roc_curves(
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: axes.Axes | None = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] | None = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class PostprocessConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_postprocess_config(
|
def load_postprocess_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> PostprocessConfig:
|
) -> PostprocessConfig:
|
||||||
"""Load the postprocessing configuration from a file.
|
"""Load the postprocessing configuration from a file.
|
||||||
|
|
||||||
|
|||||||
@ -88,9 +88,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
raw_prediction: RawPrediction,
|
raw_prediction: RawPrediction,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
classification_threshold: Optional[
|
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
float
|
|
||||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
||||||
@ -150,7 +148,7 @@ def get_class_tags(
|
|||||||
class_scores: np.ndarray,
|
class_scores: np.ndarray,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
"""Generate specific PredictedTags based on class scores and decoder.
|
"""Generate specific PredictedTags based on class scores and decoder.
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,7 @@ def extract_detection_peaks(
|
|||||||
feature_heatmap: torch.Tensor,
|
feature_heatmap: torch.Tensor,
|
||||||
classification_heatmap: torch.Tensor,
|
classification_heatmap: torch.Tensor,
|
||||||
max_detections: int = 200,
|
max_detections: int = 200,
|
||||||
threshold: Optional[float] = None,
|
threshold: float | None = None,
|
||||||
) -> List[ClipDetectionsTensor]:
|
) -> List[ClipDetectionsTensor]:
|
||||||
height = detection_heatmap.shape[-2]
|
height = detection_heatmap.shape[-2]
|
||||||
width = detection_heatmap.shape[-1]
|
width = detection_heatmap.shape[-1]
|
||||||
|
|||||||
@ -27,7 +27,7 @@ BatDetect2.
|
|||||||
|
|
||||||
def non_max_suppression(
|
def non_max_suppression(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ __all__ = [
|
|||||||
|
|
||||||
def build_postprocessor(
|
def build_postprocessor(
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[PostprocessConfig] = None,
|
config: PostprocessConfig | None = None,
|
||||||
) -> PostprocessorProtocol:
|
) -> PostprocessorProtocol:
|
||||||
"""Factory function to build the standard postprocessor."""
|
"""Factory function to build the standard postprocessor."""
|
||||||
config = config or PostprocessConfig()
|
config = config or PostprocessConfig()
|
||||||
@ -51,7 +51,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
max_freq: float,
|
max_freq: float,
|
||||||
top_k_per_sec: int = 200,
|
top_k_per_sec: int = 200,
|
||||||
detection_threshold: float = 0.01,
|
detection_threshold: float = 0.01,
|
||||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
||||||
):
|
):
|
||||||
"""Initialize the Postprocessor."""
|
"""Initialize the Postprocessor."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -66,7 +66,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
start_times: Optional[List[float]] = None,
|
start_times: List[float] | None = None,
|
||||||
) -> List[ClipDetectionsTensor]:
|
) -> List[ClipDetectionsTensor]:
|
||||||
detection_heatmap = non_max_suppression(
|
detection_heatmap = non_max_suppression(
|
||||||
output.detection_probs.detach(),
|
output.detection_probs.detach(),
|
||||||
|
|||||||
@ -31,14 +31,14 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def to_xarray(
|
def to_xarray(
|
||||||
array: Union[torch.Tensor, np.ndarray],
|
array: torch.Tensor | np.ndarray,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
min_freq: float = MIN_FREQ,
|
min_freq: float = MIN_FREQ,
|
||||||
max_freq: float = MAX_FREQ,
|
max_freq: float = MAX_FREQ,
|
||||||
name: str = "xarray",
|
name: str = "xarray",
|
||||||
extra_dims: Optional[List[str]] = None,
|
extra_dims: List[str] | None = None,
|
||||||
extra_coords: Optional[Dict[str, np.ndarray]] = None,
|
extra_coords: Dict[str, np.ndarray] | None = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
if isinstance(array, torch.Tensor):
|
if isinstance(array, torch.Tensor):
|
||||||
array = array.detach().cpu().numpy()
|
array = array.detach().cpu().numpy()
|
||||||
|
|||||||
@ -78,11 +78,7 @@ class FixDuration(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
AudioTransform = Annotated[
|
AudioTransform = Annotated[
|
||||||
Union[
|
FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
|
||||||
FixDurationConfig,
|
|
||||||
ScaleAudioConfig,
|
|
||||||
CenterAudioConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -57,6 +57,6 @@ class PreprocessingConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_preprocessing_config(
|
def load_preprocessing_config(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> PreprocessingConfig:
|
) -> PreprocessingConfig:
|
||||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||||
|
|||||||
@ -102,7 +102,7 @@ def compute_output_samplerate(
|
|||||||
|
|
||||||
|
|
||||||
def build_preprocessor(
|
def build_preprocessor(
|
||||||
config: Optional[PreprocessingConfig] = None,
|
config: PreprocessingConfig | None = None,
|
||||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> PreprocessorProtocol:
|
) -> PreprocessorProtocol:
|
||||||
"""Factory function to build the standard preprocessor from configuration."""
|
"""Factory function to build the standard preprocessor from configuration."""
|
||||||
|
|||||||
@ -98,7 +98,7 @@ def _frequency_to_index(
|
|||||||
freq: float,
|
freq: float,
|
||||||
n_fft: int,
|
n_fft: int,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> Optional[int]:
|
) -> int | None:
|
||||||
alpha = freq * 2 / samplerate
|
alpha = freq * 2 / samplerate
|
||||||
height = np.floor(n_fft / 2) + 1
|
height = np.floor(n_fft / 2) + 1
|
||||||
index = int(np.floor(alpha * height))
|
index = int(np.floor(alpha * height))
|
||||||
@ -134,8 +134,8 @@ class FrequencyCrop(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
n_fft: int,
|
n_fft: int,
|
||||||
min_freq: Optional[int] = None,
|
min_freq: int | None = None,
|
||||||
max_freq: Optional[int] = None,
|
max_freq: int | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
@ -181,7 +181,7 @@ class FrequencyCrop(torch.nn.Module):
|
|||||||
|
|
||||||
def build_spectrogram_crop(
|
def build_spectrogram_crop(
|
||||||
config: FrequencyConfig,
|
config: FrequencyConfig,
|
||||||
stft: Optional[STFTConfig] = None,
|
stft: STFTConfig | None = None,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
stft = stft or STFTConfig()
|
stft = stft or STFTConfig()
|
||||||
@ -377,12 +377,7 @@ class PeakNormalize(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
SpectrogramTransform = Annotated[
|
SpectrogramTransform = Annotated[
|
||||||
Union[
|
PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig,
|
||||||
PcenConfig,
|
|
||||||
ScaleAmplitudeConfig,
|
|
||||||
SpectralMeanSubstractionConfig,
|
|
||||||
PeakNormalizeConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -30,16 +30,16 @@ class TargetClassConfig(BaseConfig):
|
|||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
condition_input: Optional[SoundEventConditionConfig] = Field(
|
condition_input: SoundEventConditionConfig | None = Field(
|
||||||
alias="match_if",
|
alias="match_if",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
|
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||||
|
|
||||||
roi: Optional[ROIMapperConfig] = None
|
roi: ROIMapperConfig | None = None
|
||||||
|
|
||||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ class SoundEventClassifier:
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
for name, condition in self.mapping.items():
|
for name, condition in self.mapping.items():
|
||||||
if condition(sound_event_annotation):
|
if condition(sound_event_annotation):
|
||||||
return name
|
return name
|
||||||
|
|||||||
@ -48,7 +48,7 @@ class TargetConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> TargetConfig:
|
) -> TargetConfig:
|
||||||
"""Load the unified target configuration from a file.
|
"""Load the unified target configuration from a file.
|
||||||
|
|
||||||
|
|||||||
@ -414,10 +414,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
|
|
||||||
ROIMapperConfig = Annotated[
|
ROIMapperConfig = Annotated[
|
||||||
Union[
|
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
|
||||||
AnchorBBoxMapperConfig,
|
|
||||||
PeakEnergyBBoxMapperConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""A discriminated union of all supported ROI mapper configurations.
|
"""A discriminated union of all supported ROI mapper configurations.
|
||||||
@ -428,7 +425,7 @@ implementations by using the `name` field as a discriminator.
|
|||||||
|
|
||||||
|
|
||||||
def build_roi_mapper(
|
def build_roi_mapper(
|
||||||
config: Optional[ROIMapperConfig] = None,
|
config: ROIMapperConfig | None = None,
|
||||||
) -> ROITargetMapper:
|
) -> ROITargetMapper:
|
||||||
"""Factory function to create an ROITargetMapper from a config object.
|
"""Factory function to create an ROITargetMapper from a config object.
|
||||||
|
|
||||||
@ -572,9 +569,9 @@ def get_peak_energy_coordinates(
|
|||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
end_time: Optional[float] = None,
|
end_time: float | None = None,
|
||||||
low_freq: float = 0,
|
low_freq: float = 0,
|
||||||
high_freq: Optional[float] = None,
|
high_freq: float | None = None,
|
||||||
loading_buffer: float = 0.05,
|
loading_buffer: float = 0.05,
|
||||||
) -> Position:
|
) -> Position:
|
||||||
"""Find the coordinates of the highest energy point in a spectrogram.
|
"""Find the coordinates of the highest energy point in a spectrogram.
|
||||||
|
|||||||
@ -107,7 +107,7 @@ class Targets(TargetProtocol):
|
|||||||
|
|
||||||
def encode_class(
|
def encode_class(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
"""Encode a sound event annotation to its target class name.
|
"""Encode a sound event annotation to its target class name.
|
||||||
|
|
||||||
Applies the configured class definition rules (including priority)
|
Applies the configured class definition rules (including priority)
|
||||||
@ -182,7 +182,7 @@ class Targets(TargetProtocol):
|
|||||||
self,
|
self,
|
||||||
position: Position,
|
position: Position,
|
||||||
size: Size,
|
size: Size,
|
||||||
class_name: Optional[str] = None,
|
class_name: str | None = None,
|
||||||
) -> data.Geometry:
|
) -> data.Geometry:
|
||||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||||
|
|
||||||
@ -219,7 +219,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
def build_targets(config: TargetConfig | None = None) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -251,7 +251,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|||||||
|
|
||||||
def load_targets(
|
def load_targets(
|
||||||
config_path: data.PathLike,
|
config_path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Load a Targets object directly from a configuration file.
|
"""Load a Targets object directly from a configuration file.
|
||||||
|
|
||||||
@ -292,7 +292,7 @@ def load_targets(
|
|||||||
def iterate_encoded_sound_events(
|
def iterate_encoded_sound_events(
|
||||||
sound_events: Iterable[data.SoundEventAnnotation],
|
sound_events: Iterable[data.SoundEventAnnotation],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
) -> Iterable[Tuple[str | None, Position, Size]]:
|
||||||
for sound_event in sound_events:
|
for sound_event in sound_events:
|
||||||
if not targets.filter(sound_event):
|
if not targets.filter(sound_event):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -42,7 +42,7 @@ __all__ = [
|
|||||||
|
|
||||||
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
||||||
|
|
||||||
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
|
audio_augmentations: Registry[Augmentation, [int, AudioSource | None]] = (
|
||||||
Registry(name="audio_augmentation")
|
Registry(name="audio_augmentation")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ class MixAudio(torch.nn.Module):
|
|||||||
def from_config(
|
def from_config(
|
||||||
config: MixAudioConfig,
|
config: MixAudioConfig,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
source: Optional[AudioSource],
|
source: AudioSource | None,
|
||||||
):
|
):
|
||||||
if source is None:
|
if source is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -207,7 +207,7 @@ class AddEcho(torch.nn.Module):
|
|||||||
def from_config(
|
def from_config(
|
||||||
config: AddEchoConfig,
|
config: AddEchoConfig,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
source: Optional[AudioSource],
|
source: AudioSource | None,
|
||||||
):
|
):
|
||||||
return AddEcho(
|
return AddEcho(
|
||||||
samplerate=samplerate,
|
samplerate=samplerate,
|
||||||
@ -487,33 +487,18 @@ def mask_frequency(
|
|||||||
|
|
||||||
|
|
||||||
AudioAugmentationConfig = Annotated[
|
AudioAugmentationConfig = Annotated[
|
||||||
Union[
|
MixAudioConfig | AddEchoConfig,
|
||||||
MixAudioConfig,
|
|
||||||
AddEchoConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
SpectrogramAugmentationConfig = Annotated[
|
SpectrogramAugmentationConfig = Annotated[
|
||||||
Union[
|
ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
|
||||||
ScaleVolumeConfig,
|
|
||||||
WarpConfig,
|
|
||||||
MaskFrequencyConfig,
|
|
||||||
MaskTimeConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
Union[
|
MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
|
||||||
MixAudioConfig,
|
|
||||||
AddEchoConfig,
|
|
||||||
ScaleVolumeConfig,
|
|
||||||
WarpConfig,
|
|
||||||
MaskFrequencyConfig,
|
|
||||||
MaskTimeConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of individual augmentation config."""
|
"""Type alias for the discriminated union of individual augmentation config."""
|
||||||
@ -559,8 +544,8 @@ class MaybeApply(torch.nn.Module):
|
|||||||
def build_augmentation_from_config(
|
def build_augmentation_from_config(
|
||||||
config: AugmentationConfig,
|
config: AugmentationConfig,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
audio_source: Optional[AudioSource] = None,
|
audio_source: AudioSource | None = None,
|
||||||
) -> Optional[Augmentation]:
|
) -> Augmentation | None:
|
||||||
"""Factory function to build a single augmentation from its config."""
|
"""Factory function to build a single augmentation from its config."""
|
||||||
if config.name == "mix_audio":
|
if config.name == "mix_audio":
|
||||||
if audio_source is None:
|
if audio_source is None:
|
||||||
@ -645,10 +630,10 @@ class AugmentationSequence(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def build_audio_augmentations(
|
def build_audio_augmentations(
|
||||||
steps: Optional[Sequence[AudioAugmentationConfig]] = None,
|
steps: Sequence[AudioAugmentationConfig] | None = None,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
audio_source: Optional[AudioSource] = None,
|
audio_source: AudioSource | None = None,
|
||||||
) -> Optional[Augmentation]:
|
) -> Augmentation | None:
|
||||||
if not steps:
|
if not steps:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -673,8 +658,8 @@ def build_audio_augmentations(
|
|||||||
|
|
||||||
|
|
||||||
def build_spectrogram_augmentations(
|
def build_spectrogram_augmentations(
|
||||||
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
|
steps: Sequence[SpectrogramAugmentationConfig] | None = None,
|
||||||
) -> Optional[Augmentation]:
|
) -> Augmentation | None:
|
||||||
if not steps:
|
if not steps:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -698,9 +683,9 @@ def build_spectrogram_augmentations(
|
|||||||
|
|
||||||
def build_augmentations(
|
def build_augmentations(
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
config: Optional[AugmentationsConfig] = None,
|
config: AugmentationsConfig | None = None,
|
||||||
audio_source: Optional[AudioSource] = None,
|
audio_source: AudioSource | None = None,
|
||||||
) -> Tuple[Optional[Augmentation], Optional[Augmentation]]:
|
) -> Tuple[Augmentation | None, Augmentation | None]:
|
||||||
"""Build a composite augmentation pipeline function from configuration."""
|
"""Build a composite augmentation pipeline function from configuration."""
|
||||||
config = config or DEFAULT_AUGMENTATION_CONFIG
|
config = config or DEFAULT_AUGMENTATION_CONFIG
|
||||||
|
|
||||||
@ -723,7 +708,7 @@ def build_augmentations(
|
|||||||
|
|
||||||
|
|
||||||
def load_augmentation_config(
|
def load_augmentation_config(
|
||||||
path: data.PathLike, field: Optional[str] = None
|
path: data.PathLike, field: str | None = None
|
||||||
) -> AugmentationsConfig:
|
) -> AugmentationsConfig:
|
||||||
"""Load the augmentations configuration from a file."""
|
"""Load the augmentations configuration from a file."""
|
||||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||||
|
|||||||
@ -18,14 +18,14 @@ class CheckpointConfig(BaseConfig):
|
|||||||
monitor: str = "classification/mean_average_precision"
|
monitor: str = "classification/mean_average_precision"
|
||||||
mode: str = "max"
|
mode: str = "max"
|
||||||
save_top_k: int = 1
|
save_top_k: int = 1
|
||||||
filename: Optional[str] = None
|
filename: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def build_checkpoint_callback(
|
def build_checkpoint_callback(
|
||||||
config: Optional[CheckpointConfig] = None,
|
config: CheckpointConfig | None = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
) -> Callback:
|
) -> Callback:
|
||||||
config = config or CheckpointConfig()
|
config = config or CheckpointConfig()
|
||||||
|
|
||||||
|
|||||||
@ -22,20 +22,20 @@ class PLTrainerConfig(BaseConfig):
|
|||||||
accumulate_grad_batches: int = 1
|
accumulate_grad_batches: int = 1
|
||||||
deterministic: bool = True
|
deterministic: bool = True
|
||||||
check_val_every_n_epoch: int = 1
|
check_val_every_n_epoch: int = 1
|
||||||
devices: Union[str, int] = "auto"
|
devices: str | int = "auto"
|
||||||
enable_checkpointing: bool = True
|
enable_checkpointing: bool = True
|
||||||
gradient_clip_val: Optional[float] = None
|
gradient_clip_val: float | None = None
|
||||||
limit_train_batches: Optional[Union[int, float]] = None
|
limit_train_batches: int | float | None = None
|
||||||
limit_test_batches: Optional[Union[int, float]] = None
|
limit_test_batches: int | float | None = None
|
||||||
limit_val_batches: Optional[Union[int, float]] = None
|
limit_val_batches: int | float | None = None
|
||||||
log_every_n_steps: Optional[int] = None
|
log_every_n_steps: int | None = None
|
||||||
max_epochs: Optional[int] = 200
|
max_epochs: int | None = 200
|
||||||
min_epochs: Optional[int] = None
|
min_epochs: int | None = None
|
||||||
max_steps: Optional[int] = None
|
max_steps: int | None = None
|
||||||
min_steps: Optional[int] = None
|
min_steps: int | None = None
|
||||||
max_time: Optional[str] = None
|
max_time: str | None = None
|
||||||
precision: Optional[str] = None
|
precision: str | None = None
|
||||||
val_check_interval: Optional[Union[int, float]] = None
|
val_check_interval: int | float | None = None
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(BaseConfig):
|
class OptimizerConfig(BaseConfig):
|
||||||
@ -57,6 +57,6 @@ class TrainingConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: str | None = None,
|
||||||
) -> TrainingConfig:
|
) -> TrainingConfig:
|
||||||
return load_config(path, schema=TrainingConfig, field=field)
|
return load_config(path, schema=TrainingConfig, field=field)
|
||||||
|
|||||||
@ -44,10 +44,10 @@ class TrainingDataset(Dataset):
|
|||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
clipper: Optional[ClipperProtocol] = None,
|
clipper: ClipperProtocol | None = None,
|
||||||
audio_augmentation: Optional[Augmentation] = None,
|
audio_augmentation: Augmentation | None = None,
|
||||||
spectrogram_augmentation: Optional[Augmentation] = None,
|
spectrogram_augmentation: Augmentation | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
self.clip_annotations = clip_annotations
|
self.clip_annotations = clip_annotations
|
||||||
self.clipper = clipper
|
self.clipper = clipper
|
||||||
@ -108,8 +108,8 @@ class ValidationDataset(Dataset):
|
|||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
clipper: Optional[ClipperProtocol] = None,
|
clipper: ClipperProtocol | None = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
self.clip_annotations = clip_annotations
|
self.clip_annotations = clip_annotations
|
||||||
self.labeller = labeller
|
self.labeller = labeller
|
||||||
@ -165,11 +165,11 @@ class TrainLoaderConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_train_loader(
|
def build_train_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
labeller: Optional[ClipLabeller] = None,
|
labeller: ClipLabeller | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional[TrainLoaderConfig] = None,
|
config: TrainLoaderConfig | None = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
config = config or TrainLoaderConfig()
|
config = config or TrainLoaderConfig()
|
||||||
|
|
||||||
@ -207,11 +207,11 @@ class ValLoaderConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_val_loader(
|
def build_val_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
labeller: Optional[ClipLabeller] = None,
|
labeller: ClipLabeller | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional[ValLoaderConfig] = None,
|
config: ValLoaderConfig | None = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: int | None = None,
|
||||||
):
|
):
|
||||||
logger.info("Building validation data loader...")
|
logger.info("Building validation data loader...")
|
||||||
config = config or ValLoaderConfig()
|
config = config or ValLoaderConfig()
|
||||||
@ -240,10 +240,10 @@ def build_val_loader(
|
|||||||
|
|
||||||
def build_train_dataset(
|
def build_train_dataset(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
labeller: Optional[ClipLabeller] = None,
|
labeller: ClipLabeller | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional[TrainLoaderConfig] = None,
|
config: TrainLoaderConfig | None = None,
|
||||||
) -> TrainingDataset:
|
) -> TrainingDataset:
|
||||||
logger.info("Building training dataset...")
|
logger.info("Building training dataset...")
|
||||||
config = config or TrainLoaderConfig()
|
config = config or TrainLoaderConfig()
|
||||||
@ -291,10 +291,10 @@ def build_train_dataset(
|
|||||||
|
|
||||||
def build_val_dataset(
|
def build_val_dataset(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
labeller: Optional[ClipLabeller] = None,
|
labeller: ClipLabeller | None = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional[ValLoaderConfig] = None,
|
config: ValLoaderConfig | None = None,
|
||||||
) -> ValidationDataset:
|
) -> ValidationDataset:
|
||||||
logger.info("Building validation dataset...")
|
logger.info("Building validation dataset...")
|
||||||
config = config or ValLoaderConfig()
|
config = config or ValLoaderConfig()
|
||||||
|
|||||||
@ -42,10 +42,10 @@ class LabelConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
def build_clip_labeler(
|
def build_clip_labeler(
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: TargetProtocol | None = None,
|
||||||
min_freq: float = MIN_FREQ,
|
min_freq: float = MIN_FREQ,
|
||||||
max_freq: float = MAX_FREQ,
|
max_freq: float = MAX_FREQ,
|
||||||
config: Optional[LabelConfig] = None,
|
config: LabelConfig | None = None,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
"""Construct the final clip labelling function."""
|
"""Construct the final clip labelling function."""
|
||||||
config = config or LabelConfig()
|
config = config or LabelConfig()
|
||||||
@ -153,7 +153,7 @@ def generate_heatmaps(
|
|||||||
|
|
||||||
|
|
||||||
def load_label_config(
|
def load_label_config(
|
||||||
path: data.PathLike, field: Optional[str] = None
|
path: data.PathLike, field: str | None = None
|
||||||
) -> LabelConfig:
|
) -> LabelConfig:
|
||||||
"""Load the heatmap label generation configuration from a file.
|
"""Load the heatmap label generation configuration from a file.
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ def train_loop(
|
|||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
train_dataset: LabeledDataset[TrainInputs],
|
train_dataset: LabeledDataset[TrainInputs],
|
||||||
validation_dataset: LabeledDataset[TrainInputs],
|
validation_dataset: LabeledDataset[TrainInputs],
|
||||||
device: Optional[torch.device] = None,
|
device: torch.device | None = None,
|
||||||
num_epochs: int = 100,
|
num_epochs: int = 100,
|
||||||
learning_rate: float = 1e-4,
|
learning_rate: float = 1e-4,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -106,10 +106,10 @@ def standardize_low_freq(
|
|||||||
|
|
||||||
def format_annotation(
|
def format_annotation(
|
||||||
annotation: types.FileAnnotation,
|
annotation: types.FileAnnotation,
|
||||||
events_of_interest: Optional[List[str]] = None,
|
events_of_interest: List[str] | None = None,
|
||||||
name_replace: Optional[Dict[str, str]] = None,
|
name_replace: Dict[str, str] | None = None,
|
||||||
convert_to_genus: bool = False,
|
convert_to_genus: bool = False,
|
||||||
classes_to_ignore: Optional[List[str]] = None,
|
classes_to_ignore: List[str] | None = None,
|
||||||
) -> types.FileAnnotation:
|
) -> types.FileAnnotation:
|
||||||
formated = []
|
formated = []
|
||||||
for aa in annotation["annotation"]:
|
for aa in annotation["annotation"]:
|
||||||
@ -154,7 +154,7 @@ def format_annotation(
|
|||||||
|
|
||||||
def get_class_names(
|
def get_class_names(
|
||||||
data: List[types.FileAnnotation],
|
data: List[types.FileAnnotation],
|
||||||
classes_to_ignore: Optional[List[str]] = None,
|
classes_to_ignore: List[str] | None = None,
|
||||||
) -> Tuple[StringCounter, List[float]]:
|
) -> Tuple[StringCounter, List[float]]:
|
||||||
"""Extracts class names and their inverse frequencies.
|
"""Extracts class names and their inverse frequencies.
|
||||||
|
|
||||||
@ -201,9 +201,9 @@ def load_set_of_anns(
|
|||||||
*,
|
*,
|
||||||
convert_to_genus: bool = False,
|
convert_to_genus: bool = False,
|
||||||
filter_issues: bool = False,
|
filter_issues: bool = False,
|
||||||
events_of_interest: Optional[List[str]] = None,
|
events_of_interest: List[str] | None = None,
|
||||||
classes_to_ignore: Optional[List[str]] = None,
|
classes_to_ignore: List[str] | None = None,
|
||||||
name_replace: Optional[Dict[str, str]] = None,
|
name_replace: Dict[str, str] | None = None,
|
||||||
) -> List[types.FileAnnotation]:
|
) -> List[types.FileAnnotation]:
|
||||||
# load the annotations
|
# load the annotations
|
||||||
anns = []
|
anns = []
|
||||||
|
|||||||
@ -26,10 +26,10 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Optional[dict] = None,
|
config: dict | None = None,
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
model: Optional[Model] = None,
|
model: Model | None = None,
|
||||||
loss: Optional[torch.nn.Module] = None,
|
loss: torch.nn.Module | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.config import validate_config
|
from batdetect2.config import validate_config
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ def load_model_from_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
config: Optional[dict] = None,
|
config: dict | None = None,
|
||||||
t_max: int = 200,
|
t_max: int = 200,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
return TrainingModule(config=config, t_max=t_max)
|
return TrainingModule(config=config, t_max=t_max)
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class FocalLoss(nn.Module):
|
|||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
beta: float = 4,
|
beta: float = 4,
|
||||||
alpha: float = 2,
|
alpha: float = 2,
|
||||||
class_weights: Optional[torch.Tensor] = None,
|
class_weights: torch.Tensor | None = None,
|
||||||
mask_zero: bool = False,
|
mask_zero: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -422,8 +422,8 @@ class LossFunction(nn.Module, LossProtocol):
|
|||||||
|
|
||||||
|
|
||||||
def build_loss(
|
def build_loss(
|
||||||
config: Optional[LossConfig] = None,
|
config: LossConfig | None = None,
|
||||||
class_weights: Optional[np.ndarray] = None,
|
class_weights: np.ndarray | None = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Factory function to build the main LossFunction from configuration.
|
"""Factory function to build the main LossFunction from configuration.
|
||||||
|
|
||||||
|
|||||||
@ -35,21 +35,21 @@ __all__ = [
|
|||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
labeller: Optional["ClipLabeller"] = None,
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
trainer: Optional[Trainer] = None,
|
trainer: Trainer | None = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: int | None = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: int | None = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Path | None = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
num_epochs: Optional[int] = None,
|
num_epochs: int | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
@ -126,11 +126,11 @@ def train(
|
|||||||
def build_trainer(
|
def build_trainer(
|
||||||
config: "BatDetect2Config",
|
config: "BatDetect2Config",
|
||||||
evaluator: "EvaluatorProtocol",
|
evaluator: "EvaluatorProtocol",
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Path | None = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: str | None = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: str | None = None,
|
||||||
num_epochs: Optional[int] = None,
|
num_epochs: int | None = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = config.train.trainer
|
trainer_conf = config.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
|
|||||||
@ -240,7 +240,7 @@ class ProcessingConfiguration(TypedDict):
|
|||||||
detection_threshold: float
|
detection_threshold: float
|
||||||
"""Threshold for detection probability."""
|
"""Threshold for detection probability."""
|
||||||
|
|
||||||
time_expansion: Optional[float]
|
time_expansion: float | None
|
||||||
"""Time expansion factor of the processed recordings."""
|
"""Time expansion factor of the processed recordings."""
|
||||||
|
|
||||||
top_n: int
|
top_n: int
|
||||||
@ -249,7 +249,7 @@ class ProcessingConfiguration(TypedDict):
|
|||||||
return_raw_preds: bool
|
return_raw_preds: bool
|
||||||
"""Whether to return raw predictions."""
|
"""Whether to return raw predictions."""
|
||||||
|
|
||||||
max_duration: Optional[float]
|
max_duration: float | None
|
||||||
"""Maximum duration of audio file to process in seconds."""
|
"""Maximum duration of audio file to process in seconds."""
|
||||||
|
|
||||||
nms_kernel_size: int
|
nms_kernel_size: int
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[T],
|
predictions: Sequence[T],
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
audio_dir: Optional[PathLike] = None,
|
audio_dir: PathLike | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def load(self, path: PathLike) -> List[T]: ...
|
def load(self, path: PathLike) -> List[T]: ...
|
||||||
|
|||||||
@ -28,19 +28,19 @@ __all__ = [
|
|||||||
class MatchEvaluation:
|
class MatchEvaluation:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
|
|
||||||
sound_event_annotation: Optional[data.SoundEventAnnotation]
|
sound_event_annotation: data.SoundEventAnnotation | None
|
||||||
gt_det: bool
|
gt_det: bool
|
||||||
gt_class: Optional[str]
|
gt_class: str | None
|
||||||
gt_geometry: Optional[data.Geometry]
|
gt_geometry: data.Geometry | None
|
||||||
|
|
||||||
pred_score: float
|
pred_score: float
|
||||||
pred_class_scores: Dict[str, float]
|
pred_class_scores: Dict[str, float]
|
||||||
pred_geometry: Optional[data.Geometry]
|
pred_geometry: data.Geometry | None
|
||||||
|
|
||||||
affinity: float
|
affinity: float
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def top_class(self) -> Optional[str]:
|
def top_class(self) -> str | None:
|
||||||
if not self.pred_class_scores:
|
if not self.pred_class_scores:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ class MatcherProtocol(Protocol):
|
|||||||
ground_truth: Sequence[data.Geometry],
|
ground_truth: Sequence[data.Geometry],
|
||||||
predictions: Sequence[data.Geometry],
|
predictions: Sequence[data.Geometry],
|
||||||
scores: Sequence[float],
|
scores: Sequence[float],
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ...
|
) -> Iterable[Tuple[int | None, int | None, float]]: ...
|
||||||
|
|
||||||
|
|
||||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||||
|
|||||||
@ -42,7 +42,7 @@ class GeometryDecoder(Protocol):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, position: Position, size: Size, class_name: Optional[str] = None
|
self, position: Position, size: Size, class_name: str | None = None
|
||||||
) -> data.Geometry: ...
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
|
|
||||||
@ -93,5 +93,5 @@ class PostprocessorProtocol(Protocol):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
start_times: Optional[Sequence[float]] = None,
|
start_times: Sequence[float] | None = None,
|
||||||
) -> List[ClipDetectionsTensor]: ...
|
) -> List[ClipDetectionsTensor]: ...
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class AudioLoader(Protocol):
|
|||||||
def load_file(
|
def load_file(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio directly from a file path.
|
"""Load and preprocess audio directly from a file path.
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class AudioLoader(Protocol):
|
|||||||
def load_recording(
|
def load_recording(
|
||||||
self,
|
self,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio for a Recording object.
|
"""Load and preprocess the entire audio for a Recording object.
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ class AudioLoader(Protocol):
|
|||||||
def load_clip(
|
def load_clip(
|
||||||
self,
|
self,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the audio segment defined by a Clip object.
|
"""Load and preprocess the audio segment defined by a Clip object.
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user