Update type hints to python 3.10

This commit is contained in:
mbsantiago 2025-12-08 17:14:50 +00:00
parent 9c72537ddd
commit 2563f26ed3
104 changed files with 525 additions and 664 deletions

View File

@ -93,7 +93,7 @@ mlflow = ["mlflow>=3.1.1"]
[tool.ruff]
line-length = 79
target-version = "py39"
target-version = "py310"
[tool.ruff.format]
docstring-code-format = true
@ -107,7 +107,7 @@ convention = "numpy"
[tool.pyright]
include = ["src", "tests"]
pythonVersion = "3.9"
pythonVersion = "3.10"
pythonPlatform = "All"
exclude = [
"src/batdetect2/detector/",

View File

@ -165,7 +165,7 @@ def load_audio(
time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False,
max_duration: Optional[float] = None,
max_duration: float | None = None,
) -> np.ndarray:
"""Load audio from file.
@ -203,7 +203,7 @@ def load_audio(
def generate_spectrogram(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[SpectrogramParameters] = None,
config: SpectrogramParameters | None = None,
device: torch.device = DEVICE,
) -> torch.Tensor:
"""Generate spectrogram from audio array.
@ -240,7 +240,7 @@ def generate_spectrogram(
def process_file(
audio_file: str,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE,
) -> du.RunResults:
"""Process audio file with model.
@ -271,7 +271,7 @@ def process_spectrogram(
spec: torch.Tensor,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]:
"""Process spectrogram with model.
@ -312,7 +312,7 @@ def process_audio(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model.
@ -356,7 +356,7 @@ def process_audio(
def postprocess(
outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ProcessingConfiguration] = None,
config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]:
"""Postprocess model outputs.

View File

@ -67,22 +67,22 @@ class BatDetect2API:
def load_annotations(
self,
path: data.PathLike,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> Dataset:
return load_dataset_from_config(path, base_dir=base_dir)
def train(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
val_annotations: Sequence[data.ClipAnnotation] | None = None,
train_workers: int | None = None,
val_workers: int | None = None,
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: str | None = None,
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
):
train(
train_annotations=train_annotations,
@ -105,10 +105,10 @@ class BatDetect2API:
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None,
num_workers: int | None = None,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
experiment_name: str | None = None,
run_name: str | None = None,
save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
return evaluate(
@ -129,7 +129,7 @@ class BatDetect2API:
self,
annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
output_dir: Optional[data.PathLike] = None,
output_dir: data.PathLike | None = None,
):
clip_evals = self.evaluator.evaluate(
annotations,
@ -221,7 +221,7 @@ class BatDetect2API:
def process_files(
self,
audio_files: Sequence[data.PathLike],
num_workers: Optional[int] = None,
num_workers: int | None = None,
) -> List[BatDetect2Prediction]:
return process_file_list(
self.model,
@ -236,8 +236,8 @@ class BatDetect2API:
def process_clips(
self,
clips: Sequence[data.Clip],
batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
batch_size: int | None = None,
num_workers: int | None = None,
) -> List[BatDetect2Prediction]:
return run_batch_inference(
self.model,
@ -254,9 +254,9 @@ class BatDetect2API:
self,
predictions: Sequence[BatDetect2Prediction],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
format: Optional[str] = None,
config: Optional[OutputFormatConfig] = None,
audio_dir: data.PathLike | None = None,
format: str | None = None,
config: OutputFormatConfig | None = None,
):
formatter = self.formatter
@ -331,7 +331,7 @@ class BatDetect2API:
def from_checkpoint(
cls,
path: data.PathLike,
config: Optional[BatDetect2Config] = None,
config: BatDetect2Config | None = None,
):
model, stored_config = load_model_from_checkpoint(path)

View File

@ -245,16 +245,12 @@ class FixedDurationClip:
ClipConfig = Annotated[
Union[
RandomClipConfig,
PaddedClipConfig,
FixedDurationClipConfig,
],
RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
Field(discriminator="name"),
]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
config = config or RandomClipConfig()
logger.opt(lazy=True).debug(

View File

@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
resample: ResampleConfig = Field(default_factory=ResampleConfig)
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
config: ResampleConfig | None = None,
):
self.samplerate = samplerate
self.config = config or ResampleConfig()
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
@ -112,9 +112,9 @@ class SoundEventAudioLoader(AudioLoader):
def load_file_audio(
path: data.PathLike,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
@ -136,9 +136,9 @@ def load_file_audio(
def load_recording_audio(
recording: data.Recording,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
@ -158,9 +158,9 @@ def load_recording_audio(
def load_clip_audio(
clip: data.Clip,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""

View File

@ -34,9 +34,9 @@ def data(): ...
)
def summary(
dataset_config: Path,
field: Optional[str] = None,
targets_path: Optional[Path] = None,
base_dir: Optional[Path] = None,
field: str | None = None,
targets_path: Path | None = None,
base_dir: Path | None = None,
):
from batdetect2.data import compute_class_summary, load_dataset_from_config
from batdetect2.targets import load_targets
@ -83,9 +83,9 @@ def summary(
)
def convert(
dataset_config: Path,
field: Optional[str] = None,
field: str | None = None,
output: Path = Path("annotations.json"),
base_dir: Optional[Path] = None,
base_dir: Path | None = None,
):
"""Convert a dataset config file to soundevent format."""
from soundevent import data, io

View File

@ -25,11 +25,11 @@ def evaluate_command(
model_path: Path,
test_dataset: Path,
base_dir: Path,
config_path: Optional[Path],
config_path: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
num_workers: int | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config

View File

@ -26,19 +26,19 @@ __all__ = ["train_command"]
@click.option("--seed", type=int)
def train_command(
train_dataset: Path,
val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None,
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
num_epochs: Optional[int] = None,
val_dataset: Path | None = None,
model_path: Path | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
config: Path | None = None,
targets_config: Path | None = None,
config_field: str | None = None,
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
experiment_name: str | None = None,
run_name: str | None = None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import (

View File

@ -4,7 +4,7 @@ import json
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, List
import numpy as np
from soundevent import data
@ -17,7 +17,7 @@ from batdetect2.types import (
FileAnnotation,
)
PathLike = Union[Path, str, os.PathLike]
PathLike = Path | str | os.PathLike
__all__ = [
"convert_to_annotation_group",
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int]
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
def file_annotation_to_clip(
file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None,
audio_dir: PathLike | None = None,
label_key: str = "class",
) -> data.Clip:
"""Convert file annotation to recording."""

View File

@ -43,7 +43,7 @@ class BatDetect2Config(BaseConfig):
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
def validate_config(config: Optional[dict]) -> BatDetect2Config:
def validate_config(config: dict | None) -> BatDetect2Config:
if config is None:
return BatDetect2Config()
@ -52,6 +52,6 @@ def validate_config(config: Optional[dict]) -> BatDetect2Config:
def load_full_config(
path: PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> BatDetect2Config:
return load_config(path, schema=BatDetect2Config, field=field)

View File

@ -86,8 +86,8 @@ def adjust_width(
def slice_tensor(
tensor: torch.Tensor,
start: Optional[int] = None,
end: Optional[int] = None,
start: int | None = None,
end: int | None = None,
dim: int = -1,
) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim

View File

@ -128,7 +128,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
def load_config(
path: PathLike,
schema: Type[T],
field: Optional[str] = None,
field: str | None = None,
) -> T:
"""Load and validate configuration data from a file against a schema.

View File

@ -43,11 +43,7 @@ __all__ = [
AnnotationFormats = Annotated[
Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
],
BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
Field(discriminator="format"),
]
"""Type Alias representing all supported data source configurations.
@ -63,7 +59,7 @@ source configuration represents.
def load_annotated_dataset(
dataset: AnnotatedDataset,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration.

View File

@ -77,14 +77,14 @@ class AOEFAnnotations(AnnotatedDataset):
annotations_path: Path
filter: Optional[AnnotationTaskFilter] = Field(
filter: AnnotationTaskFilter | None = Field(
default_factory=AnnotationTaskFilter
)
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet:
"""Load annotations from an AnnotationSet or AnnotationProject file.

View File

@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
import json
import os
from pathlib import Path
from typing import Literal, Optional, Union
from typing import Literal
from loguru import logger
from pydantic import Field, ValidationError
@ -43,7 +43,7 @@ from batdetect2.data.annotations.legacy import (
)
from batdetect2.data.annotations.types import AnnotatedDataset
PathLike = Union[Path, str, os.PathLike]
PathLike = Path | str | os.PathLike
__all__ = [
@ -102,7 +102,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
format: Literal["batdetect2"] = "batdetect2"
annotations_dir: Path
filter: Optional[AnnotationFilter] = Field(
filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter,
)
@ -133,14 +133,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
format: Literal["batdetect2_file"] = "batdetect2_file"
annotations_path: Path
filter: Optional[AnnotationFilter] = Field(
filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter,
)
def load_batdetect2_files_annotated_dataset(
dataset: BatDetect2FilesAnnotations,
base_dir: Optional[PathLike] = None,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
@ -244,7 +244,7 @@ def load_batdetect2_files_annotated_dataset(
def load_batdetect2_merged_annotated_dataset(
dataset: BatDetect2MergedAnnotations,
base_dir: Optional[PathLike] = None,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.

View File

@ -3,12 +3,12 @@
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, List
from pydantic import BaseModel, Field
from soundevent import data
PathLike = Union[Path, str, os.PathLike]
PathLike = Path | str | os.PathLike
__all__ = []
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
)
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int]
@ -130,7 +130,7 @@ def get_sound_event_tags(
def file_annotation_to_clip(
file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None,
audio_dir: PathLike | None = None,
label_key: str = "class",
) -> data.Clip:
"""Convert file annotation to recording."""

View File

@ -264,16 +264,7 @@ class Not:
SoundEventConditionConfig = Annotated[
Union[
HasTagConfig,
HasAllTagsConfig,
HasAnyTagConfig,
DurationConfig,
FrequencyConfig,
AllOfConfig,
AnyOfConfig,
NotConfig,
],
HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig,
Field(discriminator="name"),
]

View File

@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
description: str
sources: List[AnnotationFormats]
sound_event_filter: Optional[SoundEventConditionConfig] = None
sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: List[SoundEventTransformConfig] = Field(
default_factory=list
)
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
def load_dataset(
config: DatasetConfig,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = []
@ -161,14 +161,14 @@ def insert_source_tag(
)
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
def load_dataset_config(path: data.PathLike, field: str | None = None):
return load_config(path=path, schema=DatasetConfig, field=field)
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[data.PathLike] = None,
field: str | None = None,
base_dir: data.PathLike | None = None,
) -> Dataset:
"""Load dataset annotation metadata from a configuration file.
@ -215,9 +215,9 @@ def load_dataset_from_config(
def save_dataset(
dataset: Dataset,
path: data.PathLike,
name: Optional[str] = None,
description: Optional[str] = None,
audio_dir: Optional[Path] = None,
name: str | None = None,
description: str | None = None,
audio_dir: Path | None = None,
) -> None:
"""Save a loaded dataset (list of ClipAnnotations) to a file.

View File

@ -10,7 +10,7 @@ from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events(
dataset: Dataset,
targets: TargetProtocol,
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset.
Parameters

View File

@ -24,19 +24,14 @@ __all__ = [
OutputFormatConfig = Annotated[
Union[
BatDetect2OutputConfig,
ParquetOutputConfig,
SoundEventOutputConfig,
RawOutputConfig,
],
BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig,
Field(discriminator="name"),
]
def build_output_formatter(
targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None,
targets: TargetProtocol | None = None,
config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol:
"""Construct the final output formatter."""
from batdetect2.targets import build_targets
@ -48,9 +43,9 @@ def build_output_formatter(
def get_output_formatter(
name: Optional[str] = None,
targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None,
name: str | None = None,
targets: TargetProtocol | None = None,
config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol:
"""Get the output formatter by name."""
@ -71,9 +66,9 @@ def get_output_formatter(
def load_predictions(
path: PathLike,
format: Optional[str] = "raw",
config: Optional[OutputFormatConfig] = None,
targets: Optional[TargetProtocol] = None,
format: str | None = "raw",
config: OutputFormatConfig | None = None,
targets: TargetProtocol | None = None,
):
"""Load predictions from a file."""
from batdetect2.targets import build_targets

View File

@ -123,7 +123,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
self,
predictions: Sequence[FileAnnotation],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> None:
path = Path(path)

View File

@ -53,7 +53,7 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
self,
predictions: Sequence[BatDetect2Prediction],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> None:
path = Path(path)

View File

@ -55,7 +55,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
self,
predictions: Sequence[BatDetect2Prediction],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> None:
path = Path(path)
@ -84,7 +84,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def pred_to_xr(
self,
prediction: BatDetect2Prediction,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> xr.Dataset:
clip = prediction.clip
recording = clip.recording

View File

@ -18,16 +18,16 @@ from batdetect2.typing import (
class SoundEventOutputConfig(BaseConfig):
name: Literal["soundevent"] = "soundevent"
top_k: Optional[int] = 1
min_score: Optional[float] = None
top_k: int | None = 1
min_score: float | None = None
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
def __init__(
self,
targets: TargetProtocol,
top_k: Optional[int] = 1,
min_score: Optional[float] = 0,
top_k: int | None = 1,
min_score: float | None = 0,
):
self.targets = targets
self.top_k = top_k
@ -45,7 +45,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
self,
predictions: Sequence[data.ClipPrediction],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> None:
run = data.PredictionSet(clip_predictions=list(predictions))

View File

@ -14,7 +14,7 @@ def split_dataset_by_recordings(
dataset: Dataset,
targets: TargetProtocol,
train_size: float = 0.75,
random_state: Optional[int] = None,
random_state: int | None = None,
) -> Tuple[Dataset, Dataset]:
recordings = extract_recordings_df(dataset)

View File

@ -142,7 +142,7 @@ class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str
value_mapping: Dict[str, str]
target_key: Optional[str] = None
target_key: str | None = None
class MapTagValue:
@ -150,7 +150,7 @@ class MapTagValue:
self,
tag_key: str,
value_mapping: Dict[str, str],
target_key: Optional[str] = None,
target_key: str | None = None,
):
self.tag_key = tag_key
self.value_mapping = value_mapping
@ -221,13 +221,7 @@ class ApplyAll:
SoundEventTransformConfig = Annotated[
Union[
SetFrequencyBoundConfig,
ReplaceTagConfig,
MapTagValueConfig,
ApplyIfConfig,
ApplyAllConfig,
],
SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig,
Field(discriminator="name"),
]

View File

@ -86,7 +86,7 @@ def compute_bandwidth(
def compute_max_power_bb(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -131,7 +131,7 @@ def compute_max_power_bb(
def compute_max_power(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -157,7 +157,7 @@ def compute_max_power(
def compute_max_power_first(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -184,7 +184,7 @@ def compute_max_power_first(
def compute_max_power_second(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -211,7 +211,7 @@ def compute_max_power_second(
def compute_call_interval(
prediction: types.Prediction,
previous: Optional[types.Prediction] = None,
previous: types.Prediction | None = None,
**_,
) -> float:
"""Compute time between this call and the previous call in seconds."""

View File

@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
def get_params(
make_dirs: bool = False,
exps_dir: str = "../../experiments/",
model_name: Optional[str] = None,
experiment: Union[Path, str, None] = None,
model_name: str | None = None,
experiment: Path | str | None = None,
**kwargs,
) -> TrainingParameters:
experiments_dir = Path(exps_dir)

View File

@ -151,7 +151,7 @@ def run_nms(
def non_max_suppression(
heat: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]],
kernel_size: int | Tuple[int, int],
):
# kernel can be an int or list/tuple
if isinstance(kernel_size, int):

View File

@ -213,18 +213,13 @@ class GeometricIOU(AffinityFunction):
AffinityConfig = Annotated[
Union[
TimeAffinityConfig,
IntervalIOUConfig,
BBoxIOUConfig,
GeometricIOUConfig,
],
TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig,
Field(discriminator="name"),
]
def build_affinity_function(
config: Optional[AffinityConfig] = None,
config: AffinityConfig | None = None,
) -> AffinityFunction:
config = config or GeometricIOUConfig()
return affinity_functions.build(config)

View File

@ -51,6 +51,6 @@ def get_default_eval_config() -> EvaluationConfig:
def load_evaluation_config(
path: data.PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> EvaluationConfig:
return load_config(path, schema=EvaluationConfig, field=field)

View File

@ -39,8 +39,8 @@ class TestDataset(Dataset[TestExample]):
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
clipper: ClipperProtocol | None = None,
audio_dir: data.PathLike | None = None,
):
self.clip_annotations = list(clip_annotations)
self.clipper = clipper
@ -78,10 +78,10 @@ class TestLoaderConfig(BaseConfig):
def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
num_workers: Optional[int] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TestLoaderConfig | None = None,
num_workers: int | None = None,
) -> DataLoader[TestExample]:
logger.info("Building test data loader...")
config = config or TestLoaderConfig()
@ -109,9 +109,9 @@ def build_test_loader(
def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TestLoaderConfig | None = None,
) -> TestDataset:
logger.info("Building training dataset...")
config = config or TestLoaderConfig()

View File

@ -34,10 +34,10 @@ def evaluate(
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
formatter: Optional["OutputFormatterProtocol"] = None,
num_workers: Optional[int] = None,
num_workers: int | None = None,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
from batdetect2.config import BatDetect2Config

View File

@ -51,8 +51,8 @@ class Evaluator:
def build_evaluator(
config: Optional[Union[EvaluationConfig, dict]] = None,
targets: Optional[TargetProtocol] = None,
config: EvaluationConfig | dict | None = None,
targets: TargetProtocol | None = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()

View File

@ -35,9 +35,9 @@ def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
scores: Optional[Sequence[float]] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
scores: Sequence[float] | None = None,
targets: TargetProtocol | None = None,
matcher: MatcherProtocol | None = None,
) -> ClipMatches:
if matcher is None:
matcher = build_matcher()
@ -151,7 +151,7 @@ def match_start_times(
predictions: Sequence[data.Geometry],
scores: Sequence[float],
distance_threshold: float = 0.01,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
) -> Iterable[Tuple[int | None, int | None, float]]:
if not ground_truth:
for index in range(len(predictions)):
yield index, None, 0
@ -287,7 +287,7 @@ def greedy_match(
scores: Sequence[float],
affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
) -> Iterable[Tuple[int | None, int | None, float]]:
"""Performs a greedy, one-to-one matching of source to target geometries.
Iterates through source geometries, prioritizing by score if provided. Each
@ -514,12 +514,7 @@ class OptimalMatcher(MatcherProtocol):
MatchConfig = Annotated[
Union[
GreedyMatchConfig,
StartTimeMatchConfig,
OptimalMatchConfig,
GreedyAffinityMatchConfig,
],
GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig,
Field(discriminator="name"),
]
@ -558,7 +553,7 @@ def compute_affinity_matrix(
def select_optimal_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape
gts = set(range(num_gt))
preds = set(range(num_pred))
@ -588,7 +583,7 @@ def select_optimal_matches(
def select_greedy_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape
unmatched_pred = set(range(num_pred))
@ -612,6 +607,6 @@ def select_greedy_matches(
yield pred_idx, None, 0
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategies.build(config)

View File

@ -36,13 +36,13 @@ __all__ = [
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
is_prediction: bool
is_ground_truth: bool
is_generic: bool
true_class: Optional[str]
true_class: str | None
score: float
@ -61,16 +61,16 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
class BaseClassificationConfig(BaseConfig):
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
include: List[str] | None = None
exclude: List[str] | None = None
class BaseClassificationMetric:
def __init__(
self,
targets: TargetProtocol,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
include: List[str] | None = None,
exclude: List[str] | None = None,
):
self.targets = targets
self.include = include
@ -100,8 +100,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "average_precision",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
include: List[str] | None = None,
exclude: List[str] | None = None,
):
super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions
@ -169,8 +169,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "roc_auc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
include: List[str] | None = None,
exclude: List[str] | None = None,
):
self.targets = targets
self.ignore_non_predictions = ignore_non_predictions
@ -225,10 +225,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
ClassificationMetricConfig = Annotated[
Union[
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
Field(discriminator="name"),
]

View File

@ -123,10 +123,7 @@ class ClipClassificationROCAUC:
ClipClassificationMetricConfig = Annotated[
Union[
ClipClassificationAveragePrecisionConfig,
ClipClassificationROCAUCConfig,
],
ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
Field(discriminator="name"),
]

View File

@ -159,12 +159,7 @@ class ClipDetectionPrecision:
ClipDetectionMetricConfig = Annotated[
Union[
ClipDetectionAveragePrecisionConfig,
ClipDetectionROCAUCConfig,
ClipDetectionRecallConfig,
ClipDetectionPrecisionConfig,
],
ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig,
Field(discriminator="name"),
]

View File

@ -11,7 +11,7 @@ __all__ = [
def compute_precision_recall(
y_true,
y_score,
num_positives: Optional[int] = None,
num_positives: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true)
y_score = np.array(y_score)
@ -41,7 +41,7 @@ def compute_precision_recall(
def average_precision(
y_true,
y_score,
num_positives: Optional[int] = None,
num_positives: int | None = None,
) -> float:
if num_positives == 0:
return np.nan

View File

@ -28,8 +28,8 @@ __all__ = [
@dataclass
class MatchEval:
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
is_prediction: bool
is_ground_truth: bool
@ -212,12 +212,7 @@ class DetectionPrecision:
DetectionMetricConfig = Annotated[
Union[
DetectionAveragePrecisionConfig,
DetectionROCAUCConfig,
DetectionRecallConfig,
DetectionPrecisionConfig,
],
DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig,
Field(discriminator="name"),
]

View File

@ -31,14 +31,14 @@ __all__ = [
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
is_ground_truth: bool
is_generic: bool
is_prediction: bool
pred_class: Optional[str]
true_class: Optional[str]
pred_class: str | None
true_class: str | None
score: float
@ -301,13 +301,7 @@ class BalancedAccuracy:
TopClassMetricConfig = Annotated[
Union[
TopClassAveragePrecisionConfig,
TopClassROCAUCConfig,
TopClassRecallConfig,
TopClassPrecisionConfig,
BalancedAccuracyConfig,
],
TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig,
Field(discriminator="name"),
]

View File

@ -10,7 +10,7 @@ from batdetect2.typing import TargetProtocol
class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
title: str | None = None
figsize: tuple[int, int] = (10, 10)
dpi: int = 100
@ -21,7 +21,7 @@ class BasePlot:
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None,
title: str | None = None,
dpi: int = 100,
theme: str = "default",
):

View File

@ -45,7 +45,7 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve"
title: str | None = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -108,7 +108,7 @@ class PRCurve(BasePlot):
class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve"
title: str | None = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -181,7 +181,7 @@ class ThresholdPrecisionCurve(BasePlot):
class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve"
title: str | None = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -254,7 +254,7 @@ class ThresholdRecallCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve"
title: str | None = "Classification ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -326,12 +326,7 @@ class ROCCurve(BasePlot):
ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPrecisionCurveConfig,
ThresholdRecallCurveConfig,
],
PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig,
Field(discriminator="name"),
]

View File

@ -44,7 +44,7 @@ clip_classification_plots: Registry[
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve"
title: str | None = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Classification ROC Curve"
title: str | None = "Clip Classification ROC Curve"
separate_figures: bool = False
@ -174,10 +174,7 @@ class ROCCurve(BasePlot):
ClipClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
],
PRCurveConfig | ROCCurveConfig,
Field(discriminator="name"),
]

View File

@ -41,7 +41,7 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve"
title: str | None = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot):
@ -74,7 +74,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Detection ROC Curve"
title: str | None = "Clip Detection ROC Curve"
class ROCCurve(BasePlot):
@ -107,7 +107,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Clip Detection Score Distribution"
title: str | None = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot):
@ -147,11 +147,7 @@ class ScoreDistributionPlot(BasePlot):
ClipDetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
Field(discriminator="name"),
]

View File

@ -37,7 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve"
title: str | None = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -100,7 +100,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve"
title: str | None = "Detection ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -159,7 +159,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution"
title: str | None = "Detection Score Distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -226,7 +226,7 @@ class ScoreDistributionPlot(BasePlot):
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
title: Optional[str] = "Example Detection"
title: str | None = "Example Detection"
figsize: tuple[int, int] = (10, 4)
num_examples: int = 5
threshold: float = 0.2
@ -292,12 +292,7 @@ class ExampleDetectionPlot(BasePlot):
DetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig,
Field(discriminator="name"),
]

View File

@ -44,7 +44,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve"
title: str | None = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -111,7 +111,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve"
title: str | None = "Top Class ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -173,7 +173,7 @@ class ROCCurve(BasePlot):
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix"
title: str | None = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
@ -257,7 +257,7 @@ class ConfusionMatrix(BasePlot):
class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification"
label: str = "example_classification"
title: Optional[str] = "Example Classification"
title: str | None = "Example Classification"
num_examples: int = 4
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
@ -348,12 +348,7 @@ class ExampleClassificationPlot(BasePlot):
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
ExampleClassificationPlotConfig,
],
PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig,
Field(discriminator="name"),
]

View File

@ -96,7 +96,7 @@ def extract_matches_dataframe(
EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name")
FullEvaluationTableConfig, Field(discriminator="name")
]

View File

@ -26,20 +26,14 @@ __all__ = [
TaskConfig = Annotated[
Union[
ClassificationTaskConfig,
DetectionTaskConfig,
ClipDetectionTaskConfig,
ClipClassificationTaskConfig,
TopClassDetectionTaskConfig,
],
ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig,
Field(discriminator="name"),
]
def build_task(
config: TaskConfig,
targets: Optional[TargetProtocol] = None,
targets: TargetProtocol | None = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
return tasks_registry.build(config, targets)
@ -49,8 +43,8 @@ def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
task: Optional["str"] = None,
targets: Optional[TargetProtocol] = None,
config: Optional[Union[TaskConfig, dict]] = None,
targets: TargetProtocol | None = None,
config: TaskConfig | dict | None = None,
):
if isinstance(config, BaseTaskConfig):
task_obj = build_task(config, targets)

View File

@ -67,9 +67,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str,
ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
):
self.matcher = matcher
self.metrics = metrics
@ -147,9 +145,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
config: BaseTaskConfig,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
**kwargs,
):
matcher = build_matcher(config.matching_strategy)

View File

@ -98,8 +98,8 @@ def load_annotations(
dataset_name: str,
ann_path: str,
audio_path: str,
classes_to_ignore: Optional[List[str]] = None,
events_of_interest: Optional[List[str]] = None,
classes_to_ignore: List[str] | None = None,
events_of_interest: List[str] | None = None,
) -> List[types.FileAnnotation]:
train_sets: List[types.DatasetDict] = []
train_sets.append(

View File

@ -13,7 +13,7 @@ from batdetect2 import types
def print_dataset_stats(
data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
classes_to_ignore: List[str] | None = None,
) -> Counter[str]:
print("Num files:", len(data))
counts, _ = tu.get_class_names(data, classes_to_ignore)

View File

@ -28,8 +28,8 @@ def run_batch_inference(
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
batch_size: Optional[int] = None,
num_workers: int | None = None,
batch_size: int | None = None,
) -> List[BatDetect2Prediction]:
from batdetect2.config import BatDetect2Config
@ -69,7 +69,7 @@ def process_file_list(
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: Optional[int] = None,
num_workers: int | None = None,
) -> List[BatDetect2Prediction]:
clip_config = config.inference.clipping
clips = get_clips_from_files(

View File

@ -36,7 +36,7 @@ class InferenceDataset(Dataset[DatasetItem]):
clips: Sequence[data.Clip],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
):
self.clips = list(clips)
self.preprocessor = preprocessor
@ -66,11 +66,11 @@ class InferenceLoaderConfig(BaseConfig):
def build_inference_loader(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[InferenceLoaderConfig] = None,
num_workers: Optional[int] = None,
batch_size: Optional[int] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: InferenceLoaderConfig | None = None,
num_workers: int | None = None,
batch_size: int | None = None,
) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig()
@ -95,8 +95,8 @@ def build_inference_loader(
def build_inference_dataset(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
) -> InferenceDataset:
if audio_loader is None:
audio_loader = build_audio_loader()

View File

@ -49,14 +49,14 @@ def enable_logging(level: int):
class BaseLoggerConfig(BaseConfig):
log_dir: Path = DEFAULT_LOGS_DIR
experiment_name: Optional[str] = None
run_name: Optional[str] = None
experiment_name: str | None = None
run_name: str | None = None
class DVCLiveConfig(BaseLoggerConfig):
name: Literal["dvclive"] = "dvclive"
prefix: str = ""
log_model: Union[bool, Literal["all"]] = False
log_model: bool | Literal["all"] = False
monitor_system: bool = False
@ -72,18 +72,13 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
class MLFlowLoggerConfig(BaseLoggerConfig):
name: Literal["mlflow"] = "mlflow"
tracking_uri: Optional[str] = "http://localhost:5000"
tags: Optional[dict[str, Any]] = None
tracking_uri: str | None = "http://localhost:5000"
tags: dict[str, Any] | None = None
log_model: bool = False
LoggerConfig = Annotated[
Union[
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig,
Field(discriminator="name"),
]
@ -95,17 +90,17 @@ class LoggerBuilder(Protocol, Generic[T]):
def __call__(
self,
config: T,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger: ...
def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
try:
from dvclive.lightning import DVCLiveLogger # type: ignore
@ -130,9 +125,9 @@ def create_dvclive_logger(
def create_csv_logger(
config: CSVLoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
from lightning.pytorch.loggers import CSVLogger
@ -159,9 +154,9 @@ def create_csv_logger(
def create_tensorboard_logger(
config: TensorBoardLoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger
@ -191,9 +186,9 @@ def create_tensorboard_logger(
def create_mlflow_logger(
config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: data.PathLike | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
@ -232,9 +227,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
def build_logger(
config: LoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
@ -257,7 +252,7 @@ def build_logger(
PlotLogger = Callable[[str, Figure, int], None]
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
def get_image_logger(logger: Logger) -> PlotLogger | None:
if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure
@ -282,7 +277,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
TableLogger = Callable[[str, pd.DataFrame, int], None]
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
def get_table_logger(logger: Logger) -> TableLogger | None:
if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir))

View File

@ -43,10 +43,7 @@ from batdetect2.models.bottleneck import (
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.config import (
BackboneConfig,
load_backbone_config,
)
from batdetect2.models.config import BackboneConfig, load_backbone_config
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
@ -122,10 +119,10 @@ class Model(torch.nn.Module):
def build_model(
config: Optional[BackboneConfig] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
config: BackboneConfig | None = None,
targets: TargetProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
):
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor

View File

@ -78,8 +78,8 @@ class Bottleneck(nn.Module):
input_height: int,
in_channels: int,
out_channels: int,
bottleneck_channels: Optional[int] = None,
layers: Optional[List[torch.nn.Module]] = None,
bottleneck_channels: int | None = None,
layers: List[torch.nn.Module] | None = None,
) -> None:
"""Initialize the base Bottleneck layer."""
super().__init__()
@ -127,7 +127,7 @@ class Bottleneck(nn.Module):
BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,],
SelfAttentionConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
@ -171,7 +171,7 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
def build_bottleneck(
input_height: int,
in_channels: int,
config: Optional[BottleneckConfig] = None,
config: BottleneckConfig | None = None,
) -> nn.Module:
"""Factory function to build the Bottleneck module from configuration.

View File

@ -63,7 +63,7 @@ class BackboneConfig(BaseConfig):
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.

View File

@ -41,12 +41,7 @@ __all__ = [
]
DecoderLayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
LayerGroupConfig,
],
ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
@ -216,7 +211,7 @@ convolutional block.
def build_decoder(
in_channels: int,
input_height: int,
config: Optional[DecoderConfig] = None,
config: DecoderConfig | None = None,
) -> Decoder:
"""Factory function to build a Decoder instance from configuration.

View File

@ -127,7 +127,7 @@ class Detector(DetectionModel):
def build_detector(
num_classes: int, config: Optional[BackboneConfig] = None
num_classes: int, config: BackboneConfig | None = None
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.

View File

@ -43,12 +43,7 @@ __all__ = [
]
EncoderLayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
LayerGroupConfig,
],
ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
@ -252,7 +247,7 @@ Specifies an architecture typically used in BatDetect2:
def build_encoder(
in_channels: int,
input_height: int,
config: Optional[EncoderConfig] = None,
config: EncoderConfig | None = None,
) -> Encoder:
"""Factory function to build an Encoder instance from configuration.

View File

@ -15,10 +15,10 @@ __all__ = [
def plot_clip_annotation(
clip_annotation: data.ClipAnnotation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
add_points: bool = False,
cmap: str = "gray",
alpha: float = 1,
@ -50,8 +50,8 @@ def plot_clip_annotation(
def plot_anchor_points(
clip_annotation: data.ClipAnnotation,
targets: TargetProtocol,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
size: int = 1,
color: str = "red",
marker: str = "x",

View File

@ -17,10 +17,10 @@ __all__ = [
def plot_clip_prediction(
clip_prediction: data.ClipPrediction,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
add_legend: bool = False,
spec_cmap: str = "gray",
linewidth: float = 1,
@ -50,14 +50,14 @@ def plot_clip_prediction(
def plot_predictions(
predictions: Iterable[data.SoundEventPrediction],
ax: Optional[Axes] = None,
ax: Axes | None = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
color_mapper: TagColorMapper | None = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
legend: bool = True,
max_alpha: float = 0.5,
color: Optional[str] = None,
color: str | None = None,
**kwargs,
):
"""Plot an prediction."""
@ -88,14 +88,14 @@ def plot_predictions(
def plot_prediction(
prediction: data.SoundEventPrediction,
ax: Optional[Axes] = None,
ax: Axes | None = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
color_mapper: TagColorMapper | None = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
max_alpha: float = 0.5,
alpha: Optional[float] = None,
color: Optional[str] = None,
alpha: float | None = None,
color: str | None = None,
**kwargs,
) -> Axes:
"""Plot an annotation."""

View File

@ -17,11 +17,11 @@ __all__ = [
def plot_clip(
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
spec_cmap: str = "gray",
) -> Axes:
if ax is None:

View File

@ -13,8 +13,8 @@ __all__ = [
def create_ax(
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
**kwargs,
) -> axes.Axes:
"""Create a new axis if none is provided"""
@ -25,17 +25,17 @@ def create_ax(
def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray],
start_time: Optional[float] = None,
end_time: Optional[float] = None,
min_freq: Optional[float] = None,
max_freq: Optional[float] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
spec: torch.Tensor | np.ndarray,
start_time: float | None = None,
end_time: float | None = None,
min_freq: float | None = None,
max_freq: float | None = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_colorbar: bool = False,
colorbar_kwargs: Optional[dict] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
colorbar_kwargs: dict | None = None,
vmin: float | None = None,
vmax: float | None = None,
cmap="gray",
) -> axes.Axes:
if isinstance(spec, torch.Tensor):

View File

@ -19,9 +19,9 @@ __all__ = [
def plot_clip_detections(
clip_eval: ClipEval,
figsize: tuple[int, int] = (10, 10),
ax: Optional[axes.Axes] = None,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
ax: axes.Axes | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
threshold: float = 0.2,
add_legend: bool = True,
add_title: bool = True,

View File

@ -20,11 +20,11 @@ def plot_match_gallery(
false_positives: Sequence[MatchProtocol],
false_negatives: Sequence[MatchProtocol],
cross_triggers: Sequence[MatchProtocol],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
n_examples: int = 5,
duration: float = 0.1,
fig: Optional[Figure] = None,
fig: Figure | None = None,
):
if fig is None:
fig = plt.figure(figsize=(20, 20))

View File

@ -12,13 +12,13 @@ from batdetect2.plotting.common import create_ax
def plot_detection_heatmap(
heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None,
heatmap: torch.Tensor | np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10),
threshold: Optional[float] = None,
threshold: float | None = None,
alpha: float = 1,
cmap: Union[str, Colormap] = "jet",
color: Optional[str] = None,
cmap: str | Colormap = "jet",
color: str | None = None,
) -> axes.Axes:
ax = create_ax(ax, figsize=figsize)
@ -48,13 +48,13 @@ def plot_detection_heatmap(
def plot_classification_heatmap(
heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None,
heatmap: torch.Tensor | np.ndarray,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] = (10, 10),
class_names: Optional[List[str]] = None,
threshold: Optional[float] = 0.1,
class_names: List[str] | None = None,
threshold: float | None = 0.1,
alpha: float = 1,
cmap: Union[str, Colormap] = "tab20",
cmap: str | Colormap = "tab20",
):
ax = create_ax(ax, figsize=figsize)

View File

@ -24,10 +24,10 @@ __all__ = [
def spectrogram(
spec: Union[torch.Tensor, np.ndarray],
config: Optional[ProcessingConfiguration] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
spec: torch.Tensor | np.ndarray,
config: ProcessingConfiguration | None = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
cmap: str = "plasma",
start_time: float = 0,
) -> axes.Axes:
@ -103,11 +103,11 @@ def spectrogram(
def spectrogram_with_detections(
spec: Union[torch.Tensor, np.ndarray],
spec: torch.Tensor | np.ndarray,
dets: List[Annotation],
config: Optional[ProcessingConfiguration] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
config: ProcessingConfiguration | None = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
cmap: str = "plasma",
with_names: bool = True,
start_time: float = 0,
@ -168,8 +168,8 @@ def spectrogram_with_detections(
def detections(
dets: List[Annotation],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
with_names: bool = True,
**kwargs,
) -> axes.Axes:
@ -213,8 +213,8 @@ def detections(
def detection(
det: Annotation,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
linewidth: float = 1,
edgecolor: str = "w",
facecolor: str = "none",

View File

@ -21,10 +21,10 @@ __all__ = [
class MatchProtocol(Protocol):
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
score: float
true_class: Optional[str]
true_class: str | None
DEFAULT_DURATION = 0.05
@ -38,11 +38,11 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_false_positive_match(
match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
@ -52,7 +52,7 @@ def plot_false_positive_match(
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
fontsize: float | str = "small",
) -> Axes:
assert match.pred is not None
@ -109,11 +109,11 @@ def plot_false_positive_match(
def plot_false_negative_match(
match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,
add_spectrogram: bool = True,
add_points: bool = False,
@ -169,11 +169,11 @@ def plot_false_negative_match(
def plot_true_positive_match(
match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
preprocessor: PreprocessorProtocol | None = None,
audio_loader: AudioLoader | None = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
@ -182,7 +182,7 @@ def plot_true_positive_match(
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
fontsize: float | str = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
add_title: bool = True,
@ -257,11 +257,11 @@ def plot_true_positive_match(
def plot_cross_trigger_match(
match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
preprocessor: PreprocessorProtocol | None = None,
audio_loader: AudioLoader | None = None,
figsize: Tuple[int, int] | None = None,
ax: Axes | None = None,
audio_dir: data.PathLike | None = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
@ -271,7 +271,7 @@ def plot_cross_trigger_match(
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
fontsize: Union[float, str] = "small",
fontsize: float | str = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:

View File

@ -33,16 +33,16 @@ def plot_pr_curve(
precision: np.ndarray,
recall: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
color: Union[str, Tuple[float, float, float], None] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
color: str | Tuple[float, float, float] | None = None,
add_labels: bool = True,
add_legend: bool = False,
marker: Union[str, Tuple[int, int, float], None] = "o",
markeredgecolor: Union[str, Tuple[float, float, float], None] = None,
markersize: Optional[float] = None,
linestyle: Union[str, Tuple[int, ...], None] = None,
linewidth: Optional[float] = None,
marker: str | Tuple[int, int, float] | None = "o",
markeredgecolor: str | Tuple[float, float, float] | None = None,
markersize: float | None = None,
linestyle: str | Tuple[int, ...] | None = None,
linewidth: float | None = None,
label: str = "PR Curve",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
@ -77,8 +77,8 @@ def plot_pr_curve(
def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
include_ap: bool = False,
@ -118,8 +118,8 @@ def plot_pr_curves(
def plot_threshold_precision_curve(
threshold: np.ndarray,
precision: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
@ -140,8 +140,8 @@ def plot_threshold_precision_curve(
def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
):
@ -176,8 +176,8 @@ def plot_threshold_precision_curves(
def plot_threshold_recall_curve(
threshold: np.ndarray,
recall: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
@ -198,8 +198,8 @@ def plot_threshold_recall_curve(
def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
):
@ -235,8 +235,8 @@ def plot_roc_curve(
fpr: np.ndarray,
tpr: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
@ -261,8 +261,8 @@ def plot_roc_curve(
def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: axes.Axes | None = None,
figsize: Tuple[int, int] | None = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:

View File

@ -57,7 +57,7 @@ class PostprocessConfig(BaseConfig):
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.

View File

@ -88,9 +88,7 @@ def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
recording: data.Recording,
targets: TargetProtocol,
classification_threshold: Optional[
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
@ -150,7 +148,7 @@ def get_class_tags(
class_scores: np.ndarray,
targets: TargetProtocol,
top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]:
"""Generate specific PredictedTags based on class scores and decoder.

View File

@ -32,7 +32,7 @@ def extract_detection_peaks(
feature_heatmap: torch.Tensor,
classification_heatmap: torch.Tensor,
max_detections: int = 200,
threshold: Optional[float] = None,
threshold: float | None = None,
) -> List[ClipDetectionsTensor]:
height = detection_heatmap.shape[-2]
width = detection_heatmap.shape[-1]

View File

@ -27,7 +27,7 @@ BatDetect2.
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.

View File

@ -24,7 +24,7 @@ __all__ = [
def build_postprocessor(
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
config: PostprocessConfig | None = None,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor."""
config = config or PostprocessConfig()
@ -51,7 +51,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE,
):
"""Initialize the Postprocessor."""
super().__init__()
@ -66,7 +66,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def forward(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
start_times: List[float] | None = None,
) -> List[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),

View File

@ -31,14 +31,14 @@ __all__ = [
def to_xarray(
array: Union[torch.Tensor, np.ndarray],
array: torch.Tensor | np.ndarray,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
name: str = "xarray",
extra_dims: Optional[List[str]] = None,
extra_coords: Optional[Dict[str, np.ndarray]] = None,
extra_dims: List[str] | None = None,
extra_coords: Dict[str, np.ndarray] | None = None,
) -> xr.DataArray:
if isinstance(array, torch.Tensor):
array = array.detach().cpu().numpy()

View File

@ -78,11 +78,7 @@ class FixDuration(torch.nn.Module):
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
Field(discriminator="name"),
]

View File

@ -57,6 +57,6 @@ class PreprocessingConfig(BaseConfig):
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -102,7 +102,7 @@ def compute_output_samplerate(
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
config: PreprocessingConfig | None = None,
input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""

View File

@ -98,7 +98,7 @@ def _frequency_to_index(
freq: float,
n_fft: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> Optional[int]:
) -> int | None:
alpha = freq * 2 / samplerate
height = np.floor(n_fft / 2) + 1
index = int(np.floor(alpha * height))
@ -134,8 +134,8 @@ class FrequencyCrop(torch.nn.Module):
self,
samplerate: int,
n_fft: int,
min_freq: Optional[int] = None,
max_freq: Optional[int] = None,
min_freq: int | None = None,
max_freq: int | None = None,
):
super().__init__()
self.n_fft = n_fft
@ -181,7 +181,7 @@ class FrequencyCrop(torch.nn.Module):
def build_spectrogram_crop(
config: FrequencyConfig,
stft: Optional[STFTConfig] = None,
stft: STFTConfig | None = None,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module:
stft = stft or STFTConfig()
@ -377,12 +377,7 @@ class PeakNormalize(torch.nn.Module):
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig,
Field(discriminator="name"),
]

View File

@ -30,16 +30,16 @@ class TargetClassConfig(BaseConfig):
name: str
condition_input: Optional[SoundEventConditionConfig] = Field(
condition_input: SoundEventConditionConfig | None = Field(
alias="match_if",
default=None,
)
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
tags: List[data.Tag] | None = Field(default=None, exclude=True)
assign_tags: List[data.Tag] = Field(default_factory=list)
roi: Optional[ROIMapperConfig] = None
roi: ROIMapperConfig | None = None
_match_if: SoundEventConditionConfig = PrivateAttr()
@ -202,7 +202,7 @@ class SoundEventClassifier:
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> Optional[str]:
) -> str | None:
for name, condition in self.mapping.items():
if condition(sound_event_annotation):
return name

View File

@ -48,7 +48,7 @@ class TargetConfig(BaseConfig):
def load_target_config(
path: data.PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> TargetConfig:
"""Load the unified target configuration from a file.

View File

@ -414,10 +414,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
ROIMapperConfig = Annotated[
Union[
AnchorBBoxMapperConfig,
PeakEnergyBBoxMapperConfig,
],
AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig,
Field(discriminator="name"),
]
"""A discriminated union of all supported ROI mapper configurations.
@ -428,7 +425,7 @@ implementations by using the `name` field as a discriminator.
def build_roi_mapper(
config: Optional[ROIMapperConfig] = None,
config: ROIMapperConfig | None = None,
) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from a config object.
@ -572,9 +569,9 @@ def get_peak_energy_coordinates(
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
start_time: float = 0,
end_time: Optional[float] = None,
end_time: float | None = None,
low_freq: float = 0,
high_freq: Optional[float] = None,
high_freq: float | None = None,
loading_buffer: float = 0.05,
) -> Position:
"""Find the coordinates of the highest energy point in a spectrogram.

View File

@ -107,7 +107,7 @@ class Targets(TargetProtocol):
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
) -> str | None:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
@ -182,7 +182,7 @@ class Targets(TargetProtocol):
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
class_name: str | None = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
@ -219,7 +219,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
)
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
def build_targets(config: TargetConfig | None = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
Parameters
@ -251,7 +251,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> Targets:
"""Load a Targets object directly from a configuration file.
@ -292,7 +292,7 @@ def load_targets(
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]:
) -> Iterable[Tuple[str | None, Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue

View File

@ -42,7 +42,7 @@ __all__ = [
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
audio_augmentations: Registry[Augmentation, [int, AudioSource | None]] = (
Registry(name="audio_augmentation")
)
@ -103,7 +103,7 @@ class MixAudio(torch.nn.Module):
def from_config(
config: MixAudioConfig,
samplerate: int,
source: Optional[AudioSource],
source: AudioSource | None,
):
if source is None:
warnings.warn(
@ -207,7 +207,7 @@ class AddEcho(torch.nn.Module):
def from_config(
config: AddEchoConfig,
samplerate: int,
source: Optional[AudioSource],
source: AudioSource | None,
):
return AddEcho(
samplerate=samplerate,
@ -487,33 +487,18 @@ def mask_frequency(
AudioAugmentationConfig = Annotated[
Union[
MixAudioConfig,
AddEchoConfig,
],
MixAudioConfig | AddEchoConfig,
Field(discriminator="name"),
]
SpectrogramAugmentationConfig = Annotated[
Union[
ScaleVolumeConfig,
WarpConfig,
MaskFrequencyConfig,
MaskTimeConfig,
],
ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
Field(discriminator="name"),
]
AugmentationConfig = Annotated[
Union[
MixAudioConfig,
AddEchoConfig,
ScaleVolumeConfig,
WarpConfig,
MaskFrequencyConfig,
MaskTimeConfig,
],
MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of individual augmentation config."""
@ -559,8 +544,8 @@ class MaybeApply(torch.nn.Module):
def build_augmentation_from_config(
config: AugmentationConfig,
samplerate: int,
audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]:
audio_source: AudioSource | None = None,
) -> Augmentation | None:
"""Factory function to build a single augmentation from its config."""
if config.name == "mix_audio":
if audio_source is None:
@ -645,10 +630,10 @@ class AugmentationSequence(torch.nn.Module):
def build_audio_augmentations(
steps: Optional[Sequence[AudioAugmentationConfig]] = None,
steps: Sequence[AudioAugmentationConfig] | None = None,
samplerate: int = TARGET_SAMPLERATE_HZ,
audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]:
audio_source: AudioSource | None = None,
) -> Augmentation | None:
if not steps:
return None
@ -673,8 +658,8 @@ def build_audio_augmentations(
def build_spectrogram_augmentations(
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
) -> Optional[Augmentation]:
steps: Sequence[SpectrogramAugmentationConfig] | None = None,
) -> Augmentation | None:
if not steps:
return None
@ -698,9 +683,9 @@ def build_spectrogram_augmentations(
def build_augmentations(
samplerate: int,
config: Optional[AugmentationsConfig] = None,
audio_source: Optional[AudioSource] = None,
) -> Tuple[Optional[Augmentation], Optional[Augmentation]]:
config: AugmentationsConfig | None = None,
audio_source: AudioSource | None = None,
) -> Tuple[Augmentation | None, Augmentation | None]:
"""Build a composite augmentation pipeline function from configuration."""
config = config or DEFAULT_AUGMENTATION_CONFIG
@ -723,7 +708,7 @@ def build_augmentations(
def load_augmentation_config(
path: data.PathLike, field: Optional[str] = None
path: data.PathLike, field: str | None = None
) -> AugmentationsConfig:
"""Load the augmentations configuration from a file."""
return load_config(path, schema=AugmentationsConfig, field=field)

View File

@ -18,14 +18,14 @@ class CheckpointConfig(BaseConfig):
monitor: str = "classification/mean_average_precision"
mode: str = "max"
save_top_k: int = 1
filename: Optional[str] = None
filename: str | None = None
def build_checkpoint_callback(
config: Optional[CheckpointConfig] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
config: CheckpointConfig | None = None,
checkpoint_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Callback:
config = config or CheckpointConfig()

View File

@ -22,20 +22,20 @@ class PLTrainerConfig(BaseConfig):
accumulate_grad_batches: int = 1
deterministic: bool = True
check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
devices: str | int = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = 200
min_epochs: Optional[int] = None
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
val_check_interval: Optional[Union[int, float]] = None
gradient_clip_val: float | None = None
limit_train_batches: int | float | None = None
limit_test_batches: int | float | None = None
limit_val_batches: int | float | None = None
log_every_n_steps: int | None = None
max_epochs: int | None = 200
min_epochs: int | None = None
max_steps: int | None = None
min_steps: int | None = None
max_time: str | None = None
precision: str | None = None
val_check_interval: int | float | None = None
class OptimizerConfig(BaseConfig):
@ -57,6 +57,6 @@ class TrainingConfig(BaseConfig):
def load_train_config(
path: data.PathLike,
field: Optional[str] = None,
field: str | None = None,
) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field)

View File

@ -44,10 +44,10 @@ class TrainingDataset(Dataset):
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_augmentation: Optional[Augmentation] = None,
spectrogram_augmentation: Optional[Augmentation] = None,
audio_dir: Optional[data.PathLike] = None,
clipper: ClipperProtocol | None = None,
audio_augmentation: Augmentation | None = None,
spectrogram_augmentation: Augmentation | None = None,
audio_dir: data.PathLike | None = None,
):
self.clip_annotations = clip_annotations
self.clipper = clipper
@ -108,8 +108,8 @@ class ValidationDataset(Dataset):
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
clipper: ClipperProtocol | None = None,
audio_dir: data.PathLike | None = None,
):
self.clip_annotations = clip_annotations
self.labeller = labeller
@ -165,11 +165,11 @@ class TrainLoaderConfig(BaseConfig):
def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None,
audio_loader: AudioLoader | None = None,
labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TrainLoaderConfig | None = None,
num_workers: int | None = None,
) -> DataLoader:
config = config or TrainLoaderConfig()
@ -207,11 +207,11 @@ class ValLoaderConfig(BaseConfig):
def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None,
audio_loader: AudioLoader | None = None,
labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: ValLoaderConfig | None = None,
num_workers: int | None = None,
):
logger.info("Building validation data loader...")
config = config or ValLoaderConfig()
@ -240,10 +240,10 @@ def build_val_loader(
def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
audio_loader: AudioLoader | None = None,
labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TrainLoaderConfig | None = None,
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainLoaderConfig()
@ -291,10 +291,10 @@ def build_train_dataset(
def build_val_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
audio_loader: AudioLoader | None = None,
labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: ValLoaderConfig | None = None,
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or ValLoaderConfig()

View File

@ -42,10 +42,10 @@ class LabelConfig(BaseConfig):
def build_clip_labeler(
targets: Optional[TargetProtocol] = None,
targets: TargetProtocol | None = None,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
config: Optional[LabelConfig] = None,
config: LabelConfig | None = None,
) -> ClipLabeller:
"""Construct the final clip labelling function."""
config = config or LabelConfig()
@ -153,7 +153,7 @@ def generate_heatmaps(
def load_label_config(
path: data.PathLike, field: Optional[str] = None
path: data.PathLike, field: str | None = None
) -> LabelConfig:
"""Load the heatmap label generation configuration from a file.

View File

@ -21,7 +21,7 @@ def train_loop(
model: DetectionModel,
train_dataset: LabeledDataset[TrainInputs],
validation_dataset: LabeledDataset[TrainInputs],
device: Optional[torch.device] = None,
device: torch.device | None = None,
num_epochs: int = 100,
learning_rate: float = 1e-4,
):

View File

@ -106,10 +106,10 @@ def standardize_low_freq(
def format_annotation(
annotation: types.FileAnnotation,
events_of_interest: Optional[List[str]] = None,
name_replace: Optional[Dict[str, str]] = None,
events_of_interest: List[str] | None = None,
name_replace: Dict[str, str] | None = None,
convert_to_genus: bool = False,
classes_to_ignore: Optional[List[str]] = None,
classes_to_ignore: List[str] | None = None,
) -> types.FileAnnotation:
formated = []
for aa in annotation["annotation"]:
@ -154,7 +154,7 @@ def format_annotation(
def get_class_names(
data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
classes_to_ignore: List[str] | None = None,
) -> Tuple[StringCounter, List[float]]:
"""Extracts class names and their inverse frequencies.
@ -201,9 +201,9 @@ def load_set_of_anns(
*,
convert_to_genus: bool = False,
filter_issues: bool = False,
events_of_interest: Optional[List[str]] = None,
classes_to_ignore: Optional[List[str]] = None,
name_replace: Optional[Dict[str, str]] = None,
events_of_interest: List[str] | None = None,
classes_to_ignore: List[str] | None = None,
name_replace: Dict[str, str] | None = None,
) -> List[types.FileAnnotation]:
# load the annotations
anns = []

View File

@ -26,10 +26,10 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
config: Optional[dict] = None,
config: dict | None = None,
t_max: int = 100,
model: Optional[Model] = None,
loss: Optional[torch.nn.Module] = None,
model: Model | None = None,
loss: torch.nn.Module | None = None,
):
from batdetect2.config import validate_config
@ -103,7 +103,7 @@ def load_model_from_checkpoint(
def build_training_module(
config: Optional[dict] = None,
config: dict | None = None,
t_max: int = 200,
) -> TrainingModule:
return TrainingModule(config=config, t_max=t_max)

View File

@ -151,7 +151,7 @@ class FocalLoss(nn.Module):
eps: float = 1e-5,
beta: float = 4,
alpha: float = 2,
class_weights: Optional[torch.Tensor] = None,
class_weights: torch.Tensor | None = None,
mask_zero: bool = False,
):
super().__init__()
@ -422,8 +422,8 @@ class LossFunction(nn.Module, LossProtocol):
def build_loss(
config: Optional[LossConfig] = None,
class_weights: Optional[np.ndarray] = None,
config: LossConfig | None = None,
class_weights: np.ndarray | None = None,
) -> nn.Module:
"""Factory function to build the main LossFunction from configuration.

View File

@ -35,21 +35,21 @@ __all__ = [
def train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
val_annotations: Sequence[data.ClipAnnotation] | None = None,
targets: Optional["TargetProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None,
config: Optional["BatDetect2Config"] = None,
trainer: Optional[Trainer] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
trainer: Trainer | None = None,
train_workers: int | None = None,
val_workers: int | None = None,
checkpoint_dir: Path | None = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
):
from batdetect2.config import BatDetect2Config
@ -126,11 +126,11 @@ def train(
def build_trainer(
config: "BatDetect2Config",
evaluator: "EvaluatorProtocol",
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
num_epochs: Optional[int] = None,
checkpoint_dir: Path | None = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
num_epochs: int | None = None,
) -> Trainer:
trainer_conf = config.train.trainer
logger.opt(lazy=True).debug(

View File

@ -240,7 +240,7 @@ class ProcessingConfiguration(TypedDict):
detection_threshold: float
"""Threshold for detection probability."""
time_expansion: Optional[float]
time_expansion: float | None
"""Time expansion factor of the processed recordings."""
top_n: int
@ -249,7 +249,7 @@ class ProcessingConfiguration(TypedDict):
return_raw_preds: bool
"""Whether to return raw predictions."""
max_duration: Optional[float]
max_duration: float | None
"""Maximum duration of audio file to process in seconds."""
nms_kernel_size: int

View File

@ -20,7 +20,7 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
self,
predictions: Sequence[T],
path: PathLike,
audio_dir: Optional[PathLike] = None,
audio_dir: PathLike | None = None,
) -> None: ...
def load(self, path: PathLike) -> List[T]: ...

View File

@ -28,19 +28,19 @@ __all__ = [
class MatchEvaluation:
clip: data.Clip
sound_event_annotation: Optional[data.SoundEventAnnotation]
sound_event_annotation: data.SoundEventAnnotation | None
gt_det: bool
gt_class: Optional[str]
gt_geometry: Optional[data.Geometry]
gt_class: str | None
gt_geometry: data.Geometry | None
pred_score: float
pred_class_scores: Dict[str, float]
pred_geometry: Optional[data.Geometry]
pred_geometry: data.Geometry | None
affinity: float
@property
def top_class(self) -> Optional[str]:
def top_class(self) -> str | None:
if not self.pred_class_scores:
return None
@ -76,7 +76,7 @@ class MatcherProtocol(Protocol):
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ...
) -> Iterable[Tuple[int | None, int | None, float]]: ...
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)

View File

@ -42,7 +42,7 @@ class GeometryDecoder(Protocol):
"""
def __call__(
self, position: Position, size: Size, class_name: Optional[str] = None
self, position: Position, size: Size, class_name: str | None = None
) -> data.Geometry: ...
@ -93,5 +93,5 @@ class PostprocessorProtocol(Protocol):
def __call__(
self,
output: ModelOutput,
start_times: Optional[Sequence[float]] = None,
start_times: Sequence[float] | None = None,
) -> List[ClipDetectionsTensor]: ...

View File

@ -37,7 +37,7 @@ class AudioLoader(Protocol):
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path.
@ -60,7 +60,7 @@ class AudioLoader(Protocol):
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object.
@ -90,7 +90,7 @@ class AudioLoader(Protocol):
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object.

Some files were not shown because too many files have changed in this diff Show More