mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Update type hints to python 3.10
This commit is contained in:
parent
9c72537ddd
commit
2563f26ed3
@ -93,7 +93,7 @@ mlflow = ["mlflow>=3.1.1"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
target-version = "py39"
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
@ -107,7 +107,7 @@ convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src", "tests"]
|
||||
pythonVersion = "3.9"
|
||||
pythonVersion = "3.10"
|
||||
pythonPlatform = "All"
|
||||
exclude = [
|
||||
"src/batdetect2/detector/",
|
||||
|
||||
@ -165,7 +165,7 @@ def load_audio(
|
||||
time_exp_fact: float = 1,
|
||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
max_duration: float | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load audio from file.
|
||||
|
||||
@ -203,7 +203,7 @@ def load_audio(
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[SpectrogramParameters] = None,
|
||||
config: SpectrogramParameters | None = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> torch.Tensor:
|
||||
"""Generate spectrogram from audio array.
|
||||
@ -240,7 +240,7 @@ def generate_spectrogram(
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
@ -271,7 +271,7 @@ def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
@ -312,7 +312,7 @@ def process_audio(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
@ -356,7 +356,7 @@ def process_audio(
|
||||
def postprocess(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Postprocess model outputs.
|
||||
|
||||
|
||||
@ -67,22 +67,22 @@ class BatDetect2API:
|
||||
def load_annotations(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
return load_dataset_from_config(path, base_dir=base_dir)
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
train_workers: int | None = None,
|
||||
val_workers: int | None = None,
|
||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||
experiment_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
):
|
||||
train(
|
||||
train_annotations=train_annotations,
|
||||
@ -105,10 +105,10 @@ class BatDetect2API:
|
||||
def evaluate(
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
num_workers: Optional[int] = None,
|
||||
num_workers: int | None = None,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
save_predictions: bool = True,
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
return evaluate(
|
||||
@ -129,7 +129,7 @@ class BatDetect2API:
|
||||
self,
|
||||
annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
output_dir: Optional[data.PathLike] = None,
|
||||
output_dir: data.PathLike | None = None,
|
||||
):
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
annotations,
|
||||
@ -221,7 +221,7 @@ class BatDetect2API:
|
||||
def process_files(
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
num_workers: Optional[int] = None,
|
||||
num_workers: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
@ -236,8 +236,8 @@ class BatDetect2API:
|
||||
def process_clips(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
batch_size: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
@ -254,9 +254,9 @@ class BatDetect2API:
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
format: Optional[str] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
):
|
||||
formatter = self.formatter
|
||||
|
||||
@ -331,7 +331,7 @@ class BatDetect2API:
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: Optional[BatDetect2Config] = None,
|
||||
config: BatDetect2Config | None = None,
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
|
||||
|
||||
@ -245,16 +245,12 @@ class FixedDurationClip:
|
||||
|
||||
|
||||
ClipConfig = Annotated[
|
||||
Union[
|
||||
RandomClipConfig,
|
||||
PaddedClipConfig,
|
||||
FixedDurationClipConfig,
|
||||
],
|
||||
RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
||||
def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
|
||||
config = config or RandomClipConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
|
||||
@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
|
||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||
|
||||
|
||||
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
||||
def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
config: ResampleConfig | None = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
@ -112,9 +112,9 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
@ -136,9 +136,9 @@ def load_file_audio(
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
@ -158,9 +158,9 @@ def load_recording_audio(
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
|
||||
@ -34,9 +34,9 @@ def data(): ...
|
||||
)
|
||||
def summary(
|
||||
dataset_config: Path,
|
||||
field: Optional[str] = None,
|
||||
targets_path: Optional[Path] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
field: str | None = None,
|
||||
targets_path: Path | None = None,
|
||||
base_dir: Path | None = None,
|
||||
):
|
||||
from batdetect2.data import compute_class_summary, load_dataset_from_config
|
||||
from batdetect2.targets import load_targets
|
||||
@ -83,9 +83,9 @@ def summary(
|
||||
)
|
||||
def convert(
|
||||
dataset_config: Path,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
output: Path = Path("annotations.json"),
|
||||
base_dir: Optional[Path] = None,
|
||||
base_dir: Path | None = None,
|
||||
):
|
||||
"""Convert a dataset config file to soundevent format."""
|
||||
from soundevent import data, io
|
||||
|
||||
@ -25,11 +25,11 @@ def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
base_dir: Path,
|
||||
config_path: Optional[Path],
|
||||
config_path: Path | None,
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: Optional[int] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
num_workers: int | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
|
||||
@ -26,19 +26,19 @@ __all__ = ["train_command"]
|
||||
@click.option("--seed", type=int)
|
||||
def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Optional[Path] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
val_dataset: Path | None = None,
|
||||
model_path: Path | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
config: Path | None = None,
|
||||
targets_config: Path | None = None,
|
||||
config_field: str | None = None,
|
||||
seed: int | None = None,
|
||||
num_epochs: int | None = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
|
||||
@ -4,7 +4,7 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
@ -17,7 +17,7 @@ from batdetect2.types import (
|
||||
FileAnnotation,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
PathLike = Path | str | os.PathLike
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
|
||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
||||
|
||||
ClassFn = Callable[[data.Recording], int]
|
||||
|
||||
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
audio_dir: PathLike | None = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
|
||||
@ -43,7 +43,7 @@ class BatDetect2Config(BaseConfig):
|
||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||
|
||||
|
||||
def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
||||
def validate_config(config: dict | None) -> BatDetect2Config:
|
||||
if config is None:
|
||||
return BatDetect2Config()
|
||||
|
||||
@ -52,6 +52,6 @@ def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
||||
|
||||
def load_full_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> BatDetect2Config:
|
||||
return load_config(path, schema=BatDetect2Config, field=field)
|
||||
|
||||
@ -86,8 +86,8 @@ def adjust_width(
|
||||
|
||||
def slice_tensor(
|
||||
tensor: torch.Tensor,
|
||||
start: Optional[int] = None,
|
||||
end: Optional[int] = None,
|
||||
start: int | None = None,
|
||||
end: int | None = None,
|
||||
dim: int = -1,
|
||||
) -> torch.Tensor:
|
||||
slices = [slice(None)] * tensor.ndim
|
||||
|
||||
@ -128,7 +128,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T],
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> T:
|
||||
"""Load and validate configuration data from a file against a schema.
|
||||
|
||||
|
||||
@ -43,11 +43,7 @@ __all__ = [
|
||||
|
||||
|
||||
AnnotationFormats = Annotated[
|
||||
Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
],
|
||||
BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
"""Type Alias representing all supported data source configurations.
|
||||
@ -63,7 +59,7 @@ source configuration represents.
|
||||
|
||||
def load_annotated_dataset(
|
||||
dataset: AnnotatedDataset,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations for a single data source based on its configuration.
|
||||
|
||||
|
||||
@ -77,14 +77,14 @@ class AOEFAnnotations(AnnotatedDataset):
|
||||
|
||||
annotations_path: Path
|
||||
|
||||
filter: Optional[AnnotationTaskFilter] = Field(
|
||||
filter: AnnotationTaskFilter | None = Field(
|
||||
default_factory=AnnotationTaskFilter
|
||||
)
|
||||
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, ValidationError
|
||||
@ -43,7 +43,7 @@ from batdetect2.data.annotations.legacy import (
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
PathLike = Path | str | os.PathLike
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -102,7 +102,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
annotations_dir: Path
|
||||
|
||||
filter: Optional[AnnotationFilter] = Field(
|
||||
filter: AnnotationFilter | None = Field(
|
||||
default_factory=AnnotationFilter,
|
||||
)
|
||||
|
||||
@ -133,14 +133,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
annotations_path: Path
|
||||
|
||||
filter: Optional[AnnotationFilter] = Field(
|
||||
filter: AnnotationFilter | None = Field(
|
||||
default_factory=AnnotationFilter,
|
||||
)
|
||||
|
||||
|
||||
def load_batdetect2_files_annotated_dataset(
|
||||
dataset: BatDetect2FilesAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
base_dir: PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
|
||||
|
||||
@ -244,7 +244,7 @@ def load_batdetect2_files_annotated_dataset(
|
||||
|
||||
def load_batdetect2_merged_annotated_dataset(
|
||||
dataset: BatDetect2MergedAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
base_dir: PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
|
||||
|
||||
|
||||
@ -3,12 +3,12 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
PathLike = Path | str | os.PathLike
|
||||
|
||||
__all__ = []
|
||||
|
||||
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
|
||||
)
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
||||
|
||||
ClassFn = Callable[[data.Recording], int]
|
||||
|
||||
@ -130,7 +130,7 @@ def get_sound_event_tags(
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
audio_dir: PathLike | None = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
|
||||
@ -264,16 +264,7 @@ class Not:
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
Union[
|
||||
HasTagConfig,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
DurationConfig,
|
||||
FrequencyConfig,
|
||||
AllOfConfig,
|
||||
AnyOfConfig,
|
||||
NotConfig,
|
||||
],
|
||||
HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
|
||||
description: str
|
||||
sources: List[AnnotationFormats]
|
||||
|
||||
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
||||
sound_event_filter: SoundEventConditionConfig | None = None
|
||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
|
||||
|
||||
def load_dataset(
|
||||
config: DatasetConfig,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
clip_annotations = []
|
||||
@ -161,14 +161,14 @@ def insert_source_tag(
|
||||
)
|
||||
|
||||
|
||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||
def load_dataset_config(path: data.PathLike, field: str | None = None):
|
||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||
|
||||
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
field: str | None = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
"""Load dataset annotation metadata from a configuration file.
|
||||
|
||||
@ -215,9 +215,9 @@ def load_dataset_from_config(
|
||||
def save_dataset(
|
||||
dataset: Dataset,
|
||||
path: data.PathLike,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
audio_dir: Optional[Path] = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
audio_dir: Path | None = None,
|
||||
) -> None:
|
||||
"""Save a loaded dataset (list of ClipAnnotations) to a file.
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.typing.targets import TargetProtocol
|
||||
def iterate_over_sound_events(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||
) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]:
|
||||
"""Iterate over sound events in a dataset.
|
||||
|
||||
Parameters
|
||||
|
||||
@ -24,19 +24,14 @@ __all__ = [
|
||||
|
||||
|
||||
OutputFormatConfig = Annotated[
|
||||
Union[
|
||||
BatDetect2OutputConfig,
|
||||
ParquetOutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
RawOutputConfig,
|
||||
],
|
||||
BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_output_formatter(
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> OutputFormatterProtocol:
|
||||
"""Construct the final output formatter."""
|
||||
from batdetect2.targets import build_targets
|
||||
@ -48,9 +43,9 @@ def build_output_formatter(
|
||||
|
||||
|
||||
def get_output_formatter(
|
||||
name: Optional[str] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
name: str | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> OutputFormatterProtocol:
|
||||
"""Get the output formatter by name."""
|
||||
|
||||
@ -71,9 +66,9 @@ def get_output_formatter(
|
||||
|
||||
def load_predictions(
|
||||
path: PathLike,
|
||||
format: Optional[str] = "raw",
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
format: str | None = "raw",
|
||||
config: OutputFormatConfig | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
):
|
||||
"""Load predictions from a file."""
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
@ -123,7 +123,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
self,
|
||||
predictions: Sequence[FileAnnotation],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> None:
|
||||
path = Path(path)
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> None:
|
||||
path = Path(path)
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> None:
|
||||
path = Path(path)
|
||||
|
||||
@ -84,7 +84,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
def pred_to_xr(
|
||||
self,
|
||||
prediction: BatDetect2Prediction,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> xr.Dataset:
|
||||
clip = prediction.clip
|
||||
recording = clip.recording
|
||||
|
||||
@ -18,16 +18,16 @@ from batdetect2.typing import (
|
||||
|
||||
class SoundEventOutputConfig(BaseConfig):
|
||||
name: Literal["soundevent"] = "soundevent"
|
||||
top_k: Optional[int] = 1
|
||||
min_score: Optional[float] = None
|
||||
top_k: int | None = 1
|
||||
min_score: float | None = None
|
||||
|
||||
|
||||
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
top_k: Optional[int] = 1,
|
||||
min_score: Optional[float] = 0,
|
||||
top_k: int | None = 1,
|
||||
min_score: float | None = 0,
|
||||
):
|
||||
self.targets = targets
|
||||
self.top_k = top_k
|
||||
@ -45,7 +45,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||
self,
|
||||
predictions: Sequence[data.ClipPrediction],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> None:
|
||||
run = data.PredictionSet(clip_predictions=list(predictions))
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ def split_dataset_by_recordings(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
train_size: float = 0.75,
|
||||
random_state: Optional[int] = None,
|
||||
random_state: int | None = None,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
recordings = extract_recordings_df(dataset)
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ class MapTagValueConfig(BaseConfig):
|
||||
name: Literal["map_tag_value"] = "map_tag_value"
|
||||
tag_key: str
|
||||
value_mapping: Dict[str, str]
|
||||
target_key: Optional[str] = None
|
||||
target_key: str | None = None
|
||||
|
||||
|
||||
class MapTagValue:
|
||||
@ -150,7 +150,7 @@ class MapTagValue:
|
||||
self,
|
||||
tag_key: str,
|
||||
value_mapping: Dict[str, str],
|
||||
target_key: Optional[str] = None,
|
||||
target_key: str | None = None,
|
||||
):
|
||||
self.tag_key = tag_key
|
||||
self.value_mapping = value_mapping
|
||||
@ -221,13 +221,7 @@ class ApplyAll:
|
||||
|
||||
|
||||
SoundEventTransformConfig = Annotated[
|
||||
Union[
|
||||
SetFrequencyBoundConfig,
|
||||
ReplaceTagConfig,
|
||||
MapTagValueConfig,
|
||||
ApplyIfConfig,
|
||||
ApplyAllConfig,
|
||||
],
|
||||
SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ def compute_bandwidth(
|
||||
|
||||
def compute_max_power_bb(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -131,7 +131,7 @@ def compute_max_power_bb(
|
||||
|
||||
def compute_max_power(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -157,7 +157,7 @@ def compute_max_power(
|
||||
|
||||
def compute_max_power_first(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -184,7 +184,7 @@ def compute_max_power_first(
|
||||
|
||||
def compute_max_power_second(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -211,7 +211,7 @@ def compute_max_power_second(
|
||||
|
||||
def compute_call_interval(
|
||||
prediction: types.Prediction,
|
||||
previous: Optional[types.Prediction] = None,
|
||||
previous: types.Prediction | None = None,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute time between this call and the previous call in seconds."""
|
||||
|
||||
@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
|
||||
def get_params(
|
||||
make_dirs: bool = False,
|
||||
exps_dir: str = "../../experiments/",
|
||||
model_name: Optional[str] = None,
|
||||
experiment: Union[Path, str, None] = None,
|
||||
model_name: str | None = None,
|
||||
experiment: Path | str | None = None,
|
||||
**kwargs,
|
||||
) -> TrainingParameters:
|
||||
experiments_dir = Path(exps_dir)
|
||||
|
||||
@ -151,7 +151,7 @@ def run_nms(
|
||||
|
||||
def non_max_suppression(
|
||||
heat: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
kernel_size: int | Tuple[int, int],
|
||||
):
|
||||
# kernel can be an int or list/tuple
|
||||
if isinstance(kernel_size, int):
|
||||
|
||||
@ -213,18 +213,13 @@ class GeometricIOU(AffinityFunction):
|
||||
|
||||
|
||||
AffinityConfig = Annotated[
|
||||
Union[
|
||||
TimeAffinityConfig,
|
||||
IntervalIOUConfig,
|
||||
BBoxIOUConfig,
|
||||
GeometricIOUConfig,
|
||||
],
|
||||
TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_affinity_function(
|
||||
config: Optional[AffinityConfig] = None,
|
||||
config: AffinityConfig | None = None,
|
||||
) -> AffinityFunction:
|
||||
config = config or GeometricIOUConfig()
|
||||
return affinity_functions.build(config)
|
||||
|
||||
@ -51,6 +51,6 @@ def get_default_eval_config() -> EvaluationConfig:
|
||||
|
||||
def load_evaluation_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> EvaluationConfig:
|
||||
return load_config(path, schema=EvaluationConfig, field=field)
|
||||
|
||||
@ -39,8 +39,8 @@ class TestDataset(Dataset[TestExample]):
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
clipper: ClipperProtocol | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
):
|
||||
self.clip_annotations = list(clip_annotations)
|
||||
self.clipper = clipper
|
||||
@ -78,10 +78,10 @@ class TestLoaderConfig(BaseConfig):
|
||||
|
||||
def build_test_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TestLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> DataLoader[TestExample]:
|
||||
logger.info("Building test data loader...")
|
||||
config = config or TestLoaderConfig()
|
||||
@ -109,9 +109,9 @@ def build_test_loader(
|
||||
|
||||
def build_test_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TestLoaderConfig | None = None,
|
||||
) -> TestDataset:
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TestLoaderConfig()
|
||||
|
||||
@ -34,10 +34,10 @@ def evaluate(
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
formatter: Optional["OutputFormatterProtocol"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
num_workers: int | None = None,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
|
||||
@ -51,8 +51,8 @@ class Evaluator:
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[Union[EvaluationConfig, dict]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: EvaluationConfig | dict | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
|
||||
@ -35,9 +35,9 @@ def match(
|
||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||
raw_predictions: Sequence[RawPrediction],
|
||||
clip: data.Clip,
|
||||
scores: Optional[Sequence[float]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
scores: Sequence[float] | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
matcher: MatcherProtocol | None = None,
|
||||
) -> ClipMatches:
|
||||
if matcher is None:
|
||||
matcher = build_matcher()
|
||||
@ -151,7 +151,7 @@ def match_start_times(
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
distance_threshold: float = 0.01,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||
if not ground_truth:
|
||||
for index in range(len(predictions)):
|
||||
yield index, None, 0
|
||||
@ -287,7 +287,7 @@ def greedy_match(
|
||||
scores: Sequence[float],
|
||||
affinity_threshold: float = 0.5,
|
||||
affinity_function: AffinityFunction = compute_affinity,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||
|
||||
Iterates through source geometries, prioritizing by score if provided. Each
|
||||
@ -514,12 +514,7 @@ class OptimalMatcher(MatcherProtocol):
|
||||
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[
|
||||
GreedyMatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
OptimalMatchConfig,
|
||||
GreedyAffinityMatchConfig,
|
||||
],
|
||||
GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -558,7 +553,7 @@ def compute_affinity_matrix(
|
||||
def select_optimal_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
gts = set(range(num_gt))
|
||||
preds = set(range(num_pred))
|
||||
@ -588,7 +583,7 @@ def select_optimal_matches(
|
||||
def select_greedy_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
unmatched_pred = set(range(num_pred))
|
||||
|
||||
@ -612,6 +607,6 @@ def select_greedy_matches(
|
||||
yield pred_idx, None, 0
|
||||
|
||||
|
||||
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
||||
def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
|
||||
config = config or StartTimeMatchConfig()
|
||||
return matching_strategies.build(config)
|
||||
|
||||
@ -36,13 +36,13 @@ __all__ = [
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
true_class: Optional[str]
|
||||
true_class: str | None
|
||||
score: float
|
||||
|
||||
|
||||
@ -61,16 +61,16 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
||||
|
||||
|
||||
class BaseClassificationConfig(BaseConfig):
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
include: List[str] | None = None
|
||||
exclude: List[str] | None = None
|
||||
|
||||
|
||||
class BaseClassificationMetric:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: List[str] | None = None,
|
||||
exclude: List[str] | None = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.include = include
|
||||
@ -100,8 +100,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "average_precision",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: List[str] | None = None,
|
||||
exclude: List[str] | None = None,
|
||||
):
|
||||
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
@ -169,8 +169,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "roc_auc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: List[str] | None = None,
|
||||
exclude: List[str] | None = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
@ -225,10 +225,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
|
||||
|
||||
ClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
],
|
||||
ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -123,10 +123,7 @@ class ClipClassificationROCAUC:
|
||||
|
||||
|
||||
ClipClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipClassificationAveragePrecisionConfig,
|
||||
ClipClassificationROCAUCConfig,
|
||||
],
|
||||
ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -159,12 +159,7 @@ class ClipDetectionPrecision:
|
||||
|
||||
|
||||
ClipDetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipDetectionAveragePrecisionConfig,
|
||||
ClipDetectionROCAUCConfig,
|
||||
ClipDetectionRecallConfig,
|
||||
ClipDetectionPrecisionConfig,
|
||||
],
|
||||
ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ __all__ = [
|
||||
def compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
num_positives: int | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
@ -41,7 +41,7 @@ def compute_precision_recall(
|
||||
def average_precision(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
num_positives: int | None = None,
|
||||
) -> float:
|
||||
if num_positives == 0:
|
||||
return np.nan
|
||||
|
||||
@ -28,8 +28,8 @@ __all__ = [
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
@ -212,12 +212,7 @@ class DetectionPrecision:
|
||||
|
||||
|
||||
DetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionROCAUCConfig,
|
||||
DetectionRecallConfig,
|
||||
DetectionPrecisionConfig,
|
||||
],
|
||||
DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -31,14 +31,14 @@ __all__ = [
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
is_prediction: bool
|
||||
pred_class: Optional[str]
|
||||
true_class: Optional[str]
|
||||
pred_class: str | None
|
||||
true_class: str | None
|
||||
score: float
|
||||
|
||||
|
||||
@ -301,13 +301,7 @@ class BalancedAccuracy:
|
||||
|
||||
|
||||
TopClassMetricConfig = Annotated[
|
||||
Union[
|
||||
TopClassAveragePrecisionConfig,
|
||||
TopClassROCAUCConfig,
|
||||
TopClassRecallConfig,
|
||||
TopClassPrecisionConfig,
|
||||
BalancedAccuracyConfig,
|
||||
],
|
||||
TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.typing import TargetProtocol
|
||||
class BasePlotConfig(BaseConfig):
|
||||
label: str = "plot"
|
||||
theme: str = "default"
|
||||
title: Optional[str] = None
|
||||
title: str | None = None
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
dpi: int = 100
|
||||
|
||||
@ -21,7 +21,7 @@ class BasePlot:
|
||||
targets: TargetProtocol,
|
||||
label: str = "plot",
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
title: Optional[str] = None,
|
||||
title: str | None = None,
|
||||
dpi: int = 100,
|
||||
theme: str = "default",
|
||||
):
|
||||
|
||||
@ -45,7 +45,7 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Classification Precision-Recall Curve"
|
||||
title: str | None = "Classification Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -108,7 +108,7 @@ class PRCurve(BasePlot):
|
||||
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||
label: str = "threshold_precision_curve"
|
||||
title: Optional[str] = "Classification Threshold-Precision Curve"
|
||||
title: str | None = "Classification Threshold-Precision Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -181,7 +181,7 @@ class ThresholdPrecisionCurve(BasePlot):
|
||||
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||
label: str = "threshold_recall_curve"
|
||||
title: Optional[str] = "Classification Threshold-Recall Curve"
|
||||
title: str | None = "Classification Threshold-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -254,7 +254,7 @@ class ThresholdRecallCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Classification ROC Curve"
|
||||
title: str | None = "Classification ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -326,12 +326,7 @@ class ROCCurve(BasePlot):
|
||||
|
||||
|
||||
ClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ThresholdPrecisionCurveConfig,
|
||||
ThresholdRecallCurveConfig,
|
||||
],
|
||||
PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ clip_classification_plots: Registry[
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
||||
title: str | None = "Clip Classification Precision-Recall Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Clip Classification ROC Curve"
|
||||
title: str | None = "Clip Classification ROC Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
@ -174,10 +174,7 @@ class ROCCurve(BasePlot):
|
||||
|
||||
|
||||
ClipClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
],
|
||||
PRCurveConfig | ROCCurveConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
||||
title: str | None = "Clip Detection Precision-Recall Curve"
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
@ -74,7 +74,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Clip Detection ROC Curve"
|
||||
title: str | None = "Clip Detection ROC Curve"
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
@ -107,7 +107,7 @@ class ROCCurve(BasePlot):
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Clip Detection Score Distribution"
|
||||
title: str | None = "Clip Detection Score Distribution"
|
||||
|
||||
|
||||
class ScoreDistributionPlot(BasePlot):
|
||||
@ -147,11 +147,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
|
||||
|
||||
ClipDetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
],
|
||||
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Detection Precision-Recall Curve"
|
||||
title: str | None = "Detection Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -100,7 +100,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Detection ROC Curve"
|
||||
title: str | None = "Detection ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -159,7 +159,7 @@ class ROCCurve(BasePlot):
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Detection Score Distribution"
|
||||
title: str | None = "Detection Score Distribution"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -226,7 +226,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_detection"] = "example_detection"
|
||||
label: str = "example_detection"
|
||||
title: Optional[str] = "Example Detection"
|
||||
title: str | None = "Example Detection"
|
||||
figsize: tuple[int, int] = (10, 4)
|
||||
num_examples: int = 5
|
||||
threshold: float = 0.2
|
||||
@ -292,12 +292,7 @@ class ExampleDetectionPlot(BasePlot):
|
||||
|
||||
|
||||
DetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
ExampleDetectionPlotConfig,
|
||||
],
|
||||
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Top Class Precision-Recall Curve"
|
||||
title: str | None = "Top Class Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Top Class ROC Curve"
|
||||
title: str | None = "Top Class ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -173,7 +173,7 @@ class ROCCurve(BasePlot):
|
||||
|
||||
class ConfusionMatrixConfig(BasePlotConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
title: Optional[str] = "Top Class Confusion Matrix"
|
||||
title: str | None = "Top Class Confusion Matrix"
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
label: str = "confusion_matrix"
|
||||
exclude_generic: bool = True
|
||||
@ -257,7 +257,7 @@ class ConfusionMatrix(BasePlot):
|
||||
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_classification"] = "example_classification"
|
||||
label: str = "example_classification"
|
||||
title: Optional[str] = "Example Classification"
|
||||
title: str | None = "Example Classification"
|
||||
num_examples: int = 4
|
||||
threshold: float = 0.2
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
@ -348,12 +348,7 @@ class ExampleClassificationPlot(BasePlot):
|
||||
|
||||
|
||||
TopClassPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ConfusionMatrixConfig,
|
||||
ExampleClassificationPlotConfig,
|
||||
],
|
||||
PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ def extract_matches_dataframe(
|
||||
|
||||
|
||||
EvaluationTableConfig = Annotated[
|
||||
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||
FullEvaluationTableConfig, Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -26,20 +26,14 @@ __all__ = [
|
||||
|
||||
|
||||
TaskConfig = Annotated[
|
||||
Union[
|
||||
ClassificationTaskConfig,
|
||||
DetectionTaskConfig,
|
||||
ClipDetectionTaskConfig,
|
||||
ClipClassificationTaskConfig,
|
||||
TopClassDetectionTaskConfig,
|
||||
],
|
||||
ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_task(
|
||||
config: TaskConfig,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
return tasks_registry.build(config, targets)
|
||||
@ -49,8 +43,8 @@ def evaluate_task(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
task: Optional["str"] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[Union[TaskConfig, dict]] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: TaskConfig | dict | None = None,
|
||||
):
|
||||
if isinstance(config, BaseTaskConfig):
|
||||
task_obj = build_task(config, targets)
|
||||
|
||||
@ -67,9 +67,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
prefix: str,
|
||||
ignore_start_end: float = 0.01,
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
||||
):
|
||||
self.matcher = matcher
|
||||
self.metrics = metrics
|
||||
@ -147,9 +145,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
config: BaseTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
matcher = build_matcher(config.matching_strategy)
|
||||
|
||||
@ -98,8 +98,8 @@ def load_annotations(
|
||||
dataset_name: str,
|
||||
ann_path: str,
|
||||
audio_path: str,
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
events_of_interest: Optional[List[str]] = None,
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
events_of_interest: List[str] | None = None,
|
||||
) -> List[types.FileAnnotation]:
|
||||
train_sets: List[types.DatasetDict] = []
|
||||
train_sets.append(
|
||||
|
||||
@ -13,7 +13,7 @@ from batdetect2 import types
|
||||
|
||||
def print_dataset_stats(
|
||||
data: List[types.FileAnnotation],
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
) -> Counter[str]:
|
||||
print("Num files:", len(data))
|
||||
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||
|
||||
@ -28,8 +28,8 @@ def run_batch_inference(
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
@ -69,7 +69,7 @@ def process_file_list(
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
num_workers: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
clip_config = config.inference.clipping
|
||||
clips = get_clips_from_files(
|
||||
|
||||
@ -36,7 +36,7 @@ class InferenceDataset(Dataset[DatasetItem]):
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
):
|
||||
self.clips = list(clips)
|
||||
self.preprocessor = preprocessor
|
||||
@ -66,11 +66,11 @@ class InferenceLoaderConfig(BaseConfig):
|
||||
|
||||
def build_inference_loader(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[InferenceLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: InferenceLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> DataLoader[DatasetItem]:
|
||||
logger.info("Building inference data loader...")
|
||||
config = config or InferenceLoaderConfig()
|
||||
@ -95,8 +95,8 @@ def build_inference_loader(
|
||||
|
||||
def build_inference_dataset(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
) -> InferenceDataset:
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
@ -49,14 +49,14 @@ def enable_logging(level: int):
|
||||
|
||||
class BaseLoggerConfig(BaseConfig):
|
||||
log_dir: Path = DEFAULT_LOGS_DIR
|
||||
experiment_name: Optional[str] = None
|
||||
run_name: Optional[str] = None
|
||||
experiment_name: str | None = None
|
||||
run_name: str | None = None
|
||||
|
||||
|
||||
class DVCLiveConfig(BaseLoggerConfig):
|
||||
name: Literal["dvclive"] = "dvclive"
|
||||
prefix: str = ""
|
||||
log_model: Union[bool, Literal["all"]] = False
|
||||
log_model: bool | Literal["all"] = False
|
||||
monitor_system: bool = False
|
||||
|
||||
|
||||
@ -72,18 +72,13 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
|
||||
|
||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||
name: Literal["mlflow"] = "mlflow"
|
||||
tracking_uri: Optional[str] = "http://localhost:5000"
|
||||
tags: Optional[dict[str, Any]] = None
|
||||
tracking_uri: str | None = "http://localhost:5000"
|
||||
tags: dict[str, Any] | None = None
|
||||
log_model: bool = False
|
||||
|
||||
|
||||
LoggerConfig = Annotated[
|
||||
Union[
|
||||
DVCLiveConfig,
|
||||
CSVLoggerConfig,
|
||||
TensorBoardLoggerConfig,
|
||||
MLFlowLoggerConfig,
|
||||
],
|
||||
DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -95,17 +90,17 @@ class LoggerBuilder(Protocol, Generic[T]):
|
||||
def __call__(
|
||||
self,
|
||||
config: T,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger: ...
|
||||
|
||||
|
||||
def create_dvclive_logger(
|
||||
config: DVCLiveConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||
@ -130,9 +125,9 @@ def create_dvclive_logger(
|
||||
|
||||
def create_csv_logger(
|
||||
config: CSVLoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
@ -159,9 +154,9 @@ def create_csv_logger(
|
||||
|
||||
def create_tensorboard_logger(
|
||||
config: TensorBoardLoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
@ -191,9 +186,9 @@ def create_tensorboard_logger(
|
||||
|
||||
def create_mlflow_logger(
|
||||
config: MLFlowLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: data.PathLike | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
@ -232,9 +227,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
|
||||
|
||||
def build_logger(
|
||||
config: LoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building logger with config: \n{}",
|
||||
@ -257,7 +252,7 @@ def build_logger(
|
||||
PlotLogger = Callable[[str, Figure, int], None]
|
||||
|
||||
|
||||
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return logger.experiment.add_figure
|
||||
|
||||
@ -282,7 +277,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||
|
||||
|
||||
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
||||
def get_table_logger(logger: Logger) -> TableLogger | None:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
@ -43,10 +43,7 @@ from batdetect2.models.bottleneck import (
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.config import (
|
||||
BackboneConfig,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.config import BackboneConfig, load_backbone_config
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
@ -122,10 +119,10 @@ class Model(torch.nn.Module):
|
||||
|
||||
|
||||
def build_model(
|
||||
config: Optional[BackboneConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
config: BackboneConfig | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
):
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
|
||||
@ -78,8 +78,8 @@ class Bottleneck(nn.Module):
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bottleneck_channels: Optional[int] = None,
|
||||
layers: Optional[List[torch.nn.Module]] = None,
|
||||
bottleneck_channels: int | None = None,
|
||||
layers: List[torch.nn.Module] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the base Bottleneck layer."""
|
||||
super().__init__()
|
||||
@ -127,7 +127,7 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
BottleneckLayerConfig = Annotated[
|
||||
Union[SelfAttentionConfig,],
|
||||
SelfAttentionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
@ -171,7 +171,7 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||
def build_bottleneck(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
config: Optional[BottleneckConfig] = None,
|
||||
config: BottleneckConfig | None = None,
|
||||
) -> nn.Module:
|
||||
"""Factory function to build the Bottleneck module from configuration.
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class BackboneConfig(BaseConfig):
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
|
||||
@ -41,12 +41,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
DecoderLayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
@ -216,7 +211,7 @@ convolutional block.
|
||||
def build_decoder(
|
||||
in_channels: int,
|
||||
input_height: int,
|
||||
config: Optional[DecoderConfig] = None,
|
||||
config: DecoderConfig | None = None,
|
||||
) -> Decoder:
|
||||
"""Factory function to build a Decoder instance from configuration.
|
||||
|
||||
|
||||
@ -127,7 +127,7 @@ class Detector(DetectionModel):
|
||||
|
||||
|
||||
def build_detector(
|
||||
num_classes: int, config: Optional[BackboneConfig] = None
|
||||
num_classes: int, config: BackboneConfig | None = None
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
|
||||
|
||||
@ -43,12 +43,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
EncoderLayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||
@ -252,7 +247,7 @@ Specifies an architecture typically used in BatDetect2:
|
||||
def build_encoder(
|
||||
in_channels: int,
|
||||
input_height: int,
|
||||
config: Optional[EncoderConfig] = None,
|
||||
config: EncoderConfig | None = None,
|
||||
) -> Encoder:
|
||||
"""Factory function to build an Encoder instance from configuration.
|
||||
|
||||
|
||||
@ -15,10 +15,10 @@ __all__ = [
|
||||
|
||||
def plot_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
add_points: bool = False,
|
||||
cmap: str = "gray",
|
||||
alpha: float = 1,
|
||||
@ -50,8 +50,8 @@ def plot_clip_annotation(
|
||||
def plot_anchor_points(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
targets: TargetProtocol,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
size: int = 1,
|
||||
color: str = "red",
|
||||
marker: str = "x",
|
||||
|
||||
@ -17,10 +17,10 @@ __all__ = [
|
||||
|
||||
def plot_clip_prediction(
|
||||
clip_prediction: data.ClipPrediction,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
add_legend: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
linewidth: float = 1,
|
||||
@ -50,14 +50,14 @@ def plot_clip_prediction(
|
||||
|
||||
def plot_predictions(
|
||||
predictions: Iterable[data.SoundEventPrediction],
|
||||
ax: Optional[Axes] = None,
|
||||
ax: Axes | None = None,
|
||||
position: Positions = "top-right",
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
color_mapper: TagColorMapper | None = None,
|
||||
time_offset: float = 0.001,
|
||||
freq_offset: float = 1000,
|
||||
legend: bool = True,
|
||||
max_alpha: float = 0.5,
|
||||
color: Optional[str] = None,
|
||||
color: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot an prediction."""
|
||||
@ -88,14 +88,14 @@ def plot_predictions(
|
||||
|
||||
def plot_prediction(
|
||||
prediction: data.SoundEventPrediction,
|
||||
ax: Optional[Axes] = None,
|
||||
ax: Axes | None = None,
|
||||
position: Positions = "top-right",
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
color_mapper: TagColorMapper | None = None,
|
||||
time_offset: float = 0.001,
|
||||
freq_offset: float = 1000,
|
||||
max_alpha: float = 0.5,
|
||||
alpha: Optional[float] = None,
|
||||
color: Optional[str] = None,
|
||||
alpha: float | None = None,
|
||||
color: str | None = None,
|
||||
**kwargs,
|
||||
) -> Axes:
|
||||
"""Plot an annotation."""
|
||||
|
||||
@ -17,11 +17,11 @@ __all__ = [
|
||||
|
||||
def plot_clip(
|
||||
clip: data.Clip,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
spec_cmap: str = "gray",
|
||||
) -> Axes:
|
||||
if ax is None:
|
||||
|
||||
@ -13,8 +13,8 @@ __all__ = [
|
||||
|
||||
|
||||
def create_ax(
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
@ -25,17 +25,17 @@ def create_ax(
|
||||
|
||||
|
||||
def plot_spectrogram(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
min_freq: Optional[float] = None,
|
||||
max_freq: Optional[float] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
spec: torch.Tensor | np.ndarray,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
min_freq: float | None = None,
|
||||
max_freq: float | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_colorbar: bool = False,
|
||||
colorbar_kwargs: Optional[dict] = None,
|
||||
vmin: Optional[float] = None,
|
||||
vmax: Optional[float] = None,
|
||||
colorbar_kwargs: dict | None = None,
|
||||
vmin: float | None = None,
|
||||
vmax: float | None = None,
|
||||
cmap="gray",
|
||||
) -> axes.Axes:
|
||||
if isinstance(spec, torch.Tensor):
|
||||
|
||||
@ -19,9 +19,9 @@ __all__ = [
|
||||
def plot_clip_detections(
|
||||
clip_eval: ClipEval,
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
ax: Optional[axes.Axes] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
threshold: float = 0.2,
|
||||
add_legend: bool = True,
|
||||
add_title: bool = True,
|
||||
|
||||
@ -20,11 +20,11 @@ def plot_match_gallery(
|
||||
false_positives: Sequence[MatchProtocol],
|
||||
false_negatives: Sequence[MatchProtocol],
|
||||
cross_triggers: Sequence[MatchProtocol],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
fig: Optional[Figure] = None,
|
||||
fig: Figure | None = None,
|
||||
):
|
||||
if fig is None:
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
@ -12,13 +12,13 @@ from batdetect2.plotting.common import create_ax
|
||||
|
||||
|
||||
def plot_detection_heatmap(
|
||||
heatmap: Union[torch.Tensor, np.ndarray],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
threshold: Optional[float] = None,
|
||||
threshold: float | None = None,
|
||||
alpha: float = 1,
|
||||
cmap: Union[str, Colormap] = "jet",
|
||||
color: Optional[str] = None,
|
||||
cmap: str | Colormap = "jet",
|
||||
color: str | None = None,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
@ -48,13 +48,13 @@ def plot_detection_heatmap(
|
||||
|
||||
|
||||
def plot_classification_heatmap(
|
||||
heatmap: Union[torch.Tensor, np.ndarray],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
class_names: Optional[List[str]] = None,
|
||||
threshold: Optional[float] = 0.1,
|
||||
class_names: List[str] | None = None,
|
||||
threshold: float | None = 0.1,
|
||||
alpha: float = 1,
|
||||
cmap: Union[str, Colormap] = "tab20",
|
||||
cmap: str | Colormap = "tab20",
|
||||
):
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
|
||||
@ -24,10 +24,10 @@ __all__ = [
|
||||
|
||||
|
||||
def spectrogram(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
spec: torch.Tensor | np.ndarray,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
cmap: str = "plasma",
|
||||
start_time: float = 0,
|
||||
) -> axes.Axes:
|
||||
@ -103,11 +103,11 @@ def spectrogram(
|
||||
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
spec: torch.Tensor | np.ndarray,
|
||||
dets: List[Annotation],
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
cmap: str = "plasma",
|
||||
with_names: bool = True,
|
||||
start_time: float = 0,
|
||||
@ -168,8 +168,8 @@ def spectrogram_with_detections(
|
||||
|
||||
def detections(
|
||||
dets: List[Annotation],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
with_names: bool = True,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
@ -213,8 +213,8 @@ def detections(
|
||||
|
||||
def detection(
|
||||
det: Annotation,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
linewidth: float = 1,
|
||||
edgecolor: str = "w",
|
||||
facecolor: str = "none",
|
||||
|
||||
@ -21,10 +21,10 @@ __all__ = [
|
||||
|
||||
class MatchProtocol(Protocol):
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
score: float
|
||||
true_class: Optional[str]
|
||||
true_class: str | None
|
||||
|
||||
|
||||
DEFAULT_DURATION = 0.05
|
||||
@ -38,11 +38,11 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||
|
||||
def plot_false_positive_match(
|
||||
match: MatchProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
@ -52,7 +52,7 @@ def plot_false_positive_match(
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
fontsize: float | str = "small",
|
||||
) -> Axes:
|
||||
assert match.pred is not None
|
||||
|
||||
@ -109,11 +109,11 @@ def plot_false_positive_match(
|
||||
|
||||
def plot_false_negative_match(
|
||||
match: MatchProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
@ -169,11 +169,11 @@ def plot_false_negative_match(
|
||||
|
||||
def plot_true_positive_match(
|
||||
match: MatchProtocol,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
@ -182,7 +182,7 @@ def plot_true_positive_match(
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
fontsize: float | str = "small",
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
add_title: bool = True,
|
||||
@ -257,11 +257,11 @@ def plot_true_positive_match(
|
||||
|
||||
def plot_cross_trigger_match(
|
||||
match: MatchProtocol,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
@ -271,7 +271,7 @@ def plot_cross_trigger_match(
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
fontsize: float | str = "small",
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
|
||||
@ -33,16 +33,16 @@ def plot_pr_curve(
|
||||
precision: np.ndarray,
|
||||
recall: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
color: Union[str, Tuple[float, float, float], None] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
color: str | Tuple[float, float, float] | None = None,
|
||||
add_labels: bool = True,
|
||||
add_legend: bool = False,
|
||||
marker: Union[str, Tuple[int, int, float], None] = "o",
|
||||
markeredgecolor: Union[str, Tuple[float, float, float], None] = None,
|
||||
markersize: Optional[float] = None,
|
||||
linestyle: Union[str, Tuple[int, ...], None] = None,
|
||||
linewidth: Optional[float] = None,
|
||||
marker: str | Tuple[int, int, float] | None = "o",
|
||||
markeredgecolor: str | Tuple[float, float, float] | None = None,
|
||||
markersize: float | None = None,
|
||||
linestyle: str | Tuple[int, ...] | None = None,
|
||||
linewidth: float | None = None,
|
||||
label: str = "PR Curve",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -77,8 +77,8 @@ def plot_pr_curve(
|
||||
|
||||
def plot_pr_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
include_ap: bool = False,
|
||||
@ -118,8 +118,8 @@ def plot_pr_curves(
|
||||
def plot_threshold_precision_curve(
|
||||
threshold: np.ndarray,
|
||||
precision: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -140,8 +140,8 @@ def plot_threshold_precision_curve(
|
||||
|
||||
def plot_threshold_precision_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
@ -176,8 +176,8 @@ def plot_threshold_precision_curves(
|
||||
def plot_threshold_recall_curve(
|
||||
threshold: np.ndarray,
|
||||
recall: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -198,8 +198,8 @@ def plot_threshold_recall_curve(
|
||||
|
||||
def plot_threshold_recall_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
@ -235,8 +235,8 @@ def plot_roc_curve(
|
||||
fpr: np.ndarray,
|
||||
tpr: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -261,8 +261,8 @@ def plot_roc_curve(
|
||||
|
||||
def plot_roc_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: Tuple[int, int] | None = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
|
||||
@ -57,7 +57,7 @@ class PostprocessConfig(BaseConfig):
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
|
||||
@ -88,9 +88,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction: RawPrediction,
|
||||
recording: data.Recording,
|
||||
targets: TargetProtocol,
|
||||
classification_threshold: Optional[
|
||||
float
|
||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
):
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
||||
@ -150,7 +148,7 @@ def get_class_tags(
|
||||
class_scores: np.ndarray,
|
||||
targets: TargetProtocol,
|
||||
top_class_only: bool = False,
|
||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Generate specific PredictedTags based on class scores and decoder.
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ def extract_detection_peaks(
|
||||
feature_heatmap: torch.Tensor,
|
||||
classification_heatmap: torch.Tensor,
|
||||
max_detections: int = 200,
|
||||
threshold: Optional[float] = None,
|
||||
threshold: float | None = None,
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
height = detection_heatmap.shape[-2]
|
||||
width = detection_heatmap.shape[-1]
|
||||
|
||||
@ -27,7 +27,7 @@ BatDetect2.
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ __all__ = [
|
||||
|
||||
def build_postprocessor(
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
config: PostprocessConfig | None = None,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor."""
|
||||
config = config or PostprocessConfig()
|
||||
@ -51,7 +51,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
@ -66,7 +66,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
def forward(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
start_times: List[float] | None = None,
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
|
||||
@ -31,14 +31,14 @@ __all__ = [
|
||||
|
||||
|
||||
def to_xarray(
|
||||
array: Union[torch.Tensor, np.ndarray],
|
||||
array: torch.Tensor | np.ndarray,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
name: str = "xarray",
|
||||
extra_dims: Optional[List[str]] = None,
|
||||
extra_coords: Optional[Dict[str, np.ndarray]] = None,
|
||||
extra_dims: List[str] | None = None,
|
||||
extra_coords: Dict[str, np.ndarray] | None = None,
|
||||
) -> xr.DataArray:
|
||||
if isinstance(array, torch.Tensor):
|
||||
array = array.detach().cpu().numpy()
|
||||
|
||||
@ -78,11 +78,7 @@ class FixDuration(torch.nn.Module):
|
||||
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
FixDurationConfig,
|
||||
ScaleAudioConfig,
|
||||
CenterAudioConfig,
|
||||
],
|
||||
FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -57,6 +57,6 @@ class PreprocessingConfig(BaseConfig):
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
|
||||
@ -102,7 +102,7 @@ def compute_output_samplerate(
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
config: PreprocessingConfig | None = None,
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
|
||||
@ -98,7 +98,7 @@ def _frequency_to_index(
|
||||
freq: float,
|
||||
n_fft: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> Optional[int]:
|
||||
) -> int | None:
|
||||
alpha = freq * 2 / samplerate
|
||||
height = np.floor(n_fft / 2) + 1
|
||||
index = int(np.floor(alpha * height))
|
||||
@ -134,8 +134,8 @@ class FrequencyCrop(torch.nn.Module):
|
||||
self,
|
||||
samplerate: int,
|
||||
n_fft: int,
|
||||
min_freq: Optional[int] = None,
|
||||
max_freq: Optional[int] = None,
|
||||
min_freq: int | None = None,
|
||||
max_freq: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
@ -181,7 +181,7 @@ class FrequencyCrop(torch.nn.Module):
|
||||
|
||||
def build_spectrogram_crop(
|
||||
config: FrequencyConfig,
|
||||
stft: Optional[STFTConfig] = None,
|
||||
stft: STFTConfig | None = None,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
stft = stft or STFTConfig()
|
||||
@ -377,12 +377,7 @@ class PeakNormalize(torch.nn.Module):
|
||||
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
PeakNormalizeConfig,
|
||||
],
|
||||
PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -30,16 +30,16 @@ class TargetClassConfig(BaseConfig):
|
||||
|
||||
name: str
|
||||
|
||||
condition_input: Optional[SoundEventConditionConfig] = Field(
|
||||
condition_input: SoundEventConditionConfig | None = Field(
|
||||
alias="match_if",
|
||||
default=None,
|
||||
)
|
||||
|
||||
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
|
||||
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
||||
|
||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||
|
||||
roi: Optional[ROIMapperConfig] = None
|
||||
roi: ROIMapperConfig | None = None
|
||||
|
||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||
|
||||
@ -202,7 +202,7 @@ class SoundEventClassifier:
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
for name, condition in self.mapping.items():
|
||||
if condition(sound_event_annotation):
|
||||
return name
|
||||
|
||||
@ -48,7 +48,7 @@ class TargetConfig(BaseConfig):
|
||||
|
||||
def load_target_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> TargetConfig:
|
||||
"""Load the unified target configuration from a file.
|
||||
|
||||
|
||||
@ -414,10 +414,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
|
||||
|
||||
ROIMapperConfig = Annotated[
|
||||
Union[
|
||||
AnchorBBoxMapperConfig,
|
||||
PeakEnergyBBoxMapperConfig,
|
||||
],
|
||||
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""A discriminated union of all supported ROI mapper configurations.
|
||||
@ -428,7 +425,7 @@ implementations by using the `name` field as a discriminator.
|
||||
|
||||
|
||||
def build_roi_mapper(
|
||||
config: Optional[ROIMapperConfig] = None,
|
||||
config: ROIMapperConfig | None = None,
|
||||
) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from a config object.
|
||||
|
||||
@ -572,9 +569,9 @@ def get_peak_energy_coordinates(
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
start_time: float = 0,
|
||||
end_time: Optional[float] = None,
|
||||
end_time: float | None = None,
|
||||
low_freq: float = 0,
|
||||
high_freq: Optional[float] = None,
|
||||
high_freq: float | None = None,
|
||||
loading_buffer: float = 0.05,
|
||||
) -> Position:
|
||||
"""Find the coordinates of the highest energy point in a spectrogram.
|
||||
|
||||
@ -107,7 +107,7 @@ class Targets(TargetProtocol):
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
@ -182,7 +182,7 @@ class Targets(TargetProtocol):
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
class_name: str | None = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
@ -219,7 +219,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
)
|
||||
|
||||
|
||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||
def build_targets(config: TargetConfig | None = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
Parameters
|
||||
@ -251,7 +251,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
@ -292,7 +292,7 @@ def load_targets(
|
||||
def iterate_encoded_sound_events(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
||||
) -> Iterable[Tuple[str | None, Position, Size]]:
|
||||
for sound_event in sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
@ -42,7 +42,7 @@ __all__ = [
|
||||
|
||||
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
||||
|
||||
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
|
||||
audio_augmentations: Registry[Augmentation, [int, AudioSource | None]] = (
|
||||
Registry(name="audio_augmentation")
|
||||
)
|
||||
|
||||
@ -103,7 +103,7 @@ class MixAudio(torch.nn.Module):
|
||||
def from_config(
|
||||
config: MixAudioConfig,
|
||||
samplerate: int,
|
||||
source: Optional[AudioSource],
|
||||
source: AudioSource | None,
|
||||
):
|
||||
if source is None:
|
||||
warnings.warn(
|
||||
@ -207,7 +207,7 @@ class AddEcho(torch.nn.Module):
|
||||
def from_config(
|
||||
config: AddEchoConfig,
|
||||
samplerate: int,
|
||||
source: Optional[AudioSource],
|
||||
source: AudioSource | None,
|
||||
):
|
||||
return AddEcho(
|
||||
samplerate=samplerate,
|
||||
@ -487,33 +487,18 @@ def mask_frequency(
|
||||
|
||||
|
||||
AudioAugmentationConfig = Annotated[
|
||||
Union[
|
||||
MixAudioConfig,
|
||||
AddEchoConfig,
|
||||
],
|
||||
MixAudioConfig | AddEchoConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
SpectrogramAugmentationConfig = Annotated[
|
||||
Union[
|
||||
ScaleVolumeConfig,
|
||||
WarpConfig,
|
||||
MaskFrequencyConfig,
|
||||
MaskTimeConfig,
|
||||
],
|
||||
ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
AugmentationConfig = Annotated[
|
||||
Union[
|
||||
MixAudioConfig,
|
||||
AddEchoConfig,
|
||||
ScaleVolumeConfig,
|
||||
WarpConfig,
|
||||
MaskFrequencyConfig,
|
||||
MaskTimeConfig,
|
||||
],
|
||||
MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of individual augmentation config."""
|
||||
@ -559,8 +544,8 @@ class MaybeApply(torch.nn.Module):
|
||||
def build_augmentation_from_config(
|
||||
config: AugmentationConfig,
|
||||
samplerate: int,
|
||||
audio_source: Optional[AudioSource] = None,
|
||||
) -> Optional[Augmentation]:
|
||||
audio_source: AudioSource | None = None,
|
||||
) -> Augmentation | None:
|
||||
"""Factory function to build a single augmentation from its config."""
|
||||
if config.name == "mix_audio":
|
||||
if audio_source is None:
|
||||
@ -645,10 +630,10 @@ class AugmentationSequence(torch.nn.Module):
|
||||
|
||||
|
||||
def build_audio_augmentations(
|
||||
steps: Optional[Sequence[AudioAugmentationConfig]] = None,
|
||||
steps: Sequence[AudioAugmentationConfig] | None = None,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
audio_source: Optional[AudioSource] = None,
|
||||
) -> Optional[Augmentation]:
|
||||
audio_source: AudioSource | None = None,
|
||||
) -> Augmentation | None:
|
||||
if not steps:
|
||||
return None
|
||||
|
||||
@ -673,8 +658,8 @@ def build_audio_augmentations(
|
||||
|
||||
|
||||
def build_spectrogram_augmentations(
|
||||
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
|
||||
) -> Optional[Augmentation]:
|
||||
steps: Sequence[SpectrogramAugmentationConfig] | None = None,
|
||||
) -> Augmentation | None:
|
||||
if not steps:
|
||||
return None
|
||||
|
||||
@ -698,9 +683,9 @@ def build_spectrogram_augmentations(
|
||||
|
||||
def build_augmentations(
|
||||
samplerate: int,
|
||||
config: Optional[AugmentationsConfig] = None,
|
||||
audio_source: Optional[AudioSource] = None,
|
||||
) -> Tuple[Optional[Augmentation], Optional[Augmentation]]:
|
||||
config: AugmentationsConfig | None = None,
|
||||
audio_source: AudioSource | None = None,
|
||||
) -> Tuple[Augmentation | None, Augmentation | None]:
|
||||
"""Build a composite augmentation pipeline function from configuration."""
|
||||
config = config or DEFAULT_AUGMENTATION_CONFIG
|
||||
|
||||
@ -723,7 +708,7 @@ def build_augmentations(
|
||||
|
||||
|
||||
def load_augmentation_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
path: data.PathLike, field: str | None = None
|
||||
) -> AugmentationsConfig:
|
||||
"""Load the augmentations configuration from a file."""
|
||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||
|
||||
@ -18,14 +18,14 @@ class CheckpointConfig(BaseConfig):
|
||||
monitor: str = "classification/mean_average_precision"
|
||||
mode: str = "max"
|
||||
save_top_k: int = 1
|
||||
filename: Optional[str] = None
|
||||
filename: str | None = None
|
||||
|
||||
|
||||
def build_checkpoint_callback(
|
||||
config: Optional[CheckpointConfig] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
config: CheckpointConfig | None = None,
|
||||
checkpoint_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Callback:
|
||||
config = config or CheckpointConfig()
|
||||
|
||||
|
||||
@ -22,20 +22,20 @@ class PLTrainerConfig(BaseConfig):
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
check_val_every_n_epoch: int = 1
|
||||
devices: Union[str, int] = "auto"
|
||||
devices: str | int = "auto"
|
||||
enable_checkpointing: bool = True
|
||||
gradient_clip_val: Optional[float] = None
|
||||
limit_train_batches: Optional[Union[int, float]] = None
|
||||
limit_test_batches: Optional[Union[int, float]] = None
|
||||
limit_val_batches: Optional[Union[int, float]] = None
|
||||
log_every_n_steps: Optional[int] = None
|
||||
max_epochs: Optional[int] = 200
|
||||
min_epochs: Optional[int] = None
|
||||
max_steps: Optional[int] = None
|
||||
min_steps: Optional[int] = None
|
||||
max_time: Optional[str] = None
|
||||
precision: Optional[str] = None
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
gradient_clip_val: float | None = None
|
||||
limit_train_batches: int | float | None = None
|
||||
limit_test_batches: int | float | None = None
|
||||
limit_val_batches: int | float | None = None
|
||||
log_every_n_steps: int | None = None
|
||||
max_epochs: int | None = 200
|
||||
min_epochs: int | None = None
|
||||
max_steps: int | None = None
|
||||
min_steps: int | None = None
|
||||
max_time: str | None = None
|
||||
precision: str | None = None
|
||||
val_check_interval: int | float | None = None
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
@ -57,6 +57,6 @@ class TrainingConfig(BaseConfig):
|
||||
|
||||
def load_train_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
||||
|
||||
@ -44,10 +44,10 @@ class TrainingDataset(Dataset):
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_augmentation: Optional[Augmentation] = None,
|
||||
spectrogram_augmentation: Optional[Augmentation] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
clipper: ClipperProtocol | None = None,
|
||||
audio_augmentation: Augmentation | None = None,
|
||||
spectrogram_augmentation: Augmentation | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
):
|
||||
self.clip_annotations = clip_annotations
|
||||
self.clipper = clipper
|
||||
@ -108,8 +108,8 @@ class ValidationDataset(Dataset):
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
clipper: ClipperProtocol | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
):
|
||||
self.clip_annotations = clip_annotations
|
||||
self.labeller = labeller
|
||||
@ -165,11 +165,11 @@ class TrainLoaderConfig(BaseConfig):
|
||||
|
||||
def build_train_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TrainLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
labeller: ClipLabeller | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TrainLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> DataLoader:
|
||||
config = config or TrainLoaderConfig()
|
||||
|
||||
@ -207,11 +207,11 @@ class ValLoaderConfig(BaseConfig):
|
||||
|
||||
def build_val_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[ValLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
labeller: ClipLabeller | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: ValLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
):
|
||||
logger.info("Building validation data loader...")
|
||||
config = config or ValLoaderConfig()
|
||||
@ -240,10 +240,10 @@ def build_val_loader(
|
||||
|
||||
def build_train_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TrainLoaderConfig] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
labeller: ClipLabeller | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TrainLoaderConfig | None = None,
|
||||
) -> TrainingDataset:
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TrainLoaderConfig()
|
||||
@ -291,10 +291,10 @@ def build_train_dataset(
|
||||
|
||||
def build_val_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[ValLoaderConfig] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
labeller: ClipLabeller | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: ValLoaderConfig | None = None,
|
||||
) -> ValidationDataset:
|
||||
logger.info("Building validation dataset...")
|
||||
config = config or ValLoaderConfig()
|
||||
|
||||
@ -42,10 +42,10 @@ class LabelConfig(BaseConfig):
|
||||
|
||||
|
||||
def build_clip_labeler(
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
config: Optional[LabelConfig] = None,
|
||||
config: LabelConfig | None = None,
|
||||
) -> ClipLabeller:
|
||||
"""Construct the final clip labelling function."""
|
||||
config = config or LabelConfig()
|
||||
@ -153,7 +153,7 @@ def generate_heatmaps(
|
||||
|
||||
|
||||
def load_label_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
path: data.PathLike, field: str | None = None
|
||||
) -> LabelConfig:
|
||||
"""Load the heatmap label generation configuration from a file.
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: LabeledDataset[TrainInputs],
|
||||
validation_dataset: LabeledDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
device: torch.device | None = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
|
||||
@ -106,10 +106,10 @@ def standardize_low_freq(
|
||||
|
||||
def format_annotation(
|
||||
annotation: types.FileAnnotation,
|
||||
events_of_interest: Optional[List[str]] = None,
|
||||
name_replace: Optional[Dict[str, str]] = None,
|
||||
events_of_interest: List[str] | None = None,
|
||||
name_replace: Dict[str, str] | None = None,
|
||||
convert_to_genus: bool = False,
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
) -> types.FileAnnotation:
|
||||
formated = []
|
||||
for aa in annotation["annotation"]:
|
||||
@ -154,7 +154,7 @@ def format_annotation(
|
||||
|
||||
def get_class_names(
|
||||
data: List[types.FileAnnotation],
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
) -> Tuple[StringCounter, List[float]]:
|
||||
"""Extracts class names and their inverse frequencies.
|
||||
|
||||
@ -201,9 +201,9 @@ def load_set_of_anns(
|
||||
*,
|
||||
convert_to_genus: bool = False,
|
||||
filter_issues: bool = False,
|
||||
events_of_interest: Optional[List[str]] = None,
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
name_replace: Optional[Dict[str, str]] = None,
|
||||
events_of_interest: List[str] | None = None,
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
name_replace: Dict[str, str] | None = None,
|
||||
) -> List[types.FileAnnotation]:
|
||||
# load the annotations
|
||||
anns = []
|
||||
|
||||
@ -26,10 +26,10 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[dict] = None,
|
||||
config: dict | None = None,
|
||||
t_max: int = 100,
|
||||
model: Optional[Model] = None,
|
||||
loss: Optional[torch.nn.Module] = None,
|
||||
model: Model | None = None,
|
||||
loss: torch.nn.Module | None = None,
|
||||
):
|
||||
from batdetect2.config import validate_config
|
||||
|
||||
@ -103,7 +103,7 @@ def load_model_from_checkpoint(
|
||||
|
||||
|
||||
def build_training_module(
|
||||
config: Optional[dict] = None,
|
||||
config: dict | None = None,
|
||||
t_max: int = 200,
|
||||
) -> TrainingModule:
|
||||
return TrainingModule(config=config, t_max=t_max)
|
||||
|
||||
@ -151,7 +151,7 @@ class FocalLoss(nn.Module):
|
||||
eps: float = 1e-5,
|
||||
beta: float = 4,
|
||||
alpha: float = 2,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
class_weights: torch.Tensor | None = None,
|
||||
mask_zero: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -422,8 +422,8 @@ class LossFunction(nn.Module, LossProtocol):
|
||||
|
||||
|
||||
def build_loss(
|
||||
config: Optional[LossConfig] = None,
|
||||
class_weights: Optional[np.ndarray] = None,
|
||||
config: LossConfig | None = None,
|
||||
class_weights: np.ndarray | None = None,
|
||||
) -> nn.Module:
|
||||
"""Factory function to build the main LossFunction from configuration.
|
||||
|
||||
|
||||
@ -35,21 +35,21 @@ __all__ = [
|
||||
|
||||
def train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
trainer: Optional[Trainer] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
trainer: Trainer | None = None,
|
||||
train_workers: int | None = None,
|
||||
val_workers: int | None = None,
|
||||
checkpoint_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
):
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
@ -126,11 +126,11 @@ def train(
|
||||
def build_trainer(
|
||||
config: "BatDetect2Config",
|
||||
evaluator: "EvaluatorProtocol",
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
checkpoint_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = config.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
|
||||
@ -240,7 +240,7 @@ class ProcessingConfiguration(TypedDict):
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
time_expansion: Optional[float]
|
||||
time_expansion: float | None
|
||||
"""Time expansion factor of the processed recordings."""
|
||||
|
||||
top_n: int
|
||||
@ -249,7 +249,7 @@ class ProcessingConfiguration(TypedDict):
|
||||
return_raw_preds: bool
|
||||
"""Whether to return raw predictions."""
|
||||
|
||||
max_duration: Optional[float]
|
||||
max_duration: float | None
|
||||
"""Maximum duration of audio file to process in seconds."""
|
||||
|
||||
nms_kernel_size: int
|
||||
|
||||
@ -20,7 +20,7 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||
self,
|
||||
predictions: Sequence[T],
|
||||
path: PathLike,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
audio_dir: PathLike | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def load(self, path: PathLike) -> List[T]: ...
|
||||
|
||||
@ -28,19 +28,19 @@ __all__ = [
|
||||
class MatchEvaluation:
|
||||
clip: data.Clip
|
||||
|
||||
sound_event_annotation: Optional[data.SoundEventAnnotation]
|
||||
sound_event_annotation: data.SoundEventAnnotation | None
|
||||
gt_det: bool
|
||||
gt_class: Optional[str]
|
||||
gt_geometry: Optional[data.Geometry]
|
||||
gt_class: str | None
|
||||
gt_geometry: data.Geometry | None
|
||||
|
||||
pred_score: float
|
||||
pred_class_scores: Dict[str, float]
|
||||
pred_geometry: Optional[data.Geometry]
|
||||
pred_geometry: data.Geometry | None
|
||||
|
||||
affinity: float
|
||||
|
||||
@property
|
||||
def top_class(self) -> Optional[str]:
|
||||
def top_class(self) -> str | None:
|
||||
if not self.pred_class_scores:
|
||||
return None
|
||||
|
||||
@ -76,7 +76,7 @@ class MatcherProtocol(Protocol):
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ...
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]: ...
|
||||
|
||||
|
||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||
|
||||
@ -42,7 +42,7 @@ class GeometryDecoder(Protocol):
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, position: Position, size: Size, class_name: Optional[str] = None
|
||||
self, position: Position, size: Size, class_name: str | None = None
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
@ -93,5 +93,5 @@ class PostprocessorProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[Sequence[float]] = None,
|
||||
start_times: Sequence[float] | None = None,
|
||||
) -> List[ClipDetectionsTensor]: ...
|
||||
|
||||
@ -37,7 +37,7 @@ class AudioLoader(Protocol):
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path.
|
||||
|
||||
@ -60,7 +60,7 @@ class AudioLoader(Protocol):
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object.
|
||||
|
||||
@ -90,7 +90,7 @@ class AudioLoader(Protocol):
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object.
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user