Update type hints to python 3.10

This commit is contained in:
mbsantiago 2025-12-08 17:14:50 +00:00
parent 9c72537ddd
commit 2563f26ed3
104 changed files with 525 additions and 664 deletions

View File

@ -93,7 +93,7 @@ mlflow = ["mlflow>=3.1.1"]
[tool.ruff] [tool.ruff]
line-length = 79 line-length = 79
target-version = "py39" target-version = "py310"
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true
@ -107,7 +107,7 @@ convention = "numpy"
[tool.pyright] [tool.pyright]
include = ["src", "tests"] include = ["src", "tests"]
pythonVersion = "3.9" pythonVersion = "3.10"
pythonPlatform = "All" pythonPlatform = "All"
exclude = [ exclude = [
"src/batdetect2/detector/", "src/batdetect2/detector/",

View File

@ -165,7 +165,7 @@ def load_audio(
time_exp_fact: float = 1, time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ, target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False, scale: bool = False,
max_duration: Optional[float] = None, max_duration: float | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load audio from file. """Load audio from file.
@ -203,7 +203,7 @@ def load_audio(
def generate_spectrogram( def generate_spectrogram(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[SpectrogramParameters] = None, config: SpectrogramParameters | None = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> torch.Tensor: ) -> torch.Tensor:
"""Generate spectrogram from audio array. """Generate spectrogram from audio array.
@ -240,7 +240,7 @@ def generate_spectrogram(
def process_file( def process_file(
audio_file: str, audio_file: str,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> du.RunResults: ) -> du.RunResults:
"""Process audio file with model. """Process audio file with model.
@ -271,7 +271,7 @@ def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]: ) -> Tuple[List[Annotation], np.ndarray]:
"""Process spectrogram with model. """Process spectrogram with model.
@ -312,7 +312,7 @@ def process_audio(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]: ) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model. """Process audio array with model.
@ -356,7 +356,7 @@ def process_audio(
def postprocess( def postprocess(
outputs: ModelOutput, outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]: ) -> Tuple[List[Annotation], np.ndarray]:
"""Postprocess model outputs. """Postprocess model outputs.

View File

@ -67,22 +67,22 @@ class BatDetect2API:
def load_annotations( def load_annotations(
self, self,
path: data.PathLike, path: data.PathLike,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> Dataset: ) -> Dataset:
return load_dataset_from_config(path, base_dir=base_dir) return load_dataset_from_config(path, base_dir=base_dir)
def train( def train(
self, self,
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
train_workers: Optional[int] = None, train_workers: int | None = None,
val_workers: Optional[int] = None, val_workers: int | None = None,
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR, checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR, log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
num_epochs: Optional[int] = None, num_epochs: int | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
seed: Optional[int] = None, seed: int | None = None,
): ):
train( train(
train_annotations=train_annotations, train_annotations=train_annotations,
@ -105,10 +105,10 @@ class BatDetect2API:
def evaluate( def evaluate(
self, self,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None, num_workers: int | None = None,
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
save_predictions: bool = True, save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: ) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
return evaluate( return evaluate(
@ -129,7 +129,7 @@ class BatDetect2API:
self, self,
annotations: Sequence[data.ClipAnnotation], annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[BatDetect2Prediction],
output_dir: Optional[data.PathLike] = None, output_dir: data.PathLike | None = None,
): ):
clip_evals = self.evaluator.evaluate( clip_evals = self.evaluator.evaluate(
annotations, annotations,
@ -221,7 +221,7 @@ class BatDetect2API:
def process_files( def process_files(
self, self,
audio_files: Sequence[data.PathLike], audio_files: Sequence[data.PathLike],
num_workers: Optional[int] = None, num_workers: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[BatDetect2Prediction]:
return process_file_list( return process_file_list(
self.model, self.model,
@ -236,8 +236,8 @@ class BatDetect2API:
def process_clips( def process_clips(
self, self,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
batch_size: Optional[int] = None, batch_size: int | None = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[BatDetect2Prediction]:
return run_batch_inference( return run_batch_inference(
self.model, self.model,
@ -254,9 +254,9 @@ class BatDetect2API:
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[BatDetect2Prediction],
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
format: Optional[str] = None, format: str | None = None,
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
): ):
formatter = self.formatter formatter = self.formatter
@ -331,7 +331,7 @@ class BatDetect2API:
def from_checkpoint( def from_checkpoint(
cls, cls,
path: data.PathLike, path: data.PathLike,
config: Optional[BatDetect2Config] = None, config: BatDetect2Config | None = None,
): ):
model, stored_config = load_model_from_checkpoint(path) model, stored_config = load_model_from_checkpoint(path)

View File

@ -245,16 +245,12 @@ class FixedDurationClip:
ClipConfig = Annotated[ ClipConfig = Annotated[
Union[ RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
RandomClipConfig,
PaddedClipConfig,
FixedDurationClipConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol: def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
config = config or RandomClipConfig() config = config or RandomClipConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(

View File

@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
resample: ResampleConfig = Field(default_factory=ResampleConfig) resample: ResampleConfig = Field(default_factory=ResampleConfig)
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader: def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration.""" """Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig() config = config or AudioConfig()
return SoundEventAudioLoader( return SoundEventAudioLoader(
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
def __init__( def __init__(
self, self,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
): ):
self.samplerate = samplerate self.samplerate = samplerate
self.config = config or ResampleConfig() self.config = config or ResampleConfig()
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_file( def load_file(
self, self,
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess audio directly from a file path.""" """Load and preprocess audio directly from a file path."""
return load_file_audio( return load_file_audio(
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_recording( def load_recording(
self, self,
recording: data.Recording, recording: data.Recording,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object.""" """Load and preprocess the entire audio for a Recording object."""
return load_recording_audio( return load_recording_audio(
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_clip( def load_clip(
self, self,
clip: data.Clip, clip: data.Clip,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object.""" """Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio( return load_clip_audio(
@ -112,9 +112,9 @@ class SoundEventAudioLoader(AudioLoader):
def load_file_audio( def load_file_audio(
path: data.PathLike, path: data.PathLike,
samplerate: Optional[int] = None, samplerate: int | None = None,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config.""" """Load and preprocess audio from a file path using specified config."""
@ -136,9 +136,9 @@ def load_file_audio(
def load_recording_audio( def load_recording_audio(
recording: data.Recording, recording: data.Recording,
samplerate: Optional[int] = None, samplerate: int | None = None,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config.""" """Load and preprocess the entire audio content of a recording using config."""
@ -158,9 +158,9 @@ def load_recording_audio(
def load_clip_audio( def load_clip_audio(
clip: data.Clip, clip: data.Clip,
samplerate: Optional[int] = None, samplerate: int | None = None,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config.""" """Load and preprocess a specific audio clip segment based on config."""

View File

@ -34,9 +34,9 @@ def data(): ...
) )
def summary( def summary(
dataset_config: Path, dataset_config: Path,
field: Optional[str] = None, field: str | None = None,
targets_path: Optional[Path] = None, targets_path: Path | None = None,
base_dir: Optional[Path] = None, base_dir: Path | None = None,
): ):
from batdetect2.data import compute_class_summary, load_dataset_from_config from batdetect2.data import compute_class_summary, load_dataset_from_config
from batdetect2.targets import load_targets from batdetect2.targets import load_targets
@ -83,9 +83,9 @@ def summary(
) )
def convert( def convert(
dataset_config: Path, dataset_config: Path,
field: Optional[str] = None, field: str | None = None,
output: Path = Path("annotations.json"), output: Path = Path("annotations.json"),
base_dir: Optional[Path] = None, base_dir: Path | None = None,
): ):
"""Convert a dataset config file to soundevent format.""" """Convert a dataset config file to soundevent format."""
from soundevent import data, io from soundevent import data, io

View File

@ -25,11 +25,11 @@ def evaluate_command(
model_path: Path, model_path: Path,
test_dataset: Path, test_dataset: Path,
base_dir: Path, base_dir: Path,
config_path: Optional[Path], config_path: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR, output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None, num_workers: int | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config from batdetect2.config import load_full_config

View File

@ -26,19 +26,19 @@ __all__ = ["train_command"]
@click.option("--seed", type=int) @click.option("--seed", type=int)
def train_command( def train_command(
train_dataset: Path, train_dataset: Path,
val_dataset: Optional[Path] = None, val_dataset: Path | None = None,
model_path: Optional[Path] = None, model_path: Path | None = None,
ckpt_dir: Optional[Path] = None, ckpt_dir: Path | None = None,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
config: Optional[Path] = None, config: Path | None = None,
targets_config: Optional[Path] = None, targets_config: Path | None = None,
config_field: Optional[str] = None, config_field: str | None = None,
seed: Optional[int] = None, seed: int | None = None,
num_epochs: Optional[int] = None, num_epochs: int | None = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import ( from batdetect2.config import (

View File

@ -4,7 +4,7 @@ import json
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Union from typing import Callable, List
import numpy as np import numpy as np
from soundevent import data from soundevent import data
@ -17,7 +17,7 @@ from batdetect2.types import (
FileAnnotation, FileAnnotation,
) )
PathLike = Union[Path, str, os.PathLike] PathLike = Path | str | os.PathLike
__all__ = [ __all__ = [
"convert_to_annotation_group", "convert_to_annotation_group",
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242") NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int] ClassFn = Callable[[data.Recording], int]
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
def file_annotation_to_clip( def file_annotation_to_clip(
file_annotation: FileAnnotation, file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None, audio_dir: PathLike | None = None,
label_key: str = "class", label_key: str = "class",
) -> data.Clip: ) -> data.Clip:
"""Convert file annotation to recording.""" """Convert file annotation to recording."""

View File

@ -43,7 +43,7 @@ class BatDetect2Config(BaseConfig):
output: OutputFormatConfig = Field(default_factory=RawOutputConfig) output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
def validate_config(config: Optional[dict]) -> BatDetect2Config: def validate_config(config: dict | None) -> BatDetect2Config:
if config is None: if config is None:
return BatDetect2Config() return BatDetect2Config()
@ -52,6 +52,6 @@ def validate_config(config: Optional[dict]) -> BatDetect2Config:
def load_full_config( def load_full_config(
path: PathLike, path: PathLike,
field: Optional[str] = None, field: str | None = None,
) -> BatDetect2Config: ) -> BatDetect2Config:
return load_config(path, schema=BatDetect2Config, field=field) return load_config(path, schema=BatDetect2Config, field=field)

View File

@ -86,8 +86,8 @@ def adjust_width(
def slice_tensor( def slice_tensor(
tensor: torch.Tensor, tensor: torch.Tensor,
start: Optional[int] = None, start: int | None = None,
end: Optional[int] = None, end: int | None = None,
dim: int = -1, dim: int = -1,
) -> torch.Tensor: ) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim slices = [slice(None)] * tensor.ndim

View File

@ -128,7 +128,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
def load_config( def load_config(
path: PathLike, path: PathLike,
schema: Type[T], schema: Type[T],
field: Optional[str] = None, field: str | None = None,
) -> T: ) -> T:
"""Load and validate configuration data from a file against a schema. """Load and validate configuration data from a file against a schema.

View File

@ -43,11 +43,7 @@ __all__ = [
AnnotationFormats = Annotated[ AnnotationFormats = Annotated[
Union[ BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
],
Field(discriminator="format"), Field(discriminator="format"),
] ]
"""Type Alias representing all supported data source configurations. """Type Alias representing all supported data source configurations.
@ -63,7 +59,7 @@ source configuration represents.
def load_annotated_dataset( def load_annotated_dataset(
dataset: AnnotatedDataset, dataset: AnnotatedDataset,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration. """Load annotations for a single data source based on its configuration.

View File

@ -77,14 +77,14 @@ class AOEFAnnotations(AnnotatedDataset):
annotations_path: Path annotations_path: Path
filter: Optional[AnnotationTaskFilter] = Field( filter: AnnotationTaskFilter | None = Field(
default_factory=AnnotationTaskFilter default_factory=AnnotationTaskFilter
) )
def load_aoef_annotated_dataset( def load_aoef_annotated_dataset(
dataset: AOEFAnnotations, dataset: AOEFAnnotations,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load annotations from an AnnotationSet or AnnotationProject file. """Load annotations from an AnnotationSet or AnnotationProject file.

View File

@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal
from loguru import logger from loguru import logger
from pydantic import Field, ValidationError from pydantic import Field, ValidationError
@ -43,7 +43,7 @@ from batdetect2.data.annotations.legacy import (
) )
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.types import AnnotatedDataset
PathLike = Union[Path, str, os.PathLike] PathLike = Path | str | os.PathLike
__all__ = [ __all__ = [
@ -102,7 +102,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
format: Literal["batdetect2"] = "batdetect2" format: Literal["batdetect2"] = "batdetect2"
annotations_dir: Path annotations_dir: Path
filter: Optional[AnnotationFilter] = Field( filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter, default_factory=AnnotationFilter,
) )
@ -133,14 +133,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
format: Literal["batdetect2_file"] = "batdetect2_file" format: Literal["batdetect2_file"] = "batdetect2_file"
annotations_path: Path annotations_path: Path
filter: Optional[AnnotationFilter] = Field( filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter, default_factory=AnnotationFilter,
) )
def load_batdetect2_files_annotated_dataset( def load_batdetect2_files_annotated_dataset(
dataset: BatDetect2FilesAnnotations, dataset: BatDetect2FilesAnnotations,
base_dir: Optional[PathLike] = None, base_dir: PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet. """Load and convert 'batdetect2_file' annotations into an AnnotationSet.
@ -244,7 +244,7 @@ def load_batdetect2_files_annotated_dataset(
def load_batdetect2_merged_annotated_dataset( def load_batdetect2_merged_annotated_dataset(
dataset: BatDetect2MergedAnnotations, dataset: BatDetect2MergedAnnotations,
base_dir: Optional[PathLike] = None, base_dir: PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet. """Load and convert 'batdetect2_merged' annotations into an AnnotationSet.

View File

@ -3,12 +3,12 @@
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Union from typing import Callable, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from soundevent import data from soundevent import data
PathLike = Union[Path, str, os.PathLike] PathLike = Path | str | os.PathLike
__all__ = [] __all__ = []
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
) )
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int] ClassFn = Callable[[data.Recording], int]
@ -130,7 +130,7 @@ def get_sound_event_tags(
def file_annotation_to_clip( def file_annotation_to_clip(
file_annotation: FileAnnotation, file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None, audio_dir: PathLike | None = None,
label_key: str = "class", label_key: str = "class",
) -> data.Clip: ) -> data.Clip:
"""Convert file annotation to recording.""" """Convert file annotation to recording."""

View File

@ -264,16 +264,7 @@ class Not:
SoundEventConditionConfig = Annotated[ SoundEventConditionConfig = Annotated[
Union[ HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig,
HasTagConfig,
HasAllTagsConfig,
HasAnyTagConfig,
DurationConfig,
FrequencyConfig,
AllOfConfig,
AnyOfConfig,
NotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
description: str description: str
sources: List[AnnotationFormats] sources: List[AnnotationFormats]
sound_event_filter: Optional[SoundEventConditionConfig] = None sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: List[SoundEventTransformConfig] = Field( sound_event_transforms: List[SoundEventTransformConfig] = Field(
default_factory=list default_factory=list
) )
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
def load_dataset( def load_dataset(
config: DatasetConfig, config: DatasetConfig,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> Dataset: ) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.""" """Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = [] clip_annotations = []
@ -161,14 +161,14 @@ def insert_source_tag(
) )
def load_dataset_config(path: data.PathLike, field: Optional[str] = None): def load_dataset_config(path: data.PathLike, field: str | None = None):
return load_config(path=path, schema=DatasetConfig, field=field) return load_config(path=path, schema=DatasetConfig, field=field)
def load_dataset_from_config( def load_dataset_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> Dataset: ) -> Dataset:
"""Load dataset annotation metadata from a configuration file. """Load dataset annotation metadata from a configuration file.
@ -215,9 +215,9 @@ def load_dataset_from_config(
def save_dataset( def save_dataset(
dataset: Dataset, dataset: Dataset,
path: data.PathLike, path: data.PathLike,
name: Optional[str] = None, name: str | None = None,
description: Optional[str] = None, description: str | None = None,
audio_dir: Optional[Path] = None, audio_dir: Path | None = None,
) -> None: ) -> None:
"""Save a loaded dataset (list of ClipAnnotations) to a file. """Save a loaded dataset (list of ClipAnnotations) to a file.

View File

@ -10,7 +10,7 @@ from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events( def iterate_over_sound_events(
dataset: Dataset, dataset: Dataset,
targets: TargetProtocol, targets: TargetProtocol,
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]: ) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset. """Iterate over sound events in a dataset.
Parameters Parameters

View File

@ -24,19 +24,14 @@ __all__ = [
OutputFormatConfig = Annotated[ OutputFormatConfig = Annotated[
Union[ BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig,
BatDetect2OutputConfig,
ParquetOutputConfig,
SoundEventOutputConfig,
RawOutputConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_output_formatter( def build_output_formatter(
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol: ) -> OutputFormatterProtocol:
"""Construct the final output formatter.""" """Construct the final output formatter."""
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
@ -48,9 +43,9 @@ def build_output_formatter(
def get_output_formatter( def get_output_formatter(
name: Optional[str] = None, name: str | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol: ) -> OutputFormatterProtocol:
"""Get the output formatter by name.""" """Get the output formatter by name."""
@ -71,9 +66,9 @@ def get_output_formatter(
def load_predictions( def load_predictions(
path: PathLike, path: PathLike,
format: Optional[str] = "raw", format: str | None = "raw",
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
): ):
"""Load predictions from a file.""" """Load predictions from a file."""
from batdetect2.targets import build_targets from batdetect2.targets import build_targets

View File

@ -123,7 +123,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
self, self,
predictions: Sequence[FileAnnotation], predictions: Sequence[FileAnnotation],
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> None: ) -> None:
path = Path(path) path = Path(path)

View File

@ -53,7 +53,7 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[BatDetect2Prediction],
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> None: ) -> None:
path = Path(path) path = Path(path)

View File

@ -55,7 +55,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[BatDetect2Prediction],
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> None: ) -> None:
path = Path(path) path = Path(path)
@ -84,7 +84,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def pred_to_xr( def pred_to_xr(
self, self,
prediction: BatDetect2Prediction, prediction: BatDetect2Prediction,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> xr.Dataset: ) -> xr.Dataset:
clip = prediction.clip clip = prediction.clip
recording = clip.recording recording = clip.recording

View File

@ -18,16 +18,16 @@ from batdetect2.typing import (
class SoundEventOutputConfig(BaseConfig): class SoundEventOutputConfig(BaseConfig):
name: Literal["soundevent"] = "soundevent" name: Literal["soundevent"] = "soundevent"
top_k: Optional[int] = 1 top_k: int | None = 1
min_score: Optional[float] = None min_score: float | None = None
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
top_k: Optional[int] = 1, top_k: int | None = 1,
min_score: Optional[float] = 0, min_score: float | None = 0,
): ):
self.targets = targets self.targets = targets
self.top_k = top_k self.top_k = top_k
@ -45,7 +45,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
self, self,
predictions: Sequence[data.ClipPrediction], predictions: Sequence[data.ClipPrediction],
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> None: ) -> None:
run = data.PredictionSet(clip_predictions=list(predictions)) run = data.PredictionSet(clip_predictions=list(predictions))

View File

@ -14,7 +14,7 @@ def split_dataset_by_recordings(
dataset: Dataset, dataset: Dataset,
targets: TargetProtocol, targets: TargetProtocol,
train_size: float = 0.75, train_size: float = 0.75,
random_state: Optional[int] = None, random_state: int | None = None,
) -> Tuple[Dataset, Dataset]: ) -> Tuple[Dataset, Dataset]:
recordings = extract_recordings_df(dataset) recordings = extract_recordings_df(dataset)

View File

@ -142,7 +142,7 @@ class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value" name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str tag_key: str
value_mapping: Dict[str, str] value_mapping: Dict[str, str]
target_key: Optional[str] = None target_key: str | None = None
class MapTagValue: class MapTagValue:
@ -150,7 +150,7 @@ class MapTagValue:
self, self,
tag_key: str, tag_key: str,
value_mapping: Dict[str, str], value_mapping: Dict[str, str],
target_key: Optional[str] = None, target_key: str | None = None,
): ):
self.tag_key = tag_key self.tag_key = tag_key
self.value_mapping = value_mapping self.value_mapping = value_mapping
@ -221,13 +221,7 @@ class ApplyAll:
SoundEventTransformConfig = Annotated[ SoundEventTransformConfig = Annotated[
Union[ SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig,
SetFrequencyBoundConfig,
ReplaceTagConfig,
MapTagValueConfig,
ApplyIfConfig,
ApplyAllConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -86,7 +86,7 @@ def compute_bandwidth(
def compute_max_power_bb( def compute_max_power_bb(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -131,7 +131,7 @@ def compute_max_power_bb(
def compute_max_power( def compute_max_power(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -157,7 +157,7 @@ def compute_max_power(
def compute_max_power_first( def compute_max_power_first(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -184,7 +184,7 @@ def compute_max_power_first(
def compute_max_power_second( def compute_max_power_second(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -211,7 +211,7 @@ def compute_max_power_second(
def compute_call_interval( def compute_call_interval(
prediction: types.Prediction, prediction: types.Prediction,
previous: Optional[types.Prediction] = None, previous: types.Prediction | None = None,
**_, **_,
) -> float: ) -> float:
"""Compute time between this call and the previous call in seconds.""" """Compute time between this call and the previous call in seconds."""

View File

@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
def get_params( def get_params(
make_dirs: bool = False, make_dirs: bool = False,
exps_dir: str = "../../experiments/", exps_dir: str = "../../experiments/",
model_name: Optional[str] = None, model_name: str | None = None,
experiment: Union[Path, str, None] = None, experiment: Path | str | None = None,
**kwargs, **kwargs,
) -> TrainingParameters: ) -> TrainingParameters:
experiments_dir = Path(exps_dir) experiments_dir = Path(exps_dir)

View File

@ -151,7 +151,7 @@ def run_nms(
def non_max_suppression( def non_max_suppression(
heat: torch.Tensor, heat: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]], kernel_size: int | Tuple[int, int],
): ):
# kernel can be an int or list/tuple # kernel can be an int or list/tuple
if isinstance(kernel_size, int): if isinstance(kernel_size, int):

View File

@ -213,18 +213,13 @@ class GeometricIOU(AffinityFunction):
AffinityConfig = Annotated[ AffinityConfig = Annotated[
Union[ TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig,
TimeAffinityConfig,
IntervalIOUConfig,
BBoxIOUConfig,
GeometricIOUConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_affinity_function( def build_affinity_function(
config: Optional[AffinityConfig] = None, config: AffinityConfig | None = None,
) -> AffinityFunction: ) -> AffinityFunction:
config = config or GeometricIOUConfig() config = config or GeometricIOUConfig()
return affinity_functions.build(config) return affinity_functions.build(config)

View File

@ -51,6 +51,6 @@ def get_default_eval_config() -> EvaluationConfig:
def load_evaluation_config( def load_evaluation_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
) -> EvaluationConfig: ) -> EvaluationConfig:
return load_config(path, schema=EvaluationConfig, field=field) return load_config(path, schema=EvaluationConfig, field=field)

View File

@ -39,8 +39,8 @@ class TestDataset(Dataset[TestExample]):
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None, clipper: ClipperProtocol | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
): ):
self.clip_annotations = list(clip_annotations) self.clip_annotations = list(clip_annotations)
self.clipper = clipper self.clipper = clipper
@ -78,10 +78,10 @@ class TestLoaderConfig(BaseConfig):
def build_test_loader( def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[TestLoaderConfig] = None, config: TestLoaderConfig | None = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
) -> DataLoader[TestExample]: ) -> DataLoader[TestExample]:
logger.info("Building test data loader...") logger.info("Building test data loader...")
config = config or TestLoaderConfig() config = config or TestLoaderConfig()
@ -109,9 +109,9 @@ def build_test_loader(
def build_test_dataset( def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[TestLoaderConfig] = None, config: TestLoaderConfig | None = None,
) -> TestDataset: ) -> TestDataset:
logger.info("Building training dataset...") logger.info("Building training dataset...")
config = config or TestLoaderConfig() config = config or TestLoaderConfig()

View File

@ -34,10 +34,10 @@ def evaluate(
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None, config: Optional["BatDetect2Config"] = None,
formatter: Optional["OutputFormatterProtocol"] = None, formatter: Optional["OutputFormatterProtocol"] = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: ) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config

View File

@ -51,8 +51,8 @@ class Evaluator:
def build_evaluator( def build_evaluator(
config: Optional[Union[EvaluationConfig, dict]] = None, config: EvaluationConfig | dict | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
) -> EvaluatorProtocol: ) -> EvaluatorProtocol:
targets = targets or build_targets() targets = targets or build_targets()

View File

@ -35,9 +35,9 @@ def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation], sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction], raw_predictions: Sequence[RawPrediction],
clip: data.Clip, clip: data.Clip,
scores: Optional[Sequence[float]] = None, scores: Sequence[float] | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
matcher: Optional[MatcherProtocol] = None, matcher: MatcherProtocol | None = None,
) -> ClipMatches: ) -> ClipMatches:
if matcher is None: if matcher is None:
matcher = build_matcher() matcher = build_matcher()
@ -151,7 +151,7 @@ def match_start_times(
predictions: Sequence[data.Geometry], predictions: Sequence[data.Geometry],
scores: Sequence[float], scores: Sequence[float],
distance_threshold: float = 0.01, distance_threshold: float = 0.01,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[int | None, int | None, float]]:
if not ground_truth: if not ground_truth:
for index in range(len(predictions)): for index in range(len(predictions)):
yield index, None, 0 yield index, None, 0
@ -287,7 +287,7 @@ def greedy_match(
scores: Sequence[float], scores: Sequence[float],
affinity_threshold: float = 0.5, affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity, affinity_function: AffinityFunction = compute_affinity,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[int | None, int | None, float]]:
"""Performs a greedy, one-to-one matching of source to target geometries. """Performs a greedy, one-to-one matching of source to target geometries.
Iterates through source geometries, prioritizing by score if provided. Each Iterates through source geometries, prioritizing by score if provided. Each
@ -514,12 +514,7 @@ class OptimalMatcher(MatcherProtocol):
MatchConfig = Annotated[ MatchConfig = Annotated[
Union[ GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig,
GreedyMatchConfig,
StartTimeMatchConfig,
OptimalMatchConfig,
GreedyAffinityMatchConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -558,7 +553,7 @@ def compute_affinity_matrix(
def select_optimal_matches( def select_optimal_matches(
affinity_matrix: np.ndarray, affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5, affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape num_gt, num_pred = affinity_matrix.shape
gts = set(range(num_gt)) gts = set(range(num_gt))
preds = set(range(num_pred)) preds = set(range(num_pred))
@ -588,7 +583,7 @@ def select_optimal_matches(
def select_greedy_matches( def select_greedy_matches(
affinity_matrix: np.ndarray, affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5, affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape num_gt, num_pred = affinity_matrix.shape
unmatched_pred = set(range(num_pred)) unmatched_pred = set(range(num_pred))
@ -612,6 +607,6 @@ def select_greedy_matches(
yield pred_idx, None, 0 yield pred_idx, None, 0
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol: def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig() config = config or StartTimeMatchConfig()
return matching_strategies.build(config) return matching_strategies.build(config)

View File

@ -36,13 +36,13 @@ __all__ = [
@dataclass @dataclass
class MatchEval: class MatchEval:
clip: data.Clip clip: data.Clip
gt: Optional[data.SoundEventAnnotation] gt: data.SoundEventAnnotation | None
pred: Optional[RawPrediction] pred: RawPrediction | None
is_prediction: bool is_prediction: bool
is_ground_truth: bool is_ground_truth: bool
is_generic: bool is_generic: bool
true_class: Optional[str] true_class: str | None
score: float score: float
@ -61,16 +61,16 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
class BaseClassificationConfig(BaseConfig): class BaseClassificationConfig(BaseConfig):
include: Optional[List[str]] = None include: List[str] | None = None
exclude: Optional[List[str]] = None exclude: List[str] | None = None
class BaseClassificationMetric: class BaseClassificationMetric:
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
include: Optional[List[str]] = None, include: List[str] | None = None,
exclude: Optional[List[str]] = None, exclude: List[str] | None = None,
): ):
self.targets = targets self.targets = targets
self.include = include self.include = include
@ -100,8 +100,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
ignore_non_predictions: bool = True, ignore_non_predictions: bool = True,
ignore_generic: bool = True, ignore_generic: bool = True,
label: str = "average_precision", label: str = "average_precision",
include: Optional[List[str]] = None, include: List[str] | None = None,
exclude: Optional[List[str]] = None, exclude: List[str] | None = None,
): ):
super().__init__(include=include, exclude=exclude, targets=targets) super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions self.ignore_non_predictions = ignore_non_predictions
@ -169,8 +169,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_non_predictions: bool = True, ignore_non_predictions: bool = True,
ignore_generic: bool = True, ignore_generic: bool = True,
label: str = "roc_auc", label: str = "roc_auc",
include: Optional[List[str]] = None, include: List[str] | None = None,
exclude: Optional[List[str]] = None, exclude: List[str] | None = None,
): ):
self.targets = targets self.targets = targets
self.ignore_non_predictions = ignore_non_predictions self.ignore_non_predictions = ignore_non_predictions
@ -225,10 +225,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
ClassificationMetricConfig = Annotated[ ClassificationMetricConfig = Annotated[
Union[ ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -123,10 +123,7 @@ class ClipClassificationROCAUC:
ClipClassificationMetricConfig = Annotated[ ClipClassificationMetricConfig = Annotated[
Union[ ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
ClipClassificationAveragePrecisionConfig,
ClipClassificationROCAUCConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -159,12 +159,7 @@ class ClipDetectionPrecision:
ClipDetectionMetricConfig = Annotated[ ClipDetectionMetricConfig = Annotated[
Union[ ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig,
ClipDetectionAveragePrecisionConfig,
ClipDetectionROCAUCConfig,
ClipDetectionRecallConfig,
ClipDetectionPrecisionConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -11,7 +11,7 @@ __all__ = [
def compute_precision_recall( def compute_precision_recall(
y_true, y_true,
y_score, y_score,
num_positives: Optional[int] = None, num_positives: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true) y_true = np.array(y_true)
y_score = np.array(y_score) y_score = np.array(y_score)
@ -41,7 +41,7 @@ def compute_precision_recall(
def average_precision( def average_precision(
y_true, y_true,
y_score, y_score,
num_positives: Optional[int] = None, num_positives: int | None = None,
) -> float: ) -> float:
if num_positives == 0: if num_positives == 0:
return np.nan return np.nan

View File

@ -28,8 +28,8 @@ __all__ = [
@dataclass @dataclass
class MatchEval: class MatchEval:
gt: Optional[data.SoundEventAnnotation] gt: data.SoundEventAnnotation | None
pred: Optional[RawPrediction] pred: RawPrediction | None
is_prediction: bool is_prediction: bool
is_ground_truth: bool is_ground_truth: bool
@ -212,12 +212,7 @@ class DetectionPrecision:
DetectionMetricConfig = Annotated[ DetectionMetricConfig = Annotated[
Union[ DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig,
DetectionAveragePrecisionConfig,
DetectionROCAUCConfig,
DetectionRecallConfig,
DetectionPrecisionConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -31,14 +31,14 @@ __all__ = [
@dataclass @dataclass
class MatchEval: class MatchEval:
clip: data.Clip clip: data.Clip
gt: Optional[data.SoundEventAnnotation] gt: data.SoundEventAnnotation | None
pred: Optional[RawPrediction] pred: RawPrediction | None
is_ground_truth: bool is_ground_truth: bool
is_generic: bool is_generic: bool
is_prediction: bool is_prediction: bool
pred_class: Optional[str] pred_class: str | None
true_class: Optional[str] true_class: str | None
score: float score: float
@ -301,13 +301,7 @@ class BalancedAccuracy:
TopClassMetricConfig = Annotated[ TopClassMetricConfig = Annotated[
Union[ TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig,
TopClassAveragePrecisionConfig,
TopClassROCAUCConfig,
TopClassRecallConfig,
TopClassPrecisionConfig,
BalancedAccuracyConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -10,7 +10,7 @@ from batdetect2.typing import TargetProtocol
class BasePlotConfig(BaseConfig): class BasePlotConfig(BaseConfig):
label: str = "plot" label: str = "plot"
theme: str = "default" theme: str = "default"
title: Optional[str] = None title: str | None = None
figsize: tuple[int, int] = (10, 10) figsize: tuple[int, int] = (10, 10)
dpi: int = 100 dpi: int = 100
@ -21,7 +21,7 @@ class BasePlot:
targets: TargetProtocol, targets: TargetProtocol,
label: str = "plot", label: str = "plot",
figsize: tuple[int, int] = (10, 10), figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None, title: str | None = None,
dpi: int = 100, dpi: int = 100,
theme: str = "default", theme: str = "default",
): ):

View File

@ -45,7 +45,7 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve" title: str | None = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -108,7 +108,7 @@ class PRCurve(BasePlot):
class ThresholdPrecisionCurveConfig(BasePlotConfig): class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve" name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve" label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve" title: str | None = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -181,7 +181,7 @@ class ThresholdPrecisionCurve(BasePlot):
class ThresholdRecallCurveConfig(BasePlotConfig): class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve" name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve" label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve" title: str | None = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -254,7 +254,7 @@ class ThresholdRecallCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve" title: str | None = "Classification ROC Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -326,12 +326,7 @@ class ROCCurve(BasePlot):
ClassificationPlotConfig = Annotated[ ClassificationPlotConfig = Annotated[
Union[ PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig,
PRCurveConfig,
ROCCurveConfig,
ThresholdPrecisionCurveConfig,
ThresholdRecallCurveConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -44,7 +44,7 @@ clip_classification_plots: Registry[
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve" title: str | None = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False separate_figures: bool = False
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Clip Classification ROC Curve" title: str | None = "Clip Classification ROC Curve"
separate_figures: bool = False separate_figures: bool = False
@ -174,10 +174,7 @@ class ROCCurve(BasePlot):
ClipClassificationPlotConfig = Annotated[ ClipClassificationPlotConfig = Annotated[
Union[ PRCurveConfig | ROCCurveConfig,
PRCurveConfig,
ROCCurveConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -41,7 +41,7 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve" title: str | None = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot): class PRCurve(BasePlot):
@ -74,7 +74,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Clip Detection ROC Curve" title: str | None = "Clip Detection ROC Curve"
class ROCCurve(BasePlot): class ROCCurve(BasePlot):
@ -107,7 +107,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig): class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution" name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution" label: str = "score_distribution"
title: Optional[str] = "Clip Detection Score Distribution" title: str | None = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot): class ScoreDistributionPlot(BasePlot):
@ -147,11 +147,7 @@ class ScoreDistributionPlot(BasePlot):
ClipDetectionPlotConfig = Annotated[ ClipDetectionPlotConfig = Annotated[
Union[ PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -37,7 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve" title: str | None = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -100,7 +100,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve" title: str | None = "Detection ROC Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -159,7 +159,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig): class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution" name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution" label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution" title: str | None = "Detection Score Distribution"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -226,7 +226,7 @@ class ScoreDistributionPlot(BasePlot):
class ExampleDetectionPlotConfig(BasePlotConfig): class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection" name: Literal["example_detection"] = "example_detection"
label: str = "example_detection" label: str = "example_detection"
title: Optional[str] = "Example Detection" title: str | None = "Example Detection"
figsize: tuple[int, int] = (10, 4) figsize: tuple[int, int] = (10, 4)
num_examples: int = 5 num_examples: int = 5
threshold: float = 0.2 threshold: float = 0.2
@ -292,12 +292,7 @@ class ExampleDetectionPlot(BasePlot):
DetectionPlotConfig = Annotated[ DetectionPlotConfig = Annotated[
Union[ PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig,
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -44,7 +44,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve" title: str | None = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve" title: str | None = "Top Class ROC Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -173,7 +173,7 @@ class ROCCurve(BasePlot):
class ConfusionMatrixConfig(BasePlotConfig): class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix" name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix" title: str | None = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10) figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix" label: str = "confusion_matrix"
exclude_generic: bool = True exclude_generic: bool = True
@ -257,7 +257,7 @@ class ConfusionMatrix(BasePlot):
class ExampleClassificationPlotConfig(BasePlotConfig): class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification" name: Literal["example_classification"] = "example_classification"
label: str = "example_classification" label: str = "example_classification"
title: Optional[str] = "Example Classification" title: str | None = "Example Classification"
num_examples: int = 4 num_examples: int = 4
threshold: float = 0.2 threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
@ -348,12 +348,7 @@ class ExampleClassificationPlot(BasePlot):
TopClassPlotConfig = Annotated[ TopClassPlotConfig = Annotated[
Union[ PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig,
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
ExampleClassificationPlotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -96,7 +96,7 @@ def extract_matches_dataframe(
EvaluationTableConfig = Annotated[ EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name") FullEvaluationTableConfig, Field(discriminator="name")
] ]

View File

@ -26,20 +26,14 @@ __all__ = [
TaskConfig = Annotated[ TaskConfig = Annotated[
Union[ ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig,
ClassificationTaskConfig,
DetectionTaskConfig,
ClipDetectionTaskConfig,
ClipClassificationTaskConfig,
TopClassDetectionTaskConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_task( def build_task(
config: TaskConfig, config: TaskConfig,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
) -> EvaluatorProtocol: ) -> EvaluatorProtocol:
targets = targets or build_targets() targets = targets or build_targets()
return tasks_registry.build(config, targets) return tasks_registry.build(config, targets)
@ -49,8 +43,8 @@ def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[BatDetect2Prediction],
task: Optional["str"] = None, task: Optional["str"] = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
config: Optional[Union[TaskConfig, dict]] = None, config: TaskConfig | dict | None = None,
): ):
if isinstance(config, BaseTaskConfig): if isinstance(config, BaseTaskConfig):
task_obj = build_task(config, targets) task_obj = build_task(config, targets)

View File

@ -67,9 +67,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str, prefix: str,
ignore_start_end: float = 0.01, ignore_start_end: float = 0.01,
plots: Optional[ plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
): ):
self.matcher = matcher self.matcher = matcher
self.metrics = metrics self.metrics = metrics
@ -147,9 +145,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
config: BaseTaskConfig, config: BaseTaskConfig,
targets: TargetProtocol, targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[ plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
**kwargs, **kwargs,
): ):
matcher = build_matcher(config.matching_strategy) matcher = build_matcher(config.matching_strategy)

View File

@ -98,8 +98,8 @@ def load_annotations(
dataset_name: str, dataset_name: str,
ann_path: str, ann_path: str,
audio_path: str, audio_path: str,
classes_to_ignore: Optional[List[str]] = None, classes_to_ignore: List[str] | None = None,
events_of_interest: Optional[List[str]] = None, events_of_interest: List[str] | None = None,
) -> List[types.FileAnnotation]: ) -> List[types.FileAnnotation]:
train_sets: List[types.DatasetDict] = [] train_sets: List[types.DatasetDict] = []
train_sets.append( train_sets.append(

View File

@ -13,7 +13,7 @@ from batdetect2 import types
def print_dataset_stats( def print_dataset_stats(
data: List[types.FileAnnotation], data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None, classes_to_ignore: List[str] | None = None,
) -> Counter[str]: ) -> Counter[str]:
print("Num files:", len(data)) print("Num files:", len(data))
counts, _ = tu.get_class_names(data, classes_to_ignore) counts, _ = tu.get_class_names(data, classes_to_ignore)

View File

@ -28,8 +28,8 @@ def run_batch_inference(
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None, config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
batch_size: Optional[int] = None, batch_size: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[BatDetect2Prediction]:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
@ -69,7 +69,7 @@ def process_file_list(
targets: Optional["TargetProtocol"] = None, targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[BatDetect2Prediction]:
clip_config = config.inference.clipping clip_config = config.inference.clipping
clips = get_clips_from_files( clips = get_clips_from_files(

View File

@ -36,7 +36,7 @@ class InferenceDataset(Dataset[DatasetItem]):
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
): ):
self.clips = list(clips) self.clips = list(clips)
self.preprocessor = preprocessor self.preprocessor = preprocessor
@ -66,11 +66,11 @@ class InferenceLoaderConfig(BaseConfig):
def build_inference_loader( def build_inference_loader(
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[InferenceLoaderConfig] = None, config: InferenceLoaderConfig | None = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
batch_size: Optional[int] = None, batch_size: int | None = None,
) -> DataLoader[DatasetItem]: ) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...") logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig() config = config or InferenceLoaderConfig()
@ -95,8 +95,8 @@ def build_inference_loader(
def build_inference_dataset( def build_inference_dataset(
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
) -> InferenceDataset: ) -> InferenceDataset:
if audio_loader is None: if audio_loader is None:
audio_loader = build_audio_loader() audio_loader = build_audio_loader()

View File

@ -49,14 +49,14 @@ def enable_logging(level: int):
class BaseLoggerConfig(BaseConfig): class BaseLoggerConfig(BaseConfig):
log_dir: Path = DEFAULT_LOGS_DIR log_dir: Path = DEFAULT_LOGS_DIR
experiment_name: Optional[str] = None experiment_name: str | None = None
run_name: Optional[str] = None run_name: str | None = None
class DVCLiveConfig(BaseLoggerConfig): class DVCLiveConfig(BaseLoggerConfig):
name: Literal["dvclive"] = "dvclive" name: Literal["dvclive"] = "dvclive"
prefix: str = "" prefix: str = ""
log_model: Union[bool, Literal["all"]] = False log_model: bool | Literal["all"] = False
monitor_system: bool = False monitor_system: bool = False
@ -72,18 +72,13 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
class MLFlowLoggerConfig(BaseLoggerConfig): class MLFlowLoggerConfig(BaseLoggerConfig):
name: Literal["mlflow"] = "mlflow" name: Literal["mlflow"] = "mlflow"
tracking_uri: Optional[str] = "http://localhost:5000" tracking_uri: str | None = "http://localhost:5000"
tags: Optional[dict[str, Any]] = None tags: dict[str, Any] | None = None
log_model: bool = False log_model: bool = False
LoggerConfig = Annotated[ LoggerConfig = Annotated[
Union[ DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig,
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -95,17 +90,17 @@ class LoggerBuilder(Protocol, Generic[T]):
def __call__( def __call__(
self, self,
config: T, config: T,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ... ) -> Logger: ...
def create_dvclive_logger( def create_dvclive_logger(
config: DVCLiveConfig, config: DVCLiveConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
try: try:
from dvclive.lightning import DVCLiveLogger # type: ignore from dvclive.lightning import DVCLiveLogger # type: ignore
@ -130,9 +125,9 @@ def create_dvclive_logger(
def create_csv_logger( def create_csv_logger(
config: CSVLoggerConfig, config: CSVLoggerConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers import CSVLogger
@ -159,9 +154,9 @@ def create_csv_logger(
def create_tensorboard_logger( def create_tensorboard_logger(
config: TensorBoardLoggerConfig, config: TensorBoardLoggerConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
@ -191,9 +186,9 @@ def create_tensorboard_logger(
def create_mlflow_logger( def create_mlflow_logger(
config: MLFlowLoggerConfig, config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: data.PathLike | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
try: try:
from lightning.pytorch.loggers import MLFlowLogger from lightning.pytorch.loggers import MLFlowLogger
@ -232,9 +227,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
def build_logger( def build_logger(
config: LoggerConfig, config: LoggerConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building logger with config: \n{}", "Building logger with config: \n{}",
@ -257,7 +252,7 @@ def build_logger(
PlotLogger = Callable[[str, Figure, int], None] PlotLogger = Callable[[str, Figure, int], None]
def get_image_logger(logger: Logger) -> Optional[PlotLogger]: def get_image_logger(logger: Logger) -> PlotLogger | None:
if isinstance(logger, TensorBoardLogger): if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure return logger.experiment.add_figure
@ -282,7 +277,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
TableLogger = Callable[[str, pd.DataFrame, int], None] TableLogger = Callable[[str, pd.DataFrame, int], None]
def get_table_logger(logger: Logger) -> Optional[TableLogger]: def get_table_logger(logger: Logger) -> TableLogger | None:
if isinstance(logger, TensorBoardLogger): if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir)) return partial(save_table, dir=Path(logger.log_dir))

View File

@ -43,10 +43,7 @@ from batdetect2.models.bottleneck import (
BottleneckConfig, BottleneckConfig,
build_bottleneck, build_bottleneck,
) )
from batdetect2.models.config import ( from batdetect2.models.config import BackboneConfig, load_backbone_config
BackboneConfig,
load_backbone_config,
)
from batdetect2.models.decoder import ( from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG, DEFAULT_DECODER_CONFIG,
DecoderConfig, DecoderConfig,
@ -122,10 +119,10 @@ class Model(torch.nn.Module):
def build_model( def build_model(
config: Optional[BackboneConfig] = None, config: BackboneConfig | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
postprocessor: Optional[PostprocessorProtocol] = None, postprocessor: PostprocessorProtocol | None = None,
): ):
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor

View File

@ -78,8 +78,8 @@ class Bottleneck(nn.Module):
input_height: int, input_height: int,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
bottleneck_channels: Optional[int] = None, bottleneck_channels: int | None = None,
layers: Optional[List[torch.nn.Module]] = None, layers: List[torch.nn.Module] | None = None,
) -> None: ) -> None:
"""Initialize the base Bottleneck layer.""" """Initialize the base Bottleneck layer."""
super().__init__() super().__init__()
@ -127,7 +127,7 @@ class Bottleneck(nn.Module):
BottleneckLayerConfig = Annotated[ BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,], SelfAttentionConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Decoder.""" """Type alias for the discriminated union of block configs usable in Decoder."""
@ -171,7 +171,7 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
def build_bottleneck( def build_bottleneck(
input_height: int, input_height: int,
in_channels: int, in_channels: int,
config: Optional[BottleneckConfig] = None, config: BottleneckConfig | None = None,
) -> nn.Module: ) -> nn.Module:
"""Factory function to build the Bottleneck module from configuration. """Factory function to build the Bottleneck module from configuration.

View File

@ -63,7 +63,7 @@ class BackboneConfig(BaseConfig):
def load_backbone_config( def load_backbone_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
) -> BackboneConfig: ) -> BackboneConfig:
"""Load the backbone configuration from a file. """Load the backbone configuration from a file.

View File

@ -41,12 +41,7 @@ __all__ = [
] ]
DecoderLayerConfig = Annotated[ DecoderLayerConfig = Annotated[
Union[ ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig,
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
LayerGroupConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Decoder.""" """Type alias for the discriminated union of block configs usable in Decoder."""
@ -216,7 +211,7 @@ convolutional block.
def build_decoder( def build_decoder(
in_channels: int, in_channels: int,
input_height: int, input_height: int,
config: Optional[DecoderConfig] = None, config: DecoderConfig | None = None,
) -> Decoder: ) -> Decoder:
"""Factory function to build a Decoder instance from configuration. """Factory function to build a Decoder instance from configuration.

View File

@ -127,7 +127,7 @@ class Detector(DetectionModel):
def build_detector( def build_detector(
num_classes: int, config: Optional[BackboneConfig] = None num_classes: int, config: BackboneConfig | None = None
) -> DetectionModel: ) -> DetectionModel:
"""Build the complete BatDetect2 detection model. """Build the complete BatDetect2 detection model.

View File

@ -43,12 +43,7 @@ __all__ = [
] ]
EncoderLayerConfig = Annotated[ EncoderLayerConfig = Annotated[
Union[ ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig,
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
LayerGroupConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Encoder.""" """Type alias for the discriminated union of block configs usable in Encoder."""
@ -252,7 +247,7 @@ Specifies an architecture typically used in BatDetect2:
def build_encoder( def build_encoder(
in_channels: int, in_channels: int,
input_height: int, input_height: int,
config: Optional[EncoderConfig] = None, config: EncoderConfig | None = None,
) -> Encoder: ) -> Encoder:
"""Factory function to build an Encoder instance from configuration. """Factory function to build an Encoder instance from configuration.

View File

@ -15,10 +15,10 @@ __all__ = [
def plot_clip_annotation( def plot_clip_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
add_points: bool = False, add_points: bool = False,
cmap: str = "gray", cmap: str = "gray",
alpha: float = 1, alpha: float = 1,
@ -50,8 +50,8 @@ def plot_clip_annotation(
def plot_anchor_points( def plot_anchor_points(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
targets: TargetProtocol, targets: TargetProtocol,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
size: int = 1, size: int = 1,
color: str = "red", color: str = "red",
marker: str = "x", marker: str = "x",

View File

@ -17,10 +17,10 @@ __all__ = [
def plot_clip_prediction( def plot_clip_prediction(
clip_prediction: data.ClipPrediction, clip_prediction: data.ClipPrediction,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
add_legend: bool = False, add_legend: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
linewidth: float = 1, linewidth: float = 1,
@ -50,14 +50,14 @@ def plot_clip_prediction(
def plot_predictions( def plot_predictions(
predictions: Iterable[data.SoundEventPrediction], predictions: Iterable[data.SoundEventPrediction],
ax: Optional[Axes] = None, ax: Axes | None = None,
position: Positions = "top-right", position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None, color_mapper: TagColorMapper | None = None,
time_offset: float = 0.001, time_offset: float = 0.001,
freq_offset: float = 1000, freq_offset: float = 1000,
legend: bool = True, legend: bool = True,
max_alpha: float = 0.5, max_alpha: float = 0.5,
color: Optional[str] = None, color: str | None = None,
**kwargs, **kwargs,
): ):
"""Plot an prediction.""" """Plot an prediction."""
@ -88,14 +88,14 @@ def plot_predictions(
def plot_prediction( def plot_prediction(
prediction: data.SoundEventPrediction, prediction: data.SoundEventPrediction,
ax: Optional[Axes] = None, ax: Axes | None = None,
position: Positions = "top-right", position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None, color_mapper: TagColorMapper | None = None,
time_offset: float = 0.001, time_offset: float = 0.001,
freq_offset: float = 1000, freq_offset: float = 1000,
max_alpha: float = 0.5, max_alpha: float = 0.5,
alpha: Optional[float] = None, alpha: float | None = None,
color: Optional[str] = None, color: str | None = None,
**kwargs, **kwargs,
) -> Axes: ) -> Axes:
"""Plot an annotation.""" """Plot an annotation."""

View File

@ -17,11 +17,11 @@ __all__ = [
def plot_clip( def plot_clip(
clip: data.Clip, clip: data.Clip,
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
spec_cmap: str = "gray", spec_cmap: str = "gray",
) -> Axes: ) -> Axes:
if ax is None: if ax is None:

View File

@ -13,8 +13,8 @@ __all__ = [
def create_ax( def create_ax(
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
**kwargs, **kwargs,
) -> axes.Axes: ) -> axes.Axes:
"""Create a new axis if none is provided""" """Create a new axis if none is provided"""
@ -25,17 +25,17 @@ def create_ax(
def plot_spectrogram( def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray], spec: torch.Tensor | np.ndarray,
start_time: Optional[float] = None, start_time: float | None = None,
end_time: Optional[float] = None, end_time: float | None = None,
min_freq: Optional[float] = None, min_freq: float | None = None,
max_freq: Optional[float] = None, max_freq: float | None = None,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_colorbar: bool = False, add_colorbar: bool = False,
colorbar_kwargs: Optional[dict] = None, colorbar_kwargs: dict | None = None,
vmin: Optional[float] = None, vmin: float | None = None,
vmax: Optional[float] = None, vmax: float | None = None,
cmap="gray", cmap="gray",
) -> axes.Axes: ) -> axes.Axes:
if isinstance(spec, torch.Tensor): if isinstance(spec, torch.Tensor):

View File

@ -19,9 +19,9 @@ __all__ = [
def plot_clip_detections( def plot_clip_detections(
clip_eval: ClipEval, clip_eval: ClipEval,
figsize: tuple[int, int] = (10, 10), figsize: tuple[int, int] = (10, 10),
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
threshold: float = 0.2, threshold: float = 0.2,
add_legend: bool = True, add_legend: bool = True,
add_title: bool = True, add_title: bool = True,

View File

@ -20,11 +20,11 @@ def plot_match_gallery(
false_positives: Sequence[MatchProtocol], false_positives: Sequence[MatchProtocol],
false_negatives: Sequence[MatchProtocol], false_negatives: Sequence[MatchProtocol],
cross_triggers: Sequence[MatchProtocol], cross_triggers: Sequence[MatchProtocol],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
n_examples: int = 5, n_examples: int = 5,
duration: float = 0.1, duration: float = 0.1,
fig: Optional[Figure] = None, fig: Figure | None = None,
): ):
if fig is None: if fig is None:
fig = plt.figure(figsize=(20, 20)) fig = plt.figure(figsize=(20, 20))

View File

@ -12,13 +12,13 @@ from batdetect2.plotting.common import create_ax
def plot_detection_heatmap( def plot_detection_heatmap(
heatmap: Union[torch.Tensor, np.ndarray], heatmap: torch.Tensor | np.ndarray,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10), figsize: Tuple[int, int] = (10, 10),
threshold: Optional[float] = None, threshold: float | None = None,
alpha: float = 1, alpha: float = 1,
cmap: Union[str, Colormap] = "jet", cmap: str | Colormap = "jet",
color: Optional[str] = None, color: str | None = None,
) -> axes.Axes: ) -> axes.Axes:
ax = create_ax(ax, figsize=figsize) ax = create_ax(ax, figsize=figsize)
@ -48,13 +48,13 @@ def plot_detection_heatmap(
def plot_classification_heatmap( def plot_classification_heatmap(
heatmap: Union[torch.Tensor, np.ndarray], heatmap: torch.Tensor | np.ndarray,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10), figsize: Tuple[int, int] = (10, 10),
class_names: Optional[List[str]] = None, class_names: List[str] | None = None,
threshold: Optional[float] = 0.1, threshold: float | None = 0.1,
alpha: float = 1, alpha: float = 1,
cmap: Union[str, Colormap] = "tab20", cmap: str | Colormap = "tab20",
): ):
ax = create_ax(ax, figsize=figsize) ax = create_ax(ax, figsize=figsize)

View File

@ -24,10 +24,10 @@ __all__ = [
def spectrogram( def spectrogram(
spec: Union[torch.Tensor, np.ndarray], spec: torch.Tensor | np.ndarray,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
cmap: str = "plasma", cmap: str = "plasma",
start_time: float = 0, start_time: float = 0,
) -> axes.Axes: ) -> axes.Axes:
@ -103,11 +103,11 @@ def spectrogram(
def spectrogram_with_detections( def spectrogram_with_detections(
spec: Union[torch.Tensor, np.ndarray], spec: torch.Tensor | np.ndarray,
dets: List[Annotation], dets: List[Annotation],
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
cmap: str = "plasma", cmap: str = "plasma",
with_names: bool = True, with_names: bool = True,
start_time: float = 0, start_time: float = 0,
@ -168,8 +168,8 @@ def spectrogram_with_detections(
def detections( def detections(
dets: List[Annotation], dets: List[Annotation],
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
with_names: bool = True, with_names: bool = True,
**kwargs, **kwargs,
) -> axes.Axes: ) -> axes.Axes:
@ -213,8 +213,8 @@ def detections(
def detection( def detection(
det: Annotation, det: Annotation,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
linewidth: float = 1, linewidth: float = 1,
edgecolor: str = "w", edgecolor: str = "w",
facecolor: str = "none", facecolor: str = "none",

View File

@ -21,10 +21,10 @@ __all__ = [
class MatchProtocol(Protocol): class MatchProtocol(Protocol):
clip: data.Clip clip: data.Clip
gt: Optional[data.SoundEventAnnotation] gt: data.SoundEventAnnotation | None
pred: Optional[RawPrediction] pred: RawPrediction | None
score: float score: float
true_class: Optional[str] true_class: str | None
DEFAULT_DURATION = 0.05 DEFAULT_DURATION = 0.05
@ -38,11 +38,11 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_false_positive_match( def plot_false_positive_match(
match: MatchProtocol, match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
use_score: bool = True, use_score: bool = True,
add_spectrogram: bool = True, add_spectrogram: bool = True,
@ -52,7 +52,7 @@ def plot_false_positive_match(
fill: bool = False, fill: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_POSITIVE_COLOR, color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small", fontsize: float | str = "small",
) -> Axes: ) -> Axes:
assert match.pred is not None assert match.pred is not None
@ -109,11 +109,11 @@ def plot_false_positive_match(
def plot_false_negative_match( def plot_false_negative_match(
match: MatchProtocol, match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
add_spectrogram: bool = True, add_spectrogram: bool = True,
add_points: bool = False, add_points: bool = False,
@ -169,11 +169,11 @@ def plot_false_negative_match(
def plot_true_positive_match( def plot_true_positive_match(
match: MatchProtocol, match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
use_score: bool = True, use_score: bool = True,
add_spectrogram: bool = True, add_spectrogram: bool = True,
@ -182,7 +182,7 @@ def plot_true_positive_match(
fill: bool = False, fill: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR, color: str = DEFAULT_TRUE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small", fontsize: float | str = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
add_title: bool = True, add_title: bool = True,
@ -257,11 +257,11 @@ def plot_true_positive_match(
def plot_cross_trigger_match( def plot_cross_trigger_match(
match: MatchProtocol, match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
ax: Optional[Axes] = None, ax: Axes | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
use_score: bool = True, use_score: bool = True,
add_spectrogram: bool = True, add_spectrogram: bool = True,
@ -271,7 +271,7 @@ def plot_cross_trigger_match(
fill: bool = False, fill: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR, color: str = DEFAULT_CROSS_TRIGGER_COLOR,
fontsize: Union[float, str] = "small", fontsize: float | str = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes: ) -> Axes:

View File

@ -33,16 +33,16 @@ def plot_pr_curve(
precision: np.ndarray, precision: np.ndarray,
recall: np.ndarray, recall: np.ndarray,
thresholds: np.ndarray, thresholds: np.ndarray,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
color: Union[str, Tuple[float, float, float], None] = None, color: str | Tuple[float, float, float] | None = None,
add_labels: bool = True, add_labels: bool = True,
add_legend: bool = False, add_legend: bool = False,
marker: Union[str, Tuple[int, int, float], None] = "o", marker: str | Tuple[int, int, float] | None = "o",
markeredgecolor: Union[str, Tuple[float, float, float], None] = None, markeredgecolor: str | Tuple[float, float, float] | None = None,
markersize: Optional[float] = None, markersize: float | None = None,
linestyle: Union[str, Tuple[int, ...], None] = None, linestyle: str | Tuple[int, ...] | None = None,
linewidth: Optional[float] = None, linewidth: float | None = None,
label: str = "PR Curve", label: str = "PR Curve",
) -> axes.Axes: ) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -77,8 +77,8 @@ def plot_pr_curve(
def plot_pr_curves( def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
include_ap: bool = False, include_ap: bool = False,
@ -118,8 +118,8 @@ def plot_pr_curves(
def plot_threshold_precision_curve( def plot_threshold_precision_curve(
threshold: np.ndarray, threshold: np.ndarray,
precision: np.ndarray, precision: np.ndarray,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_labels: bool = True, add_labels: bool = True,
): ):
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -140,8 +140,8 @@ def plot_threshold_precision_curve(
def plot_threshold_precision_curves( def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
): ):
@ -176,8 +176,8 @@ def plot_threshold_precision_curves(
def plot_threshold_recall_curve( def plot_threshold_recall_curve(
threshold: np.ndarray, threshold: np.ndarray,
recall: np.ndarray, recall: np.ndarray,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_labels: bool = True, add_labels: bool = True,
): ):
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -198,8 +198,8 @@ def plot_threshold_recall_curve(
def plot_threshold_recall_curves( def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
): ):
@ -235,8 +235,8 @@ def plot_roc_curve(
fpr: np.ndarray, fpr: np.ndarray,
tpr: np.ndarray, tpr: np.ndarray,
thresholds: np.ndarray, thresholds: np.ndarray,
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_labels: bool = True, add_labels: bool = True,
) -> axes.Axes: ) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -261,8 +261,8 @@ def plot_roc_curve(
def plot_roc_curves( def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None, ax: axes.Axes | None = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Tuple[int, int] | None = None,
add_legend: bool = True, add_legend: bool = True,
add_labels: bool = True, add_labels: bool = True,
) -> axes.Axes: ) -> axes.Axes:

View File

@ -57,7 +57,7 @@ class PostprocessConfig(BaseConfig):
def load_postprocess_config( def load_postprocess_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
) -> PostprocessConfig: ) -> PostprocessConfig:
"""Load the postprocessing configuration from a file. """Load the postprocessing configuration from a file.

View File

@ -88,9 +88,7 @@ def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction, raw_prediction: RawPrediction,
recording: data.Recording, recording: data.Recording,
targets: TargetProtocol, targets: TargetProtocol,
classification_threshold: Optional[ classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False, top_class_only: bool = False,
): ):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.""" """Convert a single RawPrediction into a soundevent SoundEventPrediction."""
@ -150,7 +148,7 @@ def get_class_tags(
class_scores: np.ndarray, class_scores: np.ndarray,
targets: TargetProtocol, targets: TargetProtocol,
top_class_only: bool = False, top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD, threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]: ) -> List[data.PredictedTag]:
"""Generate specific PredictedTags based on class scores and decoder. """Generate specific PredictedTags based on class scores and decoder.

View File

@ -32,7 +32,7 @@ def extract_detection_peaks(
feature_heatmap: torch.Tensor, feature_heatmap: torch.Tensor,
classification_heatmap: torch.Tensor, classification_heatmap: torch.Tensor,
max_detections: int = 200, max_detections: int = 200,
threshold: Optional[float] = None, threshold: float | None = None,
) -> List[ClipDetectionsTensor]: ) -> List[ClipDetectionsTensor]:
height = detection_heatmap.shape[-2] height = detection_heatmap.shape[-2]
width = detection_heatmap.shape[-1] width = detection_heatmap.shape[-1]

View File

@ -27,7 +27,7 @@ BatDetect2.
def non_max_suppression( def non_max_suppression(
tensor: torch.Tensor, tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap. """Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.

View File

@ -24,7 +24,7 @@ __all__ = [
def build_postprocessor( def build_postprocessor(
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None, config: PostprocessConfig | None = None,
) -> PostprocessorProtocol: ) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor.""" """Factory function to build the standard postprocessor."""
config = config or PostprocessConfig() config = config or PostprocessConfig()
@ -51,7 +51,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
max_freq: float, max_freq: float,
top_k_per_sec: int = 200, top_k_per_sec: int = 200,
detection_threshold: float = 0.01, detection_threshold: float = 0.01,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
): ):
"""Initialize the Postprocessor.""" """Initialize the Postprocessor."""
super().__init__() super().__init__()
@ -66,7 +66,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def forward( def forward(
self, self,
output: ModelOutput, output: ModelOutput,
start_times: Optional[List[float]] = None, start_times: List[float] | None = None,
) -> List[ClipDetectionsTensor]: ) -> List[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression( detection_heatmap = non_max_suppression(
output.detection_probs.detach(), output.detection_probs.detach(),

View File

@ -31,14 +31,14 @@ __all__ = [
def to_xarray( def to_xarray(
array: Union[torch.Tensor, np.ndarray], array: torch.Tensor | np.ndarray,
start_time: float, start_time: float,
end_time: float, end_time: float,
min_freq: float = MIN_FREQ, min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ, max_freq: float = MAX_FREQ,
name: str = "xarray", name: str = "xarray",
extra_dims: Optional[List[str]] = None, extra_dims: List[str] | None = None,
extra_coords: Optional[Dict[str, np.ndarray]] = None, extra_coords: Dict[str, np.ndarray] | None = None,
) -> xr.DataArray: ) -> xr.DataArray:
if isinstance(array, torch.Tensor): if isinstance(array, torch.Tensor):
array = array.detach().cpu().numpy() array = array.detach().cpu().numpy()

View File

@ -78,11 +78,7 @@ class FixDuration(torch.nn.Module):
AudioTransform = Annotated[ AudioTransform = Annotated[
Union[ FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -57,6 +57,6 @@ class PreprocessingConfig(BaseConfig):
def load_preprocessing_config( def load_preprocessing_config(
path: PathLike, path: PathLike,
field: Optional[str] = None, field: str | None = None,
) -> PreprocessingConfig: ) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field) return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -102,7 +102,7 @@ def compute_output_samplerate(
def build_preprocessor( def build_preprocessor(
config: Optional[PreprocessingConfig] = None, config: PreprocessingConfig | None = None,
input_samplerate: int = TARGET_SAMPLERATE_HZ, input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> PreprocessorProtocol: ) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration.""" """Factory function to build the standard preprocessor from configuration."""

View File

@ -98,7 +98,7 @@ def _frequency_to_index(
freq: float, freq: float,
n_fft: int, n_fft: int,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
) -> Optional[int]: ) -> int | None:
alpha = freq * 2 / samplerate alpha = freq * 2 / samplerate
height = np.floor(n_fft / 2) + 1 height = np.floor(n_fft / 2) + 1
index = int(np.floor(alpha * height)) index = int(np.floor(alpha * height))
@ -134,8 +134,8 @@ class FrequencyCrop(torch.nn.Module):
self, self,
samplerate: int, samplerate: int,
n_fft: int, n_fft: int,
min_freq: Optional[int] = None, min_freq: int | None = None,
max_freq: Optional[int] = None, max_freq: int | None = None,
): ):
super().__init__() super().__init__()
self.n_fft = n_fft self.n_fft = n_fft
@ -181,7 +181,7 @@ class FrequencyCrop(torch.nn.Module):
def build_spectrogram_crop( def build_spectrogram_crop(
config: FrequencyConfig, config: FrequencyConfig,
stft: Optional[STFTConfig] = None, stft: STFTConfig | None = None,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module: ) -> torch.nn.Module:
stft = stft or STFTConfig() stft = stft or STFTConfig()
@ -377,12 +377,7 @@ class PeakNormalize(torch.nn.Module):
SpectrogramTransform = Annotated[ SpectrogramTransform = Annotated[
Union[ PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig,
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -30,16 +30,16 @@ class TargetClassConfig(BaseConfig):
name: str name: str
condition_input: Optional[SoundEventConditionConfig] = Field( condition_input: SoundEventConditionConfig | None = Field(
alias="match_if", alias="match_if",
default=None, default=None,
) )
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True) tags: List[data.Tag] | None = Field(default=None, exclude=True)
assign_tags: List[data.Tag] = Field(default_factory=list) assign_tags: List[data.Tag] = Field(default_factory=list)
roi: Optional[ROIMapperConfig] = None roi: ROIMapperConfig | None = None
_match_if: SoundEventConditionConfig = PrivateAttr() _match_if: SoundEventConditionConfig = PrivateAttr()
@ -202,7 +202,7 @@ class SoundEventClassifier:
def __call__( def __call__(
self, sound_event_annotation: data.SoundEventAnnotation self, sound_event_annotation: data.SoundEventAnnotation
) -> Optional[str]: ) -> str | None:
for name, condition in self.mapping.items(): for name, condition in self.mapping.items():
if condition(sound_event_annotation): if condition(sound_event_annotation):
return name return name

View File

@ -48,7 +48,7 @@ class TargetConfig(BaseConfig):
def load_target_config( def load_target_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
) -> TargetConfig: ) -> TargetConfig:
"""Load the unified target configuration from a file. """Load the unified target configuration from a file.

View File

@ -414,10 +414,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
ROIMapperConfig = Annotated[ ROIMapperConfig = Annotated[
Union[ AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
AnchorBBoxMapperConfig,
PeakEnergyBBoxMapperConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""A discriminated union of all supported ROI mapper configurations. """A discriminated union of all supported ROI mapper configurations.
@ -428,7 +425,7 @@ implementations by using the `name` field as a discriminator.
def build_roi_mapper( def build_roi_mapper(
config: Optional[ROIMapperConfig] = None, config: ROIMapperConfig | None = None,
) -> ROITargetMapper: ) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from a config object. """Factory function to create an ROITargetMapper from a config object.
@ -572,9 +569,9 @@ def get_peak_energy_coordinates(
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
start_time: float = 0, start_time: float = 0,
end_time: Optional[float] = None, end_time: float | None = None,
low_freq: float = 0, low_freq: float = 0,
high_freq: Optional[float] = None, high_freq: float | None = None,
loading_buffer: float = 0.05, loading_buffer: float = 0.05,
) -> Position: ) -> Position:
"""Find the coordinates of the highest energy point in a spectrogram. """Find the coordinates of the highest energy point in a spectrogram.

View File

@ -107,7 +107,7 @@ class Targets(TargetProtocol):
def encode_class( def encode_class(
self, sound_event: data.SoundEventAnnotation self, sound_event: data.SoundEventAnnotation
) -> Optional[str]: ) -> str | None:
"""Encode a sound event annotation to its target class name. """Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority) Applies the configured class definition rules (including priority)
@ -182,7 +182,7 @@ class Targets(TargetProtocol):
self, self,
position: Position, position: Position,
size: Size, size: Size,
class_name: Optional[str] = None, class_name: str | None = None,
) -> data.Geometry: ) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions. """Recover an approximate geometric ROI from a position and dimensions.
@ -219,7 +219,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
) )
def build_targets(config: Optional[TargetConfig] = None) -> Targets: def build_targets(config: TargetConfig | None = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig. """Build a Targets object from a loaded TargetConfig.
Parameters Parameters
@ -251,7 +251,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
def load_targets( def load_targets(
config_path: data.PathLike, config_path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
) -> Targets: ) -> Targets:
"""Load a Targets object directly from a configuration file. """Load a Targets object directly from a configuration file.
@ -292,7 +292,7 @@ def load_targets(
def iterate_encoded_sound_events( def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation], sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol, targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]: ) -> Iterable[Tuple[str | None, Position, Size]]:
for sound_event in sound_events: for sound_event in sound_events:
if not targets.filter(sound_event): if not targets.filter(sound_event):
continue continue

View File

@ -42,7 +42,7 @@ __all__ = [
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]] AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = ( audio_augmentations: Registry[Augmentation, [int, AudioSource | None]] = (
Registry(name="audio_augmentation") Registry(name="audio_augmentation")
) )
@ -103,7 +103,7 @@ class MixAudio(torch.nn.Module):
def from_config( def from_config(
config: MixAudioConfig, config: MixAudioConfig,
samplerate: int, samplerate: int,
source: Optional[AudioSource], source: AudioSource | None,
): ):
if source is None: if source is None:
warnings.warn( warnings.warn(
@ -207,7 +207,7 @@ class AddEcho(torch.nn.Module):
def from_config( def from_config(
config: AddEchoConfig, config: AddEchoConfig,
samplerate: int, samplerate: int,
source: Optional[AudioSource], source: AudioSource | None,
): ):
return AddEcho( return AddEcho(
samplerate=samplerate, samplerate=samplerate,
@ -487,33 +487,18 @@ def mask_frequency(
AudioAugmentationConfig = Annotated[ AudioAugmentationConfig = Annotated[
Union[ MixAudioConfig | AddEchoConfig,
MixAudioConfig,
AddEchoConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
SpectrogramAugmentationConfig = Annotated[ SpectrogramAugmentationConfig = Annotated[
Union[ ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
ScaleVolumeConfig,
WarpConfig,
MaskFrequencyConfig,
MaskTimeConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[
Union[ MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
MixAudioConfig,
AddEchoConfig,
ScaleVolumeConfig,
WarpConfig,
MaskFrequencyConfig,
MaskTimeConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of individual augmentation config.""" """Type alias for the discriminated union of individual augmentation config."""
@ -559,8 +544,8 @@ class MaybeApply(torch.nn.Module):
def build_augmentation_from_config( def build_augmentation_from_config(
config: AugmentationConfig, config: AugmentationConfig,
samplerate: int, samplerate: int,
audio_source: Optional[AudioSource] = None, audio_source: AudioSource | None = None,
) -> Optional[Augmentation]: ) -> Augmentation | None:
"""Factory function to build a single augmentation from its config.""" """Factory function to build a single augmentation from its config."""
if config.name == "mix_audio": if config.name == "mix_audio":
if audio_source is None: if audio_source is None:
@ -645,10 +630,10 @@ class AugmentationSequence(torch.nn.Module):
def build_audio_augmentations( def build_audio_augmentations(
steps: Optional[Sequence[AudioAugmentationConfig]] = None, steps: Sequence[AudioAugmentationConfig] | None = None,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
audio_source: Optional[AudioSource] = None, audio_source: AudioSource | None = None,
) -> Optional[Augmentation]: ) -> Augmentation | None:
if not steps: if not steps:
return None return None
@ -673,8 +658,8 @@ def build_audio_augmentations(
def build_spectrogram_augmentations( def build_spectrogram_augmentations(
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None, steps: Sequence[SpectrogramAugmentationConfig] | None = None,
) -> Optional[Augmentation]: ) -> Augmentation | None:
if not steps: if not steps:
return None return None
@ -698,9 +683,9 @@ def build_spectrogram_augmentations(
def build_augmentations( def build_augmentations(
samplerate: int, samplerate: int,
config: Optional[AugmentationsConfig] = None, config: AugmentationsConfig | None = None,
audio_source: Optional[AudioSource] = None, audio_source: AudioSource | None = None,
) -> Tuple[Optional[Augmentation], Optional[Augmentation]]: ) -> Tuple[Augmentation | None, Augmentation | None]:
"""Build a composite augmentation pipeline function from configuration.""" """Build a composite augmentation pipeline function from configuration."""
config = config or DEFAULT_AUGMENTATION_CONFIG config = config or DEFAULT_AUGMENTATION_CONFIG
@ -723,7 +708,7 @@ def build_augmentations(
def load_augmentation_config( def load_augmentation_config(
path: data.PathLike, field: Optional[str] = None path: data.PathLike, field: str | None = None
) -> AugmentationsConfig: ) -> AugmentationsConfig:
"""Load the augmentations configuration from a file.""" """Load the augmentations configuration from a file."""
return load_config(path, schema=AugmentationsConfig, field=field) return load_config(path, schema=AugmentationsConfig, field=field)

View File

@ -18,14 +18,14 @@ class CheckpointConfig(BaseConfig):
monitor: str = "classification/mean_average_precision" monitor: str = "classification/mean_average_precision"
mode: str = "max" mode: str = "max"
save_top_k: int = 1 save_top_k: int = 1
filename: Optional[str] = None filename: str | None = None
def build_checkpoint_callback( def build_checkpoint_callback(
config: Optional[CheckpointConfig] = None, config: CheckpointConfig | None = None,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Callback: ) -> Callback:
config = config or CheckpointConfig() config = config or CheckpointConfig()

View File

@ -22,20 +22,20 @@ class PLTrainerConfig(BaseConfig):
accumulate_grad_batches: int = 1 accumulate_grad_batches: int = 1
deterministic: bool = True deterministic: bool = True
check_val_every_n_epoch: int = 1 check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto" devices: str | int = "auto"
enable_checkpointing: bool = True enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None gradient_clip_val: float | None = None
limit_train_batches: Optional[Union[int, float]] = None limit_train_batches: int | float | None = None
limit_test_batches: Optional[Union[int, float]] = None limit_test_batches: int | float | None = None
limit_val_batches: Optional[Union[int, float]] = None limit_val_batches: int | float | None = None
log_every_n_steps: Optional[int] = None log_every_n_steps: int | None = None
max_epochs: Optional[int] = 200 max_epochs: int | None = 200
min_epochs: Optional[int] = None min_epochs: int | None = None
max_steps: Optional[int] = None max_steps: int | None = None
min_steps: Optional[int] = None min_steps: int | None = None
max_time: Optional[str] = None max_time: str | None = None
precision: Optional[str] = None precision: str | None = None
val_check_interval: Optional[Union[int, float]] = None val_check_interval: int | float | None = None
class OptimizerConfig(BaseConfig): class OptimizerConfig(BaseConfig):
@ -57,6 +57,6 @@ class TrainingConfig(BaseConfig):
def load_train_config( def load_train_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
) -> TrainingConfig: ) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field) return load_config(path, schema=TrainingConfig, field=field)

View File

@ -44,10 +44,10 @@ class TrainingDataset(Dataset):
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
labeller: ClipLabeller, labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None, clipper: ClipperProtocol | None = None,
audio_augmentation: Optional[Augmentation] = None, audio_augmentation: Augmentation | None = None,
spectrogram_augmentation: Optional[Augmentation] = None, spectrogram_augmentation: Augmentation | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
): ):
self.clip_annotations = clip_annotations self.clip_annotations = clip_annotations
self.clipper = clipper self.clipper = clipper
@ -108,8 +108,8 @@ class ValidationDataset(Dataset):
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
labeller: ClipLabeller, labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None, clipper: ClipperProtocol | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
): ):
self.clip_annotations = clip_annotations self.clip_annotations = clip_annotations
self.labeller = labeller self.labeller = labeller
@ -165,11 +165,11 @@ class TrainLoaderConfig(BaseConfig):
def build_train_loader( def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
labeller: Optional[ClipLabeller] = None, labeller: ClipLabeller | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[TrainLoaderConfig] = None, config: TrainLoaderConfig | None = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
) -> DataLoader: ) -> DataLoader:
config = config or TrainLoaderConfig() config = config or TrainLoaderConfig()
@ -207,11 +207,11 @@ class ValLoaderConfig(BaseConfig):
def build_val_loader( def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
labeller: Optional[ClipLabeller] = None, labeller: ClipLabeller | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[ValLoaderConfig] = None, config: ValLoaderConfig | None = None,
num_workers: Optional[int] = None, num_workers: int | None = None,
): ):
logger.info("Building validation data loader...") logger.info("Building validation data loader...")
config = config or ValLoaderConfig() config = config or ValLoaderConfig()
@ -240,10 +240,10 @@ def build_val_loader(
def build_train_dataset( def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
labeller: Optional[ClipLabeller] = None, labeller: ClipLabeller | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[TrainLoaderConfig] = None, config: TrainLoaderConfig | None = None,
) -> TrainingDataset: ) -> TrainingDataset:
logger.info("Building training dataset...") logger.info("Building training dataset...")
config = config or TrainLoaderConfig() config = config or TrainLoaderConfig()
@ -291,10 +291,10 @@ def build_train_dataset(
def build_val_dataset( def build_val_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
labeller: Optional[ClipLabeller] = None, labeller: ClipLabeller | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[ValLoaderConfig] = None, config: ValLoaderConfig | None = None,
) -> ValidationDataset: ) -> ValidationDataset:
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or ValLoaderConfig() config = config or ValLoaderConfig()

View File

@ -42,10 +42,10 @@ class LabelConfig(BaseConfig):
def build_clip_labeler( def build_clip_labeler(
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
min_freq: float = MIN_FREQ, min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ, max_freq: float = MAX_FREQ,
config: Optional[LabelConfig] = None, config: LabelConfig | None = None,
) -> ClipLabeller: ) -> ClipLabeller:
"""Construct the final clip labelling function.""" """Construct the final clip labelling function."""
config = config or LabelConfig() config = config or LabelConfig()
@ -153,7 +153,7 @@ def generate_heatmaps(
def load_label_config( def load_label_config(
path: data.PathLike, field: Optional[str] = None path: data.PathLike, field: str | None = None
) -> LabelConfig: ) -> LabelConfig:
"""Load the heatmap label generation configuration from a file. """Load the heatmap label generation configuration from a file.

View File

@ -21,7 +21,7 @@ def train_loop(
model: DetectionModel, model: DetectionModel,
train_dataset: LabeledDataset[TrainInputs], train_dataset: LabeledDataset[TrainInputs],
validation_dataset: LabeledDataset[TrainInputs], validation_dataset: LabeledDataset[TrainInputs],
device: Optional[torch.device] = None, device: torch.device | None = None,
num_epochs: int = 100, num_epochs: int = 100,
learning_rate: float = 1e-4, learning_rate: float = 1e-4,
): ):

View File

@ -106,10 +106,10 @@ def standardize_low_freq(
def format_annotation( def format_annotation(
annotation: types.FileAnnotation, annotation: types.FileAnnotation,
events_of_interest: Optional[List[str]] = None, events_of_interest: List[str] | None = None,
name_replace: Optional[Dict[str, str]] = None, name_replace: Dict[str, str] | None = None,
convert_to_genus: bool = False, convert_to_genus: bool = False,
classes_to_ignore: Optional[List[str]] = None, classes_to_ignore: List[str] | None = None,
) -> types.FileAnnotation: ) -> types.FileAnnotation:
formated = [] formated = []
for aa in annotation["annotation"]: for aa in annotation["annotation"]:
@ -154,7 +154,7 @@ def format_annotation(
def get_class_names( def get_class_names(
data: List[types.FileAnnotation], data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None, classes_to_ignore: List[str] | None = None,
) -> Tuple[StringCounter, List[float]]: ) -> Tuple[StringCounter, List[float]]:
"""Extracts class names and their inverse frequencies. """Extracts class names and their inverse frequencies.
@ -201,9 +201,9 @@ def load_set_of_anns(
*, *,
convert_to_genus: bool = False, convert_to_genus: bool = False,
filter_issues: bool = False, filter_issues: bool = False,
events_of_interest: Optional[List[str]] = None, events_of_interest: List[str] | None = None,
classes_to_ignore: Optional[List[str]] = None, classes_to_ignore: List[str] | None = None,
name_replace: Optional[Dict[str, str]] = None, name_replace: Dict[str, str] | None = None,
) -> List[types.FileAnnotation]: ) -> List[types.FileAnnotation]:
# load the annotations # load the annotations
anns = [] anns = []

View File

@ -26,10 +26,10 @@ class TrainingModule(L.LightningModule):
def __init__( def __init__(
self, self,
config: Optional[dict] = None, config: dict | None = None,
t_max: int = 100, t_max: int = 100,
model: Optional[Model] = None, model: Model | None = None,
loss: Optional[torch.nn.Module] = None, loss: torch.nn.Module | None = None,
): ):
from batdetect2.config import validate_config from batdetect2.config import validate_config
@ -103,7 +103,7 @@ def load_model_from_checkpoint(
def build_training_module( def build_training_module(
config: Optional[dict] = None, config: dict | None = None,
t_max: int = 200, t_max: int = 200,
) -> TrainingModule: ) -> TrainingModule:
return TrainingModule(config=config, t_max=t_max) return TrainingModule(config=config, t_max=t_max)

View File

@ -151,7 +151,7 @@ class FocalLoss(nn.Module):
eps: float = 1e-5, eps: float = 1e-5,
beta: float = 4, beta: float = 4,
alpha: float = 2, alpha: float = 2,
class_weights: Optional[torch.Tensor] = None, class_weights: torch.Tensor | None = None,
mask_zero: bool = False, mask_zero: bool = False,
): ):
super().__init__() super().__init__()
@ -422,8 +422,8 @@ class LossFunction(nn.Module, LossProtocol):
def build_loss( def build_loss(
config: Optional[LossConfig] = None, config: LossConfig | None = None,
class_weights: Optional[np.ndarray] = None, class_weights: np.ndarray | None = None,
) -> nn.Module: ) -> nn.Module:
"""Factory function to build the main LossFunction from configuration. """Factory function to build the main LossFunction from configuration.

View File

@ -35,21 +35,21 @@ __all__ = [
def train( def train(
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
targets: Optional["TargetProtocol"] = None, targets: Optional["TargetProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None, labeller: Optional["ClipLabeller"] = None,
config: Optional["BatDetect2Config"] = None, config: Optional["BatDetect2Config"] = None,
trainer: Optional[Trainer] = None, trainer: Trainer | None = None,
train_workers: Optional[int] = None, train_workers: int | None = None,
val_workers: Optional[int] = None, val_workers: int | None = None,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Path | None = None,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
num_epochs: Optional[int] = None, num_epochs: int | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
seed: Optional[int] = None, seed: int | None = None,
): ):
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
@ -126,11 +126,11 @@ def train(
def build_trainer( def build_trainer(
config: "BatDetect2Config", config: "BatDetect2Config",
evaluator: "EvaluatorProtocol", evaluator: "EvaluatorProtocol",
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Path | None = None,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
num_epochs: Optional[int] = None, num_epochs: int | None = None,
) -> Trainer: ) -> Trainer:
trainer_conf = config.train.trainer trainer_conf = config.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(

View File

@ -240,7 +240,7 @@ class ProcessingConfiguration(TypedDict):
detection_threshold: float detection_threshold: float
"""Threshold for detection probability.""" """Threshold for detection probability."""
time_expansion: Optional[float] time_expansion: float | None
"""Time expansion factor of the processed recordings.""" """Time expansion factor of the processed recordings."""
top_n: int top_n: int
@ -249,7 +249,7 @@ class ProcessingConfiguration(TypedDict):
return_raw_preds: bool return_raw_preds: bool
"""Whether to return raw predictions.""" """Whether to return raw predictions."""
max_duration: Optional[float] max_duration: float | None
"""Maximum duration of audio file to process in seconds.""" """Maximum duration of audio file to process in seconds."""
nms_kernel_size: int nms_kernel_size: int

View File

@ -20,7 +20,7 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
self, self,
predictions: Sequence[T], predictions: Sequence[T],
path: PathLike, path: PathLike,
audio_dir: Optional[PathLike] = None, audio_dir: PathLike | None = None,
) -> None: ... ) -> None: ...
def load(self, path: PathLike) -> List[T]: ... def load(self, path: PathLike) -> List[T]: ...

View File

@ -28,19 +28,19 @@ __all__ = [
class MatchEvaluation: class MatchEvaluation:
clip: data.Clip clip: data.Clip
sound_event_annotation: Optional[data.SoundEventAnnotation] sound_event_annotation: data.SoundEventAnnotation | None
gt_det: bool gt_det: bool
gt_class: Optional[str] gt_class: str | None
gt_geometry: Optional[data.Geometry] gt_geometry: data.Geometry | None
pred_score: float pred_score: float
pred_class_scores: Dict[str, float] pred_class_scores: Dict[str, float]
pred_geometry: Optional[data.Geometry] pred_geometry: data.Geometry | None
affinity: float affinity: float
@property @property
def top_class(self) -> Optional[str]: def top_class(self) -> str | None:
if not self.pred_class_scores: if not self.pred_class_scores:
return None return None
@ -76,7 +76,7 @@ class MatcherProtocol(Protocol):
ground_truth: Sequence[data.Geometry], ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry], predictions: Sequence[data.Geometry],
scores: Sequence[float], scores: Sequence[float],
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ... ) -> Iterable[Tuple[int | None, int | None, float]]: ...
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)

View File

@ -42,7 +42,7 @@ class GeometryDecoder(Protocol):
""" """
def __call__( def __call__(
self, position: Position, size: Size, class_name: Optional[str] = None self, position: Position, size: Size, class_name: str | None = None
) -> data.Geometry: ... ) -> data.Geometry: ...
@ -93,5 +93,5 @@ class PostprocessorProtocol(Protocol):
def __call__( def __call__(
self, self,
output: ModelOutput, output: ModelOutput,
start_times: Optional[Sequence[float]] = None, start_times: Sequence[float] | None = None,
) -> List[ClipDetectionsTensor]: ... ) -> List[ClipDetectionsTensor]: ...

View File

@ -37,7 +37,7 @@ class AudioLoader(Protocol):
def load_file( def load_file(
self, self,
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess audio directly from a file path. """Load and preprocess audio directly from a file path.
@ -60,7 +60,7 @@ class AudioLoader(Protocol):
def load_recording( def load_recording(
self, self,
recording: data.Recording, recording: data.Recording,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object. """Load and preprocess the entire audio for a Recording object.
@ -90,7 +90,7 @@ class AudioLoader(Protocol):
def load_clip( def load_clip(
self, self,
clip: data.Clip, clip: data.Clip,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object. """Load and preprocess the audio segment defined by a Clip object.

Some files were not shown because too many files have changed in this diff Show More