From 2563f26ed35dbba432f7ee336363bde611b2c6da Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 8 Dec 2025 17:14:50 +0000 Subject: [PATCH] Update type hints to python 3.10 --- pyproject.toml | 4 +- src/batdetect2/api.py | 12 ++-- src/batdetect2/api_v2.py | 42 +++++++------- src/batdetect2/audio/clips.py | 8 +-- src/batdetect2/audio/loader.py | 28 ++++----- src/batdetect2/cli/data.py | 10 ++-- src/batdetect2/cli/evaluate.py | 8 +-- src/batdetect2/cli/train.py | 22 +++---- src/batdetect2/compat/data.py | 8 +-- src/batdetect2/config.py | 4 +- src/batdetect2/core/arrays.py | 4 +- src/batdetect2/core/configs.py | 2 +- src/batdetect2/data/annotations/__init__.py | 8 +-- src/batdetect2/data/annotations/aoef.py | 4 +- src/batdetect2/data/annotations/batdetect2.py | 12 ++-- src/batdetect2/data/annotations/legacy.py | 8 +-- src/batdetect2/data/conditions.py | 11 +--- src/batdetect2/data/datasets.py | 16 +++--- src/batdetect2/data/iterators.py | 2 +- src/batdetect2/data/predictions/__init__.py | 23 +++----- src/batdetect2/data/predictions/batdetect2.py | 2 +- src/batdetect2/data/predictions/parquet.py | 2 +- src/batdetect2/data/predictions/raw.py | 4 +- src/batdetect2/data/predictions/soundevent.py | 10 ++-- src/batdetect2/data/split.py | 2 +- src/batdetect2/data/transforms.py | 12 +--- src/batdetect2/detector/compute_features.py | 10 ++-- src/batdetect2/detector/parameters.py | 4 +- src/batdetect2/detector/post_process.py | 2 +- src/batdetect2/evaluate/affinity.py | 9 +-- src/batdetect2/evaluate/config.py | 2 +- src/batdetect2/evaluate/dataset.py | 18 +++--- src/batdetect2/evaluate/evaluate.py | 6 +- src/batdetect2/evaluate/evaluator.py | 4 +- src/batdetect2/evaluate/match.py | 23 +++----- .../evaluate/metrics/classification.py | 27 ++++----- .../evaluate/metrics/clip_classification.py | 5 +- .../evaluate/metrics/clip_detection.py | 7 +-- src/batdetect2/evaluate/metrics/common.py | 4 +- src/batdetect2/evaluate/metrics/detection.py | 11 +--- src/batdetect2/evaluate/metrics/top_class.py | 16 ++---- src/batdetect2/evaluate/plots/base.py | 4 +- .../evaluate/plots/classification.py | 15 ++--- .../evaluate/plots/clip_classification.py | 9 +-- .../evaluate/plots/clip_detection.py | 12 ++-- src/batdetect2/evaluate/plots/detection.py | 15 ++--- src/batdetect2/evaluate/plots/top_class.py | 15 ++--- src/batdetect2/evaluate/tables.py | 2 +- src/batdetect2/evaluate/tasks/__init__.py | 14 ++--- src/batdetect2/evaluate/tasks/base.py | 8 +-- src/batdetect2/finetune/finetune_model.py | 4 +- src/batdetect2/finetune/prep_data_finetune.py | 2 +- src/batdetect2/inference/batch.py | 6 +- src/batdetect2/inference/dataset.py | 16 +++--- src/batdetect2/logging.py | 57 +++++++++---------- src/batdetect2/models/__init__.py | 13 ++--- src/batdetect2/models/bottleneck.py | 8 +-- src/batdetect2/models/config.py | 2 +- src/batdetect2/models/decoder.py | 9 +-- src/batdetect2/models/detectors.py | 2 +- src/batdetect2/models/encoder.py | 9 +-- src/batdetect2/plotting/clip_annotations.py | 12 ++-- src/batdetect2/plotting/clip_predictions.py | 22 +++---- src/batdetect2/plotting/clips.py | 10 ++-- src/batdetect2/plotting/common.py | 24 ++++---- src/batdetect2/plotting/detections.py | 6 +- src/batdetect2/plotting/gallery.py | 6 +- src/batdetect2/plotting/heatmaps.py | 20 +++---- src/batdetect2/plotting/legacy/plot.py | 24 ++++---- src/batdetect2/plotting/matches.py | 52 ++++++++--------- src/batdetect2/plotting/metrics.py | 44 +++++++------- src/batdetect2/postprocess/config.py | 2 +- src/batdetect2/postprocess/decoding.py | 6 +- src/batdetect2/postprocess/extraction.py | 2 +- src/batdetect2/postprocess/nms.py | 2 +- src/batdetect2/postprocess/postprocessor.py | 6 +- src/batdetect2/postprocess/remapping.py | 6 +- src/batdetect2/preprocess/audio.py | 6 +- src/batdetect2/preprocess/config.py | 2 +- src/batdetect2/preprocess/preprocessor.py | 2 +- src/batdetect2/preprocess/spectrogram.py | 15 ++--- src/batdetect2/targets/classes.py | 8 +-- src/batdetect2/targets/config.py | 2 +- src/batdetect2/targets/rois.py | 11 ++-- src/batdetect2/targets/targets.py | 10 ++-- src/batdetect2/train/augmentations.py | 49 ++++++---------- src/batdetect2/train/checkpoints.py | 10 ++-- src/batdetect2/train/config.py | 28 ++++----- src/batdetect2/train/dataset.py | 48 ++++++++-------- src/batdetect2/train/labels.py | 6 +- src/batdetect2/train/legacy/train.py | 2 +- src/batdetect2/train/legacy/train_utils.py | 14 ++--- src/batdetect2/train/lightning.py | 8 +-- src/batdetect2/train/losses.py | 6 +- src/batdetect2/train/train.py | 30 +++++----- src/batdetect2/types.py | 4 +- src/batdetect2/typing/data.py | 2 +- src/batdetect2/typing/evaluate.py | 12 ++-- src/batdetect2/typing/postprocess.py | 4 +- src/batdetect2/typing/preprocess.py | 6 +- src/batdetect2/typing/targets.py | 6 +- src/batdetect2/utils/audio_utils.py | 4 +- src/batdetect2/utils/detector_utils.py | 6 +- src/batdetect2/utils/tensors.py | 6 +- 104 files changed, 525 insertions(+), 664 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c410467..a0ad77c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ mlflow = ["mlflow>=3.1.1"] [tool.ruff] line-length = 79 -target-version = "py39" +target-version = "py310" [tool.ruff.format] docstring-code-format = true @@ -107,7 +107,7 @@ convention = "numpy" [tool.pyright] include = ["src", "tests"] -pythonVersion = "3.9" +pythonVersion = "3.10" pythonPlatform = "All" exclude = [ "src/batdetect2/detector/", diff --git a/src/batdetect2/api.py b/src/batdetect2/api.py index 9920006..d2bcc41 100644 --- a/src/batdetect2/api.py +++ b/src/batdetect2/api.py @@ -165,7 +165,7 @@ def load_audio( time_exp_fact: float = 1, target_samp_rate: int = TARGET_SAMPLERATE_HZ, scale: bool = False, - max_duration: Optional[float] = None, + max_duration: float | None = None, ) -> np.ndarray: """Load audio from file. @@ -203,7 +203,7 @@ def load_audio( def generate_spectrogram( audio: np.ndarray, samp_rate: int = TARGET_SAMPLERATE_HZ, - config: Optional[SpectrogramParameters] = None, + config: SpectrogramParameters | None = None, device: torch.device = DEVICE, ) -> torch.Tensor: """Generate spectrogram from audio array. @@ -240,7 +240,7 @@ def generate_spectrogram( def process_file( audio_file: str, model: DetectionModel = MODEL, - config: Optional[ProcessingConfiguration] = None, + config: ProcessingConfiguration | None = None, device: torch.device = DEVICE, ) -> du.RunResults: """Process audio file with model. @@ -271,7 +271,7 @@ def process_spectrogram( spec: torch.Tensor, samp_rate: int = TARGET_SAMPLERATE_HZ, model: DetectionModel = MODEL, - config: Optional[ProcessingConfiguration] = None, + config: ProcessingConfiguration | None = None, ) -> Tuple[List[Annotation], np.ndarray]: """Process spectrogram with model. @@ -312,7 +312,7 @@ def process_audio( audio: np.ndarray, samp_rate: int = TARGET_SAMPLERATE_HZ, model: DetectionModel = MODEL, - config: Optional[ProcessingConfiguration] = None, + config: ProcessingConfiguration | None = None, device: torch.device = DEVICE, ) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]: """Process audio array with model. @@ -356,7 +356,7 @@ def process_audio( def postprocess( outputs: ModelOutput, samp_rate: int = TARGET_SAMPLERATE_HZ, - config: Optional[ProcessingConfiguration] = None, + config: ProcessingConfiguration | None = None, ) -> Tuple[List[Annotation], np.ndarray]: """Postprocess model outputs. diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 38c3386..529c5db 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -67,22 +67,22 @@ class BatDetect2API: def load_annotations( self, path: data.PathLike, - base_dir: Optional[data.PathLike] = None, + base_dir: data.PathLike | None = None, ) -> Dataset: return load_dataset_from_config(path, base_dir=base_dir) def train( self, train_annotations: Sequence[data.ClipAnnotation], - val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, - train_workers: Optional[int] = None, - val_workers: Optional[int] = None, - checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR, - log_dir: Optional[Path] = DEFAULT_LOGS_DIR, - experiment_name: Optional[str] = None, - num_epochs: Optional[int] = None, - run_name: Optional[str] = None, - seed: Optional[int] = None, + val_annotations: Sequence[data.ClipAnnotation] | None = None, + train_workers: int | None = None, + val_workers: int | None = None, + checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR, + log_dir: Path | None = DEFAULT_LOGS_DIR, + experiment_name: str | None = None, + num_epochs: int | None = None, + run_name: str | None = None, + seed: int | None = None, ): train( train_annotations=train_annotations, @@ -105,10 +105,10 @@ class BatDetect2API: def evaluate( self, test_annotations: Sequence[data.ClipAnnotation], - num_workers: Optional[int] = None, + num_workers: int | None = None, output_dir: data.PathLike = DEFAULT_EVAL_DIR, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + experiment_name: str | None = None, + run_name: str | None = None, save_predictions: bool = True, ) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: return evaluate( @@ -129,7 +129,7 @@ class BatDetect2API: self, annotations: Sequence[data.ClipAnnotation], predictions: Sequence[BatDetect2Prediction], - output_dir: Optional[data.PathLike] = None, + output_dir: data.PathLike | None = None, ): clip_evals = self.evaluator.evaluate( annotations, @@ -221,7 +221,7 @@ class BatDetect2API: def process_files( self, audio_files: Sequence[data.PathLike], - num_workers: Optional[int] = None, + num_workers: int | None = None, ) -> List[BatDetect2Prediction]: return process_file_list( self.model, @@ -236,8 +236,8 @@ class BatDetect2API: def process_clips( self, clips: Sequence[data.Clip], - batch_size: Optional[int] = None, - num_workers: Optional[int] = None, + batch_size: int | None = None, + num_workers: int | None = None, ) -> List[BatDetect2Prediction]: return run_batch_inference( self.model, @@ -254,9 +254,9 @@ class BatDetect2API: self, predictions: Sequence[BatDetect2Prediction], path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, - format: Optional[str] = None, - config: Optional[OutputFormatConfig] = None, + audio_dir: data.PathLike | None = None, + format: str | None = None, + config: OutputFormatConfig | None = None, ): formatter = self.formatter @@ -331,7 +331,7 @@ class BatDetect2API: def from_checkpoint( cls, path: data.PathLike, - config: Optional[BatDetect2Config] = None, + config: BatDetect2Config | None = None, ): model, stored_config = load_model_from_checkpoint(path) diff --git a/src/batdetect2/audio/clips.py b/src/batdetect2/audio/clips.py index 1a2a41e..ecfe5da 100644 --- a/src/batdetect2/audio/clips.py +++ b/src/batdetect2/audio/clips.py @@ -245,16 +245,12 @@ class FixedDurationClip: ClipConfig = Annotated[ - Union[ - RandomClipConfig, - PaddedClipConfig, - FixedDurationClipConfig, - ], + RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig, Field(discriminator="name"), ] -def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol: +def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol: config = config or RandomClipConfig() logger.opt(lazy=True).debug( diff --git a/src/batdetect2/audio/loader.py b/src/batdetect2/audio/loader.py index 76d60a5..c02307d 100644 --- a/src/batdetect2/audio/loader.py +++ b/src/batdetect2/audio/loader.py @@ -50,7 +50,7 @@ class AudioConfig(BaseConfig): resample: ResampleConfig = Field(default_factory=ResampleConfig) -def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader: +def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader: """Factory function to create an AudioLoader based on configuration.""" config = config or AudioConfig() return SoundEventAudioLoader( @@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader): def __init__( self, samplerate: int = TARGET_SAMPLERATE_HZ, - config: Optional[ResampleConfig] = None, + config: ResampleConfig | None = None, ): self.samplerate = samplerate self.config = config or ResampleConfig() @@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader): def load_file( self, path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> np.ndarray: """Load and preprocess audio directly from a file path.""" return load_file_audio( @@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader): def load_recording( self, recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> np.ndarray: """Load and preprocess the entire audio for a Recording object.""" return load_recording_audio( @@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader): def load_clip( self, clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> np.ndarray: """Load and preprocess the audio segment defined by a Clip object.""" return load_clip_audio( @@ -112,9 +112,9 @@ class SoundEventAudioLoader(AudioLoader): def load_file_audio( path: data.PathLike, - samplerate: Optional[int] = None, - config: Optional[ResampleConfig] = None, - audio_dir: Optional[data.PathLike] = None, + samplerate: int | None = None, + config: ResampleConfig | None = None, + audio_dir: data.PathLike | None = None, dtype: DTypeLike = np.float32, # type: ignore ) -> np.ndarray: """Load and preprocess audio from a file path using specified config.""" @@ -136,9 +136,9 @@ def load_file_audio( def load_recording_audio( recording: data.Recording, - samplerate: Optional[int] = None, - config: Optional[ResampleConfig] = None, - audio_dir: Optional[data.PathLike] = None, + samplerate: int | None = None, + config: ResampleConfig | None = None, + audio_dir: data.PathLike | None = None, dtype: DTypeLike = np.float32, # type: ignore ) -> np.ndarray: """Load and preprocess the entire audio content of a recording using config.""" @@ -158,9 +158,9 @@ def load_recording_audio( def load_clip_audio( clip: data.Clip, - samplerate: Optional[int] = None, - config: Optional[ResampleConfig] = None, - audio_dir: Optional[data.PathLike] = None, + samplerate: int | None = None, + config: ResampleConfig | None = None, + audio_dir: data.PathLike | None = None, dtype: DTypeLike = np.float32, # type: ignore ) -> np.ndarray: """Load and preprocess a specific audio clip segment based on config.""" diff --git a/src/batdetect2/cli/data.py b/src/batdetect2/cli/data.py index f571e5c..2aba78f 100644 --- a/src/batdetect2/cli/data.py +++ b/src/batdetect2/cli/data.py @@ -34,9 +34,9 @@ def data(): ... ) def summary( dataset_config: Path, - field: Optional[str] = None, - targets_path: Optional[Path] = None, - base_dir: Optional[Path] = None, + field: str | None = None, + targets_path: Path | None = None, + base_dir: Path | None = None, ): from batdetect2.data import compute_class_summary, load_dataset_from_config from batdetect2.targets import load_targets @@ -83,9 +83,9 @@ def summary( ) def convert( dataset_config: Path, - field: Optional[str] = None, + field: str | None = None, output: Path = Path("annotations.json"), - base_dir: Optional[Path] = None, + base_dir: Path | None = None, ): """Convert a dataset config file to soundevent format.""" from soundevent import data, io diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index 28c771f..5874231 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -25,11 +25,11 @@ def evaluate_command( model_path: Path, test_dataset: Path, base_dir: Path, - config_path: Optional[Path], + config_path: Path | None, output_dir: Path = DEFAULT_OUTPUT_DIR, - num_workers: Optional[int] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + num_workers: int | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ): from batdetect2.api_v2 import BatDetect2API from batdetect2.config import load_full_config diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 7cb7d98..96f87f7 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -26,19 +26,19 @@ __all__ = ["train_command"] @click.option("--seed", type=int) def train_command( train_dataset: Path, - val_dataset: Optional[Path] = None, - model_path: Optional[Path] = None, - ckpt_dir: Optional[Path] = None, - log_dir: Optional[Path] = None, - config: Optional[Path] = None, - targets_config: Optional[Path] = None, - config_field: Optional[str] = None, - seed: Optional[int] = None, - num_epochs: Optional[int] = None, + val_dataset: Path | None = None, + model_path: Path | None = None, + ckpt_dir: Path | None = None, + log_dir: Path | None = None, + config: Path | None = None, + targets_config: Path | None = None, + config_field: str | None = None, + seed: int | None = None, + num_epochs: int | None = None, train_workers: int = 0, val_workers: int = 0, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + experiment_name: str | None = None, + run_name: str | None = None, ): from batdetect2.api_v2 import BatDetect2API from batdetect2.config import ( diff --git a/src/batdetect2/compat/data.py b/src/batdetect2/compat/data.py index 6473eb6..9cb540a 100644 --- a/src/batdetect2/compat/data.py +++ b/src/batdetect2/compat/data.py @@ -4,7 +4,7 @@ import json import os import uuid from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, List import numpy as np from soundevent import data @@ -17,7 +17,7 @@ from batdetect2.types import ( FileAnnotation, ) -PathLike = Union[Path, str, os.PathLike] +PathLike = Path | str | os.PathLike __all__ = [ "convert_to_annotation_group", @@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__" NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242") -EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] +EventFn = Callable[[data.SoundEventAnnotation], str | None] ClassFn = Callable[[data.Recording], int] @@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction( def file_annotation_to_clip( file_annotation: FileAnnotation, - audio_dir: Optional[PathLike] = None, + audio_dir: PathLike | None = None, label_key: str = "class", ) -> data.Clip: """Convert file annotation to recording.""" diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index 6a17b68..1044d86 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -43,7 +43,7 @@ class BatDetect2Config(BaseConfig): output: OutputFormatConfig = Field(default_factory=RawOutputConfig) -def validate_config(config: Optional[dict]) -> BatDetect2Config: +def validate_config(config: dict | None) -> BatDetect2Config: if config is None: return BatDetect2Config() @@ -52,6 +52,6 @@ def validate_config(config: Optional[dict]) -> BatDetect2Config: def load_full_config( path: PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> BatDetect2Config: return load_config(path, schema=BatDetect2Config, field=field) diff --git a/src/batdetect2/core/arrays.py b/src/batdetect2/core/arrays.py index c3204c5..a146f14 100644 --- a/src/batdetect2/core/arrays.py +++ b/src/batdetect2/core/arrays.py @@ -86,8 +86,8 @@ def adjust_width( def slice_tensor( tensor: torch.Tensor, - start: Optional[int] = None, - end: Optional[int] = None, + start: int | None = None, + end: int | None = None, dim: int = -1, ) -> torch.Tensor: slices = [slice(None)] * tensor.ndim diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index fd2583c..252c6e5 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -128,7 +128,7 @@ def get_object_field(obj: dict, current_key: str) -> Any: def load_config( path: PathLike, schema: Type[T], - field: Optional[str] = None, + field: str | None = None, ) -> T: """Load and validate configuration data from a file against a schema. diff --git a/src/batdetect2/data/annotations/__init__.py b/src/batdetect2/data/annotations/__init__.py index 1d499c3..b44b84b 100644 --- a/src/batdetect2/data/annotations/__init__.py +++ b/src/batdetect2/data/annotations/__init__.py @@ -43,11 +43,7 @@ __all__ = [ AnnotationFormats = Annotated[ - Union[ - BatDetect2MergedAnnotations, - BatDetect2FilesAnnotations, - AOEFAnnotations, - ], + BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations, Field(discriminator="format"), ] """Type Alias representing all supported data source configurations. @@ -63,7 +59,7 @@ source configuration represents. def load_annotated_dataset( dataset: AnnotatedDataset, - base_dir: Optional[data.PathLike] = None, + base_dir: data.PathLike | None = None, ) -> data.AnnotationSet: """Load annotations for a single data source based on its configuration. diff --git a/src/batdetect2/data/annotations/aoef.py b/src/batdetect2/data/annotations/aoef.py index e73cf60..35b65ea 100644 --- a/src/batdetect2/data/annotations/aoef.py +++ b/src/batdetect2/data/annotations/aoef.py @@ -77,14 +77,14 @@ class AOEFAnnotations(AnnotatedDataset): annotations_path: Path - filter: Optional[AnnotationTaskFilter] = Field( + filter: AnnotationTaskFilter | None = Field( default_factory=AnnotationTaskFilter ) def load_aoef_annotated_dataset( dataset: AOEFAnnotations, - base_dir: Optional[data.PathLike] = None, + base_dir: data.PathLike | None = None, ) -> data.AnnotationSet: """Load annotations from an AnnotationSet or AnnotationProject file. diff --git a/src/batdetect2/data/annotations/batdetect2.py b/src/batdetect2/data/annotations/batdetect2.py index 5982f42..1dd3727 100644 --- a/src/batdetect2/data/annotations/batdetect2.py +++ b/src/batdetect2/data/annotations/batdetect2.py @@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`. import json import os from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal from loguru import logger from pydantic import Field, ValidationError @@ -43,7 +43,7 @@ from batdetect2.data.annotations.legacy import ( ) from batdetect2.data.annotations.types import AnnotatedDataset -PathLike = Union[Path, str, os.PathLike] +PathLike = Path | str | os.PathLike __all__ = [ @@ -102,7 +102,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset): format: Literal["batdetect2"] = "batdetect2" annotations_dir: Path - filter: Optional[AnnotationFilter] = Field( + filter: AnnotationFilter | None = Field( default_factory=AnnotationFilter, ) @@ -133,14 +133,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset): format: Literal["batdetect2_file"] = "batdetect2_file" annotations_path: Path - filter: Optional[AnnotationFilter] = Field( + filter: AnnotationFilter | None = Field( default_factory=AnnotationFilter, ) def load_batdetect2_files_annotated_dataset( dataset: BatDetect2FilesAnnotations, - base_dir: Optional[PathLike] = None, + base_dir: PathLike | None = None, ) -> data.AnnotationSet: """Load and convert 'batdetect2_file' annotations into an AnnotationSet. @@ -244,7 +244,7 @@ def load_batdetect2_files_annotated_dataset( def load_batdetect2_merged_annotated_dataset( dataset: BatDetect2MergedAnnotations, - base_dir: Optional[PathLike] = None, + base_dir: PathLike | None = None, ) -> data.AnnotationSet: """Load and convert 'batdetect2_merged' annotations into an AnnotationSet. diff --git a/src/batdetect2/data/annotations/legacy.py b/src/batdetect2/data/annotations/legacy.py index e689a92..ad946fb 100644 --- a/src/batdetect2/data/annotations/legacy.py +++ b/src/batdetect2/data/annotations/legacy.py @@ -3,12 +3,12 @@ import os import uuid from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, List from pydantic import BaseModel, Field from soundevent import data -PathLike = Union[Path, str, os.PathLike] +PathLike = Path | str | os.PathLike __all__ = [] @@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5( ) -EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] +EventFn = Callable[[data.SoundEventAnnotation], str | None] ClassFn = Callable[[data.Recording], int] @@ -130,7 +130,7 @@ def get_sound_event_tags( def file_annotation_to_clip( file_annotation: FileAnnotation, - audio_dir: Optional[PathLike] = None, + audio_dir: PathLike | None = None, label_key: str = "class", ) -> data.Clip: """Convert file annotation to recording.""" diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py index f9bca3d..04a8791 100644 --- a/src/batdetect2/data/conditions.py +++ b/src/batdetect2/data/conditions.py @@ -264,16 +264,7 @@ class Not: SoundEventConditionConfig = Annotated[ - Union[ - HasTagConfig, - HasAllTagsConfig, - HasAnyTagConfig, - DurationConfig, - FrequencyConfig, - AllOfConfig, - AnyOfConfig, - NotConfig, - ], + HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/data/datasets.py b/src/batdetect2/data/datasets.py index 3a442ab..ef10077 100644 --- a/src/batdetect2/data/datasets.py +++ b/src/batdetect2/data/datasets.py @@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig): description: str sources: List[AnnotationFormats] - sound_event_filter: Optional[SoundEventConditionConfig] = None + sound_event_filter: SoundEventConditionConfig | None = None sound_event_transforms: List[SoundEventTransformConfig] = Field( default_factory=list ) @@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig): def load_dataset( config: DatasetConfig, - base_dir: Optional[data.PathLike] = None, + base_dir: data.PathLike | None = None, ) -> Dataset: """Load all clip annotations from the sources defined in a DatasetConfig.""" clip_annotations = [] @@ -161,14 +161,14 @@ def insert_source_tag( ) -def load_dataset_config(path: data.PathLike, field: Optional[str] = None): +def load_dataset_config(path: data.PathLike, field: str | None = None): return load_config(path=path, schema=DatasetConfig, field=field) def load_dataset_from_config( path: data.PathLike, - field: Optional[str] = None, - base_dir: Optional[data.PathLike] = None, + field: str | None = None, + base_dir: data.PathLike | None = None, ) -> Dataset: """Load dataset annotation metadata from a configuration file. @@ -215,9 +215,9 @@ def load_dataset_from_config( def save_dataset( dataset: Dataset, path: data.PathLike, - name: Optional[str] = None, - description: Optional[str] = None, - audio_dir: Optional[Path] = None, + name: str | None = None, + description: str | None = None, + audio_dir: Path | None = None, ) -> None: """Save a loaded dataset (list of ClipAnnotations) to a file. diff --git a/src/batdetect2/data/iterators.py b/src/batdetect2/data/iterators.py index f3f3ff7..5a53441 100644 --- a/src/batdetect2/data/iterators.py +++ b/src/batdetect2/data/iterators.py @@ -10,7 +10,7 @@ from batdetect2.typing.targets import TargetProtocol def iterate_over_sound_events( dataset: Dataset, targets: TargetProtocol, -) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]: +) -> Generator[Tuple[str | None, data.SoundEventAnnotation], None, None]: """Iterate over sound events in a dataset. Parameters diff --git a/src/batdetect2/data/predictions/__init__.py b/src/batdetect2/data/predictions/__init__.py index 1006476..758c10d 100644 --- a/src/batdetect2/data/predictions/__init__.py +++ b/src/batdetect2/data/predictions/__init__.py @@ -24,19 +24,14 @@ __all__ = [ OutputFormatConfig = Annotated[ - Union[ - BatDetect2OutputConfig, - ParquetOutputConfig, - SoundEventOutputConfig, - RawOutputConfig, - ], + BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig, Field(discriminator="name"), ] def build_output_formatter( - targets: Optional[TargetProtocol] = None, - config: Optional[OutputFormatConfig] = None, + targets: TargetProtocol | None = None, + config: OutputFormatConfig | None = None, ) -> OutputFormatterProtocol: """Construct the final output formatter.""" from batdetect2.targets import build_targets @@ -48,9 +43,9 @@ def build_output_formatter( def get_output_formatter( - name: Optional[str] = None, - targets: Optional[TargetProtocol] = None, - config: Optional[OutputFormatConfig] = None, + name: str | None = None, + targets: TargetProtocol | None = None, + config: OutputFormatConfig | None = None, ) -> OutputFormatterProtocol: """Get the output formatter by name.""" @@ -71,9 +66,9 @@ def get_output_formatter( def load_predictions( path: PathLike, - format: Optional[str] = "raw", - config: Optional[OutputFormatConfig] = None, - targets: Optional[TargetProtocol] = None, + format: str | None = "raw", + config: OutputFormatConfig | None = None, + targets: TargetProtocol | None = None, ): """Load predictions from a file.""" from batdetect2.targets import build_targets diff --git a/src/batdetect2/data/predictions/batdetect2.py b/src/batdetect2/data/predictions/batdetect2.py index 0f4b382..67d4796 100644 --- a/src/batdetect2/data/predictions/batdetect2.py +++ b/src/batdetect2/data/predictions/batdetect2.py @@ -123,7 +123,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): self, predictions: Sequence[FileAnnotation], path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> None: path = Path(path) diff --git a/src/batdetect2/data/predictions/parquet.py b/src/batdetect2/data/predictions/parquet.py index c7c1030..3b8d820 100644 --- a/src/batdetect2/data/predictions/parquet.py +++ b/src/batdetect2/data/predictions/parquet.py @@ -53,7 +53,7 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): self, predictions: Sequence[BatDetect2Prediction], path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> None: path = Path(path) diff --git a/src/batdetect2/data/predictions/raw.py b/src/batdetect2/data/predictions/raw.py index 814784d..8205a0d 100644 --- a/src/batdetect2/data/predictions/raw.py +++ b/src/batdetect2/data/predictions/raw.py @@ -55,7 +55,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): self, predictions: Sequence[BatDetect2Prediction], path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> None: path = Path(path) @@ -84,7 +84,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): def pred_to_xr( self, prediction: BatDetect2Prediction, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> xr.Dataset: clip = prediction.clip recording = clip.recording diff --git a/src/batdetect2/data/predictions/soundevent.py b/src/batdetect2/data/predictions/soundevent.py index f400f42..36cb407 100644 --- a/src/batdetect2/data/predictions/soundevent.py +++ b/src/batdetect2/data/predictions/soundevent.py @@ -18,16 +18,16 @@ from batdetect2.typing import ( class SoundEventOutputConfig(BaseConfig): name: Literal["soundevent"] = "soundevent" - top_k: Optional[int] = 1 - min_score: Optional[float] = None + top_k: int | None = 1 + min_score: float | None = None class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): def __init__( self, targets: TargetProtocol, - top_k: Optional[int] = 1, - min_score: Optional[float] = 0, + top_k: int | None = 1, + min_score: float | None = 0, ): self.targets = targets self.top_k = top_k @@ -45,7 +45,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): self, predictions: Sequence[data.ClipPrediction], path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> None: run = data.PredictionSet(clip_predictions=list(predictions)) diff --git a/src/batdetect2/data/split.py b/src/batdetect2/data/split.py index adf426d..81348d9 100644 --- a/src/batdetect2/data/split.py +++ b/src/batdetect2/data/split.py @@ -14,7 +14,7 @@ def split_dataset_by_recordings( dataset: Dataset, targets: TargetProtocol, train_size: float = 0.75, - random_state: Optional[int] = None, + random_state: int | None = None, ) -> Tuple[Dataset, Dataset]: recordings = extract_recordings_df(dataset) diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py index a826b7d..8e94a36 100644 --- a/src/batdetect2/data/transforms.py +++ b/src/batdetect2/data/transforms.py @@ -142,7 +142,7 @@ class MapTagValueConfig(BaseConfig): name: Literal["map_tag_value"] = "map_tag_value" tag_key: str value_mapping: Dict[str, str] - target_key: Optional[str] = None + target_key: str | None = None class MapTagValue: @@ -150,7 +150,7 @@ class MapTagValue: self, tag_key: str, value_mapping: Dict[str, str], - target_key: Optional[str] = None, + target_key: str | None = None, ): self.tag_key = tag_key self.value_mapping = value_mapping @@ -221,13 +221,7 @@ class ApplyAll: SoundEventTransformConfig = Annotated[ - Union[ - SetFrequencyBoundConfig, - ReplaceTagConfig, - MapTagValueConfig, - ApplyIfConfig, - ApplyAllConfig, - ], + SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/detector/compute_features.py b/src/batdetect2/detector/compute_features.py index f9d1da5..d7d18e6 100644 --- a/src/batdetect2/detector/compute_features.py +++ b/src/batdetect2/detector/compute_features.py @@ -86,7 +86,7 @@ def compute_bandwidth( def compute_max_power_bb( prediction: types.Prediction, - spec: Optional[np.ndarray] = None, + spec: np.ndarray | None = None, min_freq: int = MIN_FREQ_HZ, max_freq: int = MAX_FREQ_HZ, **_, @@ -131,7 +131,7 @@ def compute_max_power_bb( def compute_max_power( prediction: types.Prediction, - spec: Optional[np.ndarray] = None, + spec: np.ndarray | None = None, min_freq: int = MIN_FREQ_HZ, max_freq: int = MAX_FREQ_HZ, **_, @@ -157,7 +157,7 @@ def compute_max_power( def compute_max_power_first( prediction: types.Prediction, - spec: Optional[np.ndarray] = None, + spec: np.ndarray | None = None, min_freq: int = MIN_FREQ_HZ, max_freq: int = MAX_FREQ_HZ, **_, @@ -184,7 +184,7 @@ def compute_max_power_first( def compute_max_power_second( prediction: types.Prediction, - spec: Optional[np.ndarray] = None, + spec: np.ndarray | None = None, min_freq: int = MIN_FREQ_HZ, max_freq: int = MAX_FREQ_HZ, **_, @@ -211,7 +211,7 @@ def compute_max_power_second( def compute_call_interval( prediction: types.Prediction, - previous: Optional[types.Prediction] = None, + previous: types.Prediction | None = None, **_, ) -> float: """Compute time between this call and the previous call in seconds.""" diff --git a/src/batdetect2/detector/parameters.py b/src/batdetect2/detector/parameters.py index 3e4f091..ba2b430 100644 --- a/src/batdetect2/detector/parameters.py +++ b/src/batdetect2/detector/parameters.py @@ -198,8 +198,8 @@ class TrainingParameters(BaseModel): def get_params( make_dirs: bool = False, exps_dir: str = "../../experiments/", - model_name: Optional[str] = None, - experiment: Union[Path, str, None] = None, + model_name: str | None = None, + experiment: Path | str | None = None, **kwargs, ) -> TrainingParameters: experiments_dir = Path(exps_dir) diff --git a/src/batdetect2/detector/post_process.py b/src/batdetect2/detector/post_process.py index eaa8152..fc5a14e 100644 --- a/src/batdetect2/detector/post_process.py +++ b/src/batdetect2/detector/post_process.py @@ -151,7 +151,7 @@ def run_nms( def non_max_suppression( heat: torch.Tensor, - kernel_size: Union[int, Tuple[int, int]], + kernel_size: int | Tuple[int, int], ): # kernel can be an int or list/tuple if isinstance(kernel_size, int): diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index c8ff253..a9ee307 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -213,18 +213,13 @@ class GeometricIOU(AffinityFunction): AffinityConfig = Annotated[ - Union[ - TimeAffinityConfig, - IntervalIOUConfig, - BBoxIOUConfig, - GeometricIOUConfig, - ], + TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig, Field(discriminator="name"), ] def build_affinity_function( - config: Optional[AffinityConfig] = None, + config: AffinityConfig | None = None, ) -> AffinityFunction: config = config or GeometricIOUConfig() return affinity_functions.build(config) diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index fbcb281..71c32b6 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -51,6 +51,6 @@ def get_default_eval_config() -> EvaluationConfig: def load_evaluation_config( path: data.PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> EvaluationConfig: return load_config(path, schema=EvaluationConfig, field=field) diff --git a/src/batdetect2/evaluate/dataset.py b/src/batdetect2/evaluate/dataset.py index fb3458c..36f15af 100644 --- a/src/batdetect2/evaluate/dataset.py +++ b/src/batdetect2/evaluate/dataset.py @@ -39,8 +39,8 @@ class TestDataset(Dataset[TestExample]): clip_annotations: Sequence[data.ClipAnnotation], audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, - clipper: Optional[ClipperProtocol] = None, - audio_dir: Optional[data.PathLike] = None, + clipper: ClipperProtocol | None = None, + audio_dir: data.PathLike | None = None, ): self.clip_annotations = list(clip_annotations) self.clipper = clipper @@ -78,10 +78,10 @@ class TestLoaderConfig(BaseConfig): def build_test_loader( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[TestLoaderConfig] = None, - num_workers: Optional[int] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + config: TestLoaderConfig | None = None, + num_workers: int | None = None, ) -> DataLoader[TestExample]: logger.info("Building test data loader...") config = config or TestLoaderConfig() @@ -109,9 +109,9 @@ def build_test_loader( def build_test_dataset( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[TestLoaderConfig] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + config: TestLoaderConfig | None = None, ) -> TestDataset: logger.info("Building training dataset...") config = config or TestLoaderConfig() diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 0029b7b..3a400f6 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -34,10 +34,10 @@ def evaluate( preprocessor: Optional["PreprocessorProtocol"] = None, config: Optional["BatDetect2Config"] = None, formatter: Optional["OutputFormatterProtocol"] = None, - num_workers: Optional[int] = None, + num_workers: int | None = None, output_dir: data.PathLike = DEFAULT_EVAL_DIR, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: from batdetect2.config import BatDetect2Config diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index 181994b..d29df07 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -51,8 +51,8 @@ class Evaluator: def build_evaluator( - config: Optional[Union[EvaluationConfig, dict]] = None, - targets: Optional[TargetProtocol] = None, + config: EvaluationConfig | dict | None = None, + targets: TargetProtocol | None = None, ) -> EvaluatorProtocol: targets = targets or build_targets() diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index d93cb7a..fe0b0c5 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -35,9 +35,9 @@ def match( sound_event_annotations: Sequence[data.SoundEventAnnotation], raw_predictions: Sequence[RawPrediction], clip: data.Clip, - scores: Optional[Sequence[float]] = None, - targets: Optional[TargetProtocol] = None, - matcher: Optional[MatcherProtocol] = None, + scores: Sequence[float] | None = None, + targets: TargetProtocol | None = None, + matcher: MatcherProtocol | None = None, ) -> ClipMatches: if matcher is None: matcher = build_matcher() @@ -151,7 +151,7 @@ def match_start_times( predictions: Sequence[data.Geometry], scores: Sequence[float], distance_threshold: float = 0.01, -) -> Iterable[Tuple[Optional[int], Optional[int], float]]: +) -> Iterable[Tuple[int | None, int | None, float]]: if not ground_truth: for index in range(len(predictions)): yield index, None, 0 @@ -287,7 +287,7 @@ def greedy_match( scores: Sequence[float], affinity_threshold: float = 0.5, affinity_function: AffinityFunction = compute_affinity, -) -> Iterable[Tuple[Optional[int], Optional[int], float]]: +) -> Iterable[Tuple[int | None, int | None, float]]: """Performs a greedy, one-to-one matching of source to target geometries. Iterates through source geometries, prioritizing by score if provided. Each @@ -514,12 +514,7 @@ class OptimalMatcher(MatcherProtocol): MatchConfig = Annotated[ - Union[ - GreedyMatchConfig, - StartTimeMatchConfig, - OptimalMatchConfig, - GreedyAffinityMatchConfig, - ], + GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig, Field(discriminator="name"), ] @@ -558,7 +553,7 @@ def compute_affinity_matrix( def select_optimal_matches( affinity_matrix: np.ndarray, affinity_threshold: float = 0.5, -) -> Iterable[Tuple[Optional[int], Optional[int], float]]: +) -> Iterable[Tuple[int | None, int | None, float]]: num_gt, num_pred = affinity_matrix.shape gts = set(range(num_gt)) preds = set(range(num_pred)) @@ -588,7 +583,7 @@ def select_optimal_matches( def select_greedy_matches( affinity_matrix: np.ndarray, affinity_threshold: float = 0.5, -) -> Iterable[Tuple[Optional[int], Optional[int], float]]: +) -> Iterable[Tuple[int | None, int | None, float]]: num_gt, num_pred = affinity_matrix.shape unmatched_pred = set(range(num_pred)) @@ -612,6 +607,6 @@ def select_greedy_matches( yield pred_idx, None, 0 -def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol: +def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol: config = config or StartTimeMatchConfig() return matching_strategies.build(config) diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index c18a9b1..cbab6a6 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -36,13 +36,13 @@ __all__ = [ @dataclass class MatchEval: clip: data.Clip - gt: Optional[data.SoundEventAnnotation] - pred: Optional[RawPrediction] + gt: data.SoundEventAnnotation | None + pred: RawPrediction | None is_prediction: bool is_ground_truth: bool is_generic: bool - true_class: Optional[str] + true_class: str | None score: float @@ -61,16 +61,16 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = ( class BaseClassificationConfig(BaseConfig): - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None + include: List[str] | None = None + exclude: List[str] | None = None class BaseClassificationMetric: def __init__( self, targets: TargetProtocol, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, + include: List[str] | None = None, + exclude: List[str] | None = None, ): self.targets = targets self.include = include @@ -100,8 +100,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric): ignore_non_predictions: bool = True, ignore_generic: bool = True, label: str = "average_precision", - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, + include: List[str] | None = None, + exclude: List[str] | None = None, ): super().__init__(include=include, exclude=exclude, targets=targets) self.ignore_non_predictions = ignore_non_predictions @@ -169,8 +169,8 @@ class ClassificationROCAUC(BaseClassificationMetric): ignore_non_predictions: bool = True, ignore_generic: bool = True, label: str = "roc_auc", - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, + include: List[str] | None = None, + exclude: List[str] | None = None, ): self.targets = targets self.ignore_non_predictions = ignore_non_predictions @@ -225,10 +225,7 @@ class ClassificationROCAUC(BaseClassificationMetric): ClassificationMetricConfig = Annotated[ - Union[ - ClassificationAveragePrecisionConfig, - ClassificationROCAUCConfig, - ], + ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/metrics/clip_classification.py b/src/batdetect2/evaluate/metrics/clip_classification.py index 5554b9c..a566c97 100644 --- a/src/batdetect2/evaluate/metrics/clip_classification.py +++ b/src/batdetect2/evaluate/metrics/clip_classification.py @@ -123,10 +123,7 @@ class ClipClassificationROCAUC: ClipClassificationMetricConfig = Annotated[ - Union[ - ClipClassificationAveragePrecisionConfig, - ClipClassificationROCAUCConfig, - ], + ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/metrics/clip_detection.py b/src/batdetect2/evaluate/metrics/clip_detection.py index df4b99d..728152c 100644 --- a/src/batdetect2/evaluate/metrics/clip_detection.py +++ b/src/batdetect2/evaluate/metrics/clip_detection.py @@ -159,12 +159,7 @@ class ClipDetectionPrecision: ClipDetectionMetricConfig = Annotated[ - Union[ - ClipDetectionAveragePrecisionConfig, - ClipDetectionROCAUCConfig, - ClipDetectionRecallConfig, - ClipDetectionPrecisionConfig, - ], + ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/metrics/common.py b/src/batdetect2/evaluate/metrics/common.py index dfc47bc..77e8aa5 100644 --- a/src/batdetect2/evaluate/metrics/common.py +++ b/src/batdetect2/evaluate/metrics/common.py @@ -11,7 +11,7 @@ __all__ = [ def compute_precision_recall( y_true, y_score, - num_positives: Optional[int] = None, + num_positives: int | None = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: y_true = np.array(y_true) y_score = np.array(y_score) @@ -41,7 +41,7 @@ def compute_precision_recall( def average_precision( y_true, y_score, - num_positives: Optional[int] = None, + num_positives: int | None = None, ) -> float: if num_positives == 0: return np.nan diff --git a/src/batdetect2/evaluate/metrics/detection.py b/src/batdetect2/evaluate/metrics/detection.py index f687392..e625435 100644 --- a/src/batdetect2/evaluate/metrics/detection.py +++ b/src/batdetect2/evaluate/metrics/detection.py @@ -28,8 +28,8 @@ __all__ = [ @dataclass class MatchEval: - gt: Optional[data.SoundEventAnnotation] - pred: Optional[RawPrediction] + gt: data.SoundEventAnnotation | None + pred: RawPrediction | None is_prediction: bool is_ground_truth: bool @@ -212,12 +212,7 @@ class DetectionPrecision: DetectionMetricConfig = Annotated[ - Union[ - DetectionAveragePrecisionConfig, - DetectionROCAUCConfig, - DetectionRecallConfig, - DetectionPrecisionConfig, - ], + DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py index d16575b..96fec6e 100644 --- a/src/batdetect2/evaluate/metrics/top_class.py +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -31,14 +31,14 @@ __all__ = [ @dataclass class MatchEval: clip: data.Clip - gt: Optional[data.SoundEventAnnotation] - pred: Optional[RawPrediction] + gt: data.SoundEventAnnotation | None + pred: RawPrediction | None is_ground_truth: bool is_generic: bool is_prediction: bool - pred_class: Optional[str] - true_class: Optional[str] + pred_class: str | None + true_class: str | None score: float @@ -301,13 +301,7 @@ class BalancedAccuracy: TopClassMetricConfig = Annotated[ - Union[ - TopClassAveragePrecisionConfig, - TopClassROCAUCConfig, - TopClassRecallConfig, - TopClassPrecisionConfig, - BalancedAccuracyConfig, - ], + TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/base.py b/src/batdetect2/evaluate/plots/base.py index c01406b..bce076b 100644 --- a/src/batdetect2/evaluate/plots/base.py +++ b/src/batdetect2/evaluate/plots/base.py @@ -10,7 +10,7 @@ from batdetect2.typing import TargetProtocol class BasePlotConfig(BaseConfig): label: str = "plot" theme: str = "default" - title: Optional[str] = None + title: str | None = None figsize: tuple[int, int] = (10, 10) dpi: int = 100 @@ -21,7 +21,7 @@ class BasePlot: targets: TargetProtocol, label: str = "plot", figsize: tuple[int, int] = (10, 10), - title: Optional[str] = None, + title: str | None = None, dpi: int = 100, theme: str = "default", ): diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py index dc9a1e4..98f92b6 100644 --- a/src/batdetect2/evaluate/plots/classification.py +++ b/src/batdetect2/evaluate/plots/classification.py @@ -45,7 +45,7 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = ( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" - title: Optional[str] = "Classification Precision-Recall Curve" + title: str | None = "Classification Precision-Recall Curve" ignore_non_predictions: bool = True ignore_generic: bool = True separate_figures: bool = False @@ -108,7 +108,7 @@ class PRCurve(BasePlot): class ThresholdPrecisionCurveConfig(BasePlotConfig): name: Literal["threshold_precision_curve"] = "threshold_precision_curve" label: str = "threshold_precision_curve" - title: Optional[str] = "Classification Threshold-Precision Curve" + title: str | None = "Classification Threshold-Precision Curve" ignore_non_predictions: bool = True ignore_generic: bool = True separate_figures: bool = False @@ -181,7 +181,7 @@ class ThresholdPrecisionCurve(BasePlot): class ThresholdRecallCurveConfig(BasePlotConfig): name: Literal["threshold_recall_curve"] = "threshold_recall_curve" label: str = "threshold_recall_curve" - title: Optional[str] = "Classification Threshold-Recall Curve" + title: str | None = "Classification Threshold-Recall Curve" ignore_non_predictions: bool = True ignore_generic: bool = True separate_figures: bool = False @@ -254,7 +254,7 @@ class ThresholdRecallCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" - title: Optional[str] = "Classification ROC Curve" + title: str | None = "Classification ROC Curve" ignore_non_predictions: bool = True ignore_generic: bool = True separate_figures: bool = False @@ -326,12 +326,7 @@ class ROCCurve(BasePlot): ClassificationPlotConfig = Annotated[ - Union[ - PRCurveConfig, - ROCCurveConfig, - ThresholdPrecisionCurveConfig, - ThresholdRecallCurveConfig, - ], + PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/clip_classification.py b/src/batdetect2/evaluate/plots/clip_classification.py index 388e999..40bef6e 100644 --- a/src/batdetect2/evaluate/plots/clip_classification.py +++ b/src/batdetect2/evaluate/plots/clip_classification.py @@ -44,7 +44,7 @@ clip_classification_plots: Registry[ class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" - title: Optional[str] = "Clip Classification Precision-Recall Curve" + title: str | None = "Clip Classification Precision-Recall Curve" separate_figures: bool = False @@ -111,7 +111,7 @@ class PRCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" - title: Optional[str] = "Clip Classification ROC Curve" + title: str | None = "Clip Classification ROC Curve" separate_figures: bool = False @@ -174,10 +174,7 @@ class ROCCurve(BasePlot): ClipClassificationPlotConfig = Annotated[ - Union[ - PRCurveConfig, - ROCCurveConfig, - ], + PRCurveConfig | ROCCurveConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/clip_detection.py b/src/batdetect2/evaluate/plots/clip_detection.py index cfcfb58..1feec3e 100644 --- a/src/batdetect2/evaluate/plots/clip_detection.py +++ b/src/batdetect2/evaluate/plots/clip_detection.py @@ -41,7 +41,7 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = ( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" - title: Optional[str] = "Clip Detection Precision-Recall Curve" + title: str | None = "Clip Detection Precision-Recall Curve" class PRCurve(BasePlot): @@ -74,7 +74,7 @@ class PRCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" - title: Optional[str] = "Clip Detection ROC Curve" + title: str | None = "Clip Detection ROC Curve" class ROCCurve(BasePlot): @@ -107,7 +107,7 @@ class ROCCurve(BasePlot): class ScoreDistributionPlotConfig(BasePlotConfig): name: Literal["score_distribution"] = "score_distribution" label: str = "score_distribution" - title: Optional[str] = "Clip Detection Score Distribution" + title: str | None = "Clip Detection Score Distribution" class ScoreDistributionPlot(BasePlot): @@ -147,11 +147,7 @@ class ScoreDistributionPlot(BasePlot): ClipDetectionPlotConfig = Annotated[ - Union[ - PRCurveConfig, - ROCCurveConfig, - ScoreDistributionPlotConfig, - ], + PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/detection.py b/src/batdetect2/evaluate/plots/detection.py index 29e2b86..34756ea 100644 --- a/src/batdetect2/evaluate/plots/detection.py +++ b/src/batdetect2/evaluate/plots/detection.py @@ -37,7 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" - title: Optional[str] = "Detection Precision-Recall Curve" + title: str | None = "Detection Precision-Recall Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -100,7 +100,7 @@ class PRCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" - title: Optional[str] = "Detection ROC Curve" + title: str | None = "Detection ROC Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -159,7 +159,7 @@ class ROCCurve(BasePlot): class ScoreDistributionPlotConfig(BasePlotConfig): name: Literal["score_distribution"] = "score_distribution" label: str = "score_distribution" - title: Optional[str] = "Detection Score Distribution" + title: str | None = "Detection Score Distribution" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -226,7 +226,7 @@ class ScoreDistributionPlot(BasePlot): class ExampleDetectionPlotConfig(BasePlotConfig): name: Literal["example_detection"] = "example_detection" label: str = "example_detection" - title: Optional[str] = "Example Detection" + title: str | None = "Example Detection" figsize: tuple[int, int] = (10, 4) num_examples: int = 5 threshold: float = 0.2 @@ -292,12 +292,7 @@ class ExampleDetectionPlot(BasePlot): DetectionPlotConfig = Annotated[ - Union[ - PRCurveConfig, - ROCCurveConfig, - ScoreDistributionPlotConfig, - ExampleDetectionPlotConfig, - ], + PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py index fe43263..5bbc721 100644 --- a/src/batdetect2/evaluate/plots/top_class.py +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -44,7 +44,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" - title: Optional[str] = "Top Class Precision-Recall Curve" + title: str | None = "Top Class Precision-Recall Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -111,7 +111,7 @@ class PRCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" - title: Optional[str] = "Top Class ROC Curve" + title: str | None = "Top Class ROC Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -173,7 +173,7 @@ class ROCCurve(BasePlot): class ConfusionMatrixConfig(BasePlotConfig): name: Literal["confusion_matrix"] = "confusion_matrix" - title: Optional[str] = "Top Class Confusion Matrix" + title: str | None = "Top Class Confusion Matrix" figsize: tuple[int, int] = (10, 10) label: str = "confusion_matrix" exclude_generic: bool = True @@ -257,7 +257,7 @@ class ConfusionMatrix(BasePlot): class ExampleClassificationPlotConfig(BasePlotConfig): name: Literal["example_classification"] = "example_classification" label: str = "example_classification" - title: Optional[str] = "Example Classification" + title: str | None = "Example Classification" num_examples: int = 4 threshold: float = 0.2 audio: AudioConfig = Field(default_factory=AudioConfig) @@ -348,12 +348,7 @@ class ExampleClassificationPlot(BasePlot): TopClassPlotConfig = Annotated[ - Union[ - PRCurveConfig, - ROCCurveConfig, - ConfusionMatrixConfig, - ExampleClassificationPlotConfig, - ], + PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/tables.py b/src/batdetect2/evaluate/tables.py index d623529..8724d8f 100644 --- a/src/batdetect2/evaluate/tables.py +++ b/src/batdetect2/evaluate/tables.py @@ -96,7 +96,7 @@ def extract_matches_dataframe( EvaluationTableConfig = Annotated[ - Union[FullEvaluationTableConfig,], Field(discriminator="name") + FullEvaluationTableConfig, Field(discriminator="name") ] diff --git a/src/batdetect2/evaluate/tasks/__init__.py b/src/batdetect2/evaluate/tasks/__init__.py index 5f4947d..9c310d9 100644 --- a/src/batdetect2/evaluate/tasks/__init__.py +++ b/src/batdetect2/evaluate/tasks/__init__.py @@ -26,20 +26,14 @@ __all__ = [ TaskConfig = Annotated[ - Union[ - ClassificationTaskConfig, - DetectionTaskConfig, - ClipDetectionTaskConfig, - ClipClassificationTaskConfig, - TopClassDetectionTaskConfig, - ], + ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig, Field(discriminator="name"), ] def build_task( config: TaskConfig, - targets: Optional[TargetProtocol] = None, + targets: TargetProtocol | None = None, ) -> EvaluatorProtocol: targets = targets or build_targets() return tasks_registry.build(config, targets) @@ -49,8 +43,8 @@ def evaluate_task( clip_annotations: Sequence[data.ClipAnnotation], predictions: Sequence[BatDetect2Prediction], task: Optional["str"] = None, - targets: Optional[TargetProtocol] = None, - config: Optional[Union[TaskConfig, dict]] = None, + targets: TargetProtocol | None = None, + config: TaskConfig | dict | None = None, ): if isinstance(config, BaseTaskConfig): task_obj = build_task(config, targets) diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 3835f86..2c793f2 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -67,9 +67,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], prefix: str, ignore_start_end: float = 0.01, - plots: Optional[ - List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] - ] = None, + plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None, ): self.matcher = matcher self.metrics = metrics @@ -147,9 +145,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): config: BaseTaskConfig, targets: TargetProtocol, metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], - plots: Optional[ - List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] - ] = None, + plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None, **kwargs, ): matcher = build_matcher(config.matching_strategy) diff --git a/src/batdetect2/finetune/finetune_model.py b/src/batdetect2/finetune/finetune_model.py index 3135d78..b049b7a 100644 --- a/src/batdetect2/finetune/finetune_model.py +++ b/src/batdetect2/finetune/finetune_model.py @@ -98,8 +98,8 @@ def load_annotations( dataset_name: str, ann_path: str, audio_path: str, - classes_to_ignore: Optional[List[str]] = None, - events_of_interest: Optional[List[str]] = None, + classes_to_ignore: List[str] | None = None, + events_of_interest: List[str] | None = None, ) -> List[types.FileAnnotation]: train_sets: List[types.DatasetDict] = [] train_sets.append( diff --git a/src/batdetect2/finetune/prep_data_finetune.py b/src/batdetect2/finetune/prep_data_finetune.py index 1aea005..1df46eb 100644 --- a/src/batdetect2/finetune/prep_data_finetune.py +++ b/src/batdetect2/finetune/prep_data_finetune.py @@ -13,7 +13,7 @@ from batdetect2 import types def print_dataset_stats( data: List[types.FileAnnotation], - classes_to_ignore: Optional[List[str]] = None, + classes_to_ignore: List[str] | None = None, ) -> Counter[str]: print("Num files:", len(data)) counts, _ = tu.get_class_names(data, classes_to_ignore) diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index e3aee99..b97848d 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -28,8 +28,8 @@ def run_batch_inference( audio_loader: Optional["AudioLoader"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, config: Optional["BatDetect2Config"] = None, - num_workers: Optional[int] = None, - batch_size: Optional[int] = None, + num_workers: int | None = None, + batch_size: int | None = None, ) -> List[BatDetect2Prediction]: from batdetect2.config import BatDetect2Config @@ -69,7 +69,7 @@ def process_file_list( targets: Optional["TargetProtocol"] = None, audio_loader: Optional["AudioLoader"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, - num_workers: Optional[int] = None, + num_workers: int | None = None, ) -> List[BatDetect2Prediction]: clip_config = config.inference.clipping clips = get_clips_from_files( diff --git a/src/batdetect2/inference/dataset.py b/src/batdetect2/inference/dataset.py index b2fba56..62b696a 100644 --- a/src/batdetect2/inference/dataset.py +++ b/src/batdetect2/inference/dataset.py @@ -36,7 +36,7 @@ class InferenceDataset(Dataset[DatasetItem]): clips: Sequence[data.Clip], audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ): self.clips = list(clips) self.preprocessor = preprocessor @@ -66,11 +66,11 @@ class InferenceLoaderConfig(BaseConfig): def build_inference_loader( clips: Sequence[data.Clip], - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[InferenceLoaderConfig] = None, - num_workers: Optional[int] = None, - batch_size: Optional[int] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + config: InferenceLoaderConfig | None = None, + num_workers: int | None = None, + batch_size: int | None = None, ) -> DataLoader[DatasetItem]: logger.info("Building inference data loader...") config = config or InferenceLoaderConfig() @@ -95,8 +95,8 @@ def build_inference_loader( def build_inference_dataset( clips: Sequence[data.Clip], - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, ) -> InferenceDataset: if audio_loader is None: audio_loader = build_audio_loader() diff --git a/src/batdetect2/logging.py b/src/batdetect2/logging.py index a423ca2..01f8acc 100644 --- a/src/batdetect2/logging.py +++ b/src/batdetect2/logging.py @@ -49,14 +49,14 @@ def enable_logging(level: int): class BaseLoggerConfig(BaseConfig): log_dir: Path = DEFAULT_LOGS_DIR - experiment_name: Optional[str] = None - run_name: Optional[str] = None + experiment_name: str | None = None + run_name: str | None = None class DVCLiveConfig(BaseLoggerConfig): name: Literal["dvclive"] = "dvclive" prefix: str = "" - log_model: Union[bool, Literal["all"]] = False + log_model: bool | Literal["all"] = False monitor_system: bool = False @@ -72,18 +72,13 @@ class TensorBoardLoggerConfig(BaseLoggerConfig): class MLFlowLoggerConfig(BaseLoggerConfig): name: Literal["mlflow"] = "mlflow" - tracking_uri: Optional[str] = "http://localhost:5000" - tags: Optional[dict[str, Any]] = None + tracking_uri: str | None = "http://localhost:5000" + tags: dict[str, Any] | None = None log_model: bool = False LoggerConfig = Annotated[ - Union[ - DVCLiveConfig, - CSVLoggerConfig, - TensorBoardLoggerConfig, - MLFlowLoggerConfig, - ], + DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig, Field(discriminator="name"), ] @@ -95,17 +90,17 @@ class LoggerBuilder(Protocol, Generic[T]): def __call__( self, config: T, - log_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + log_dir: Path | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Logger: ... def create_dvclive_logger( config: DVCLiveConfig, - log_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + log_dir: Path | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Logger: try: from dvclive.lightning import DVCLiveLogger # type: ignore @@ -130,9 +125,9 @@ def create_dvclive_logger( def create_csv_logger( config: CSVLoggerConfig, - log_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + log_dir: Path | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Logger: from lightning.pytorch.loggers import CSVLogger @@ -159,9 +154,9 @@ def create_csv_logger( def create_tensorboard_logger( config: TensorBoardLoggerConfig, - log_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + log_dir: Path | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Logger: from lightning.pytorch.loggers import TensorBoardLogger @@ -191,9 +186,9 @@ def create_tensorboard_logger( def create_mlflow_logger( config: MLFlowLoggerConfig, - log_dir: Optional[data.PathLike] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + log_dir: data.PathLike | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Logger: try: from lightning.pytorch.loggers import MLFlowLogger @@ -232,9 +227,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = { def build_logger( config: LoggerConfig, - log_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + log_dir: Path | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Logger: logger.opt(lazy=True).debug( "Building logger with config: \n{}", @@ -257,7 +252,7 @@ def build_logger( PlotLogger = Callable[[str, Figure, int], None] -def get_image_logger(logger: Logger) -> Optional[PlotLogger]: +def get_image_logger(logger: Logger) -> PlotLogger | None: if isinstance(logger, TensorBoardLogger): return logger.experiment.add_figure @@ -282,7 +277,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]: TableLogger = Callable[[str, pd.DataFrame, int], None] -def get_table_logger(logger: Logger) -> Optional[TableLogger]: +def get_table_logger(logger: Logger) -> TableLogger | None: if isinstance(logger, TensorBoardLogger): return partial(save_table, dir=Path(logger.log_dir)) diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 7ab11e7..fc15357 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -43,10 +43,7 @@ from batdetect2.models.bottleneck import ( BottleneckConfig, build_bottleneck, ) -from batdetect2.models.config import ( - BackboneConfig, - load_backbone_config, -) +from batdetect2.models.config import BackboneConfig, load_backbone_config from batdetect2.models.decoder import ( DEFAULT_DECODER_CONFIG, DecoderConfig, @@ -122,10 +119,10 @@ class Model(torch.nn.Module): def build_model( - config: Optional[BackboneConfig] = None, - targets: Optional[TargetProtocol] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - postprocessor: Optional[PostprocessorProtocol] = None, + config: BackboneConfig | None = None, + targets: TargetProtocol | None = None, + preprocessor: PreprocessorProtocol | None = None, + postprocessor: PostprocessorProtocol | None = None, ): from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index 253e702..879fc18 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -78,8 +78,8 @@ class Bottleneck(nn.Module): input_height: int, in_channels: int, out_channels: int, - bottleneck_channels: Optional[int] = None, - layers: Optional[List[torch.nn.Module]] = None, + bottleneck_channels: int | None = None, + layers: List[torch.nn.Module] | None = None, ) -> None: """Initialize the base Bottleneck layer.""" super().__init__() @@ -127,7 +127,7 @@ class Bottleneck(nn.Module): BottleneckLayerConfig = Annotated[ - Union[SelfAttentionConfig,], + SelfAttentionConfig, Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Decoder.""" @@ -171,7 +171,7 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig( def build_bottleneck( input_height: int, in_channels: int, - config: Optional[BottleneckConfig] = None, + config: BottleneckConfig | None = None, ) -> nn.Module: """Factory function to build the Bottleneck module from configuration. diff --git a/src/batdetect2/models/config.py b/src/batdetect2/models/config.py index 0f34b34..12c0f26 100644 --- a/src/batdetect2/models/config.py +++ b/src/batdetect2/models/config.py @@ -63,7 +63,7 @@ class BackboneConfig(BaseConfig): def load_backbone_config( path: data.PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> BackboneConfig: """Load the backbone configuration from a file. diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index dd74270..233b0d3 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -41,12 +41,7 @@ __all__ = [ ] DecoderLayerConfig = Annotated[ - Union[ - ConvConfig, - FreqCoordConvUpConfig, - StandardConvUpConfig, - LayerGroupConfig, - ], + ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig, Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Decoder.""" @@ -216,7 +211,7 @@ convolutional block. def build_decoder( in_channels: int, input_height: int, - config: Optional[DecoderConfig] = None, + config: DecoderConfig | None = None, ) -> Decoder: """Factory function to build a Decoder instance from configuration. diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 518ae94..7bad757 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -127,7 +127,7 @@ class Detector(DetectionModel): def build_detector( - num_classes: int, config: Optional[BackboneConfig] = None + num_classes: int, config: BackboneConfig | None = None ) -> DetectionModel: """Build the complete BatDetect2 detection model. diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index e7da745..70aa6c8 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -43,12 +43,7 @@ __all__ = [ ] EncoderLayerConfig = Annotated[ - Union[ - ConvConfig, - FreqCoordConvDownConfig, - StandardConvDownConfig, - LayerGroupConfig, - ], + ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig, Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Encoder.""" @@ -252,7 +247,7 @@ Specifies an architecture typically used in BatDetect2: def build_encoder( in_channels: int, input_height: int, - config: Optional[EncoderConfig] = None, + config: EncoderConfig | None = None, ) -> Encoder: """Factory function to build an Encoder instance from configuration. diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index 400a23a..c52de62 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -15,10 +15,10 @@ __all__ = [ def plot_clip_annotation( clip_annotation: data.ClipAnnotation, - preprocessor: Optional[PreprocessorProtocol] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, + preprocessor: PreprocessorProtocol | None = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, + audio_dir: data.PathLike | None = None, add_points: bool = False, cmap: str = "gray", alpha: float = 1, @@ -50,8 +50,8 @@ def plot_clip_annotation( def plot_anchor_points( clip_annotation: data.ClipAnnotation, targets: TargetProtocol, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, size: int = 1, color: str = "red", marker: str = "x", diff --git a/src/batdetect2/plotting/clip_predictions.py b/src/batdetect2/plotting/clip_predictions.py index 86f4833..1827a00 100644 --- a/src/batdetect2/plotting/clip_predictions.py +++ b/src/batdetect2/plotting/clip_predictions.py @@ -17,10 +17,10 @@ __all__ = [ def plot_clip_prediction( clip_prediction: data.ClipPrediction, - preprocessor: Optional[PreprocessorProtocol] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, + preprocessor: PreprocessorProtocol | None = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, + audio_dir: data.PathLike | None = None, add_legend: bool = False, spec_cmap: str = "gray", linewidth: float = 1, @@ -50,14 +50,14 @@ def plot_clip_prediction( def plot_predictions( predictions: Iterable[data.SoundEventPrediction], - ax: Optional[Axes] = None, + ax: Axes | None = None, position: Positions = "top-right", - color_mapper: Optional[TagColorMapper] = None, + color_mapper: TagColorMapper | None = None, time_offset: float = 0.001, freq_offset: float = 1000, legend: bool = True, max_alpha: float = 0.5, - color: Optional[str] = None, + color: str | None = None, **kwargs, ): """Plot an prediction.""" @@ -88,14 +88,14 @@ def plot_predictions( def plot_prediction( prediction: data.SoundEventPrediction, - ax: Optional[Axes] = None, + ax: Axes | None = None, position: Positions = "top-right", - color_mapper: Optional[TagColorMapper] = None, + color_mapper: TagColorMapper | None = None, time_offset: float = 0.001, freq_offset: float = 1000, max_alpha: float = 0.5, - alpha: Optional[float] = None, - color: Optional[str] = None, + alpha: float | None = None, + color: str | None = None, **kwargs, ) -> Axes: """Plot an annotation.""" diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py index 63e978f..e634c37 100644 --- a/src/batdetect2/plotting/clips.py +++ b/src/batdetect2/plotting/clips.py @@ -17,11 +17,11 @@ __all__ = [ def plot_clip( clip: data.Clip, - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, + audio_dir: data.PathLike | None = None, spec_cmap: str = "gray", ) -> Axes: if ax is None: diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index d79ae02..ef1231f 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -13,8 +13,8 @@ __all__ = [ def create_ax( - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, **kwargs, ) -> axes.Axes: """Create a new axis if none is provided""" @@ -25,17 +25,17 @@ def create_ax( def plot_spectrogram( - spec: Union[torch.Tensor, np.ndarray], - start_time: Optional[float] = None, - end_time: Optional[float] = None, - min_freq: Optional[float] = None, - max_freq: Optional[float] = None, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + spec: torch.Tensor | np.ndarray, + start_time: float | None = None, + end_time: float | None = None, + min_freq: float | None = None, + max_freq: float | None = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_colorbar: bool = False, - colorbar_kwargs: Optional[dict] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, + colorbar_kwargs: dict | None = None, + vmin: float | None = None, + vmax: float | None = None, cmap="gray", ) -> axes.Axes: if isinstance(spec, torch.Tensor): diff --git a/src/batdetect2/plotting/detections.py b/src/batdetect2/plotting/detections.py index 800b8b6..442953c 100644 --- a/src/batdetect2/plotting/detections.py +++ b/src/batdetect2/plotting/detections.py @@ -19,9 +19,9 @@ __all__ = [ def plot_clip_detections( clip_eval: ClipEval, figsize: tuple[int, int] = (10, 10), - ax: Optional[axes.Axes] = None, - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, + ax: axes.Axes | None = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, threshold: float = 0.2, add_legend: bool = True, add_title: bool = True, diff --git a/src/batdetect2/plotting/gallery.py b/src/batdetect2/plotting/gallery.py index 1a06f9e..d16ac45 100644 --- a/src/batdetect2/plotting/gallery.py +++ b/src/batdetect2/plotting/gallery.py @@ -20,11 +20,11 @@ def plot_match_gallery( false_positives: Sequence[MatchProtocol], false_negatives: Sequence[MatchProtocol], cross_triggers: Sequence[MatchProtocol], - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, n_examples: int = 5, duration: float = 0.1, - fig: Optional[Figure] = None, + fig: Figure | None = None, ): if fig is None: fig = plt.figure(figsize=(20, 20)) diff --git a/src/batdetect2/plotting/heatmaps.py b/src/batdetect2/plotting/heatmaps.py index 8354b38..fd67632 100644 --- a/src/batdetect2/plotting/heatmaps.py +++ b/src/batdetect2/plotting/heatmaps.py @@ -12,13 +12,13 @@ from batdetect2.plotting.common import create_ax def plot_detection_heatmap( - heatmap: Union[torch.Tensor, np.ndarray], - ax: Optional[axes.Axes] = None, + heatmap: torch.Tensor | np.ndarray, + ax: axes.Axes | None = None, figsize: Tuple[int, int] = (10, 10), - threshold: Optional[float] = None, + threshold: float | None = None, alpha: float = 1, - cmap: Union[str, Colormap] = "jet", - color: Optional[str] = None, + cmap: str | Colormap = "jet", + color: str | None = None, ) -> axes.Axes: ax = create_ax(ax, figsize=figsize) @@ -48,13 +48,13 @@ def plot_detection_heatmap( def plot_classification_heatmap( - heatmap: Union[torch.Tensor, np.ndarray], - ax: Optional[axes.Axes] = None, + heatmap: torch.Tensor | np.ndarray, + ax: axes.Axes | None = None, figsize: Tuple[int, int] = (10, 10), - class_names: Optional[List[str]] = None, - threshold: Optional[float] = 0.1, + class_names: List[str] | None = None, + threshold: float | None = 0.1, alpha: float = 1, - cmap: Union[str, Colormap] = "tab20", + cmap: str | Colormap = "tab20", ): ax = create_ax(ax, figsize=figsize) diff --git a/src/batdetect2/plotting/legacy/plot.py b/src/batdetect2/plotting/legacy/plot.py index b9e5d4e..84ad89a 100644 --- a/src/batdetect2/plotting/legacy/plot.py +++ b/src/batdetect2/plotting/legacy/plot.py @@ -24,10 +24,10 @@ __all__ = [ def spectrogram( - spec: Union[torch.Tensor, np.ndarray], - config: Optional[ProcessingConfiguration] = None, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + spec: torch.Tensor | np.ndarray, + config: ProcessingConfiguration | None = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, cmap: str = "plasma", start_time: float = 0, ) -> axes.Axes: @@ -103,11 +103,11 @@ def spectrogram( def spectrogram_with_detections( - spec: Union[torch.Tensor, np.ndarray], + spec: torch.Tensor | np.ndarray, dets: List[Annotation], - config: Optional[ProcessingConfiguration] = None, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + config: ProcessingConfiguration | None = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, cmap: str = "plasma", with_names: bool = True, start_time: float = 0, @@ -168,8 +168,8 @@ def spectrogram_with_detections( def detections( dets: List[Annotation], - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, with_names: bool = True, **kwargs, ) -> axes.Axes: @@ -213,8 +213,8 @@ def detections( def detection( det: Annotation, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, linewidth: float = 1, edgecolor: str = "w", facecolor: str = "none", diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 1803dd3..7da593c 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -21,10 +21,10 @@ __all__ = [ class MatchProtocol(Protocol): clip: data.Clip - gt: Optional[data.SoundEventAnnotation] - pred: Optional[RawPrediction] + gt: data.SoundEventAnnotation | None + pred: RawPrediction | None score: float - true_class: Optional[str] + true_class: str | None DEFAULT_DURATION = 0.05 @@ -38,11 +38,11 @@ DEFAULT_PREDICTION_LINE_STYLE = "--" def plot_false_positive_match( match: MatchProtocol, - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, + audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, use_score: bool = True, add_spectrogram: bool = True, @@ -52,7 +52,7 @@ def plot_false_positive_match( fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_FALSE_POSITIVE_COLOR, - fontsize: Union[float, str] = "small", + fontsize: float | str = "small", ) -> Axes: assert match.pred is not None @@ -109,11 +109,11 @@ def plot_false_positive_match( def plot_false_negative_match( match: MatchProtocol, - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, + audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, add_spectrogram: bool = True, add_points: bool = False, @@ -169,11 +169,11 @@ def plot_false_negative_match( def plot_true_positive_match( match: MatchProtocol, - preprocessor: Optional[PreprocessorProtocol] = None, - audio_loader: Optional[AudioLoader] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, + preprocessor: PreprocessorProtocol | None = None, + audio_loader: AudioLoader | None = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, + audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, use_score: bool = True, add_spectrogram: bool = True, @@ -182,7 +182,7 @@ def plot_true_positive_match( fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_TRUE_POSITIVE_COLOR, - fontsize: Union[float, str] = "small", + fontsize: float | str = "small", annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, add_title: bool = True, @@ -257,11 +257,11 @@ def plot_true_positive_match( def plot_cross_trigger_match( match: MatchProtocol, - preprocessor: Optional[PreprocessorProtocol] = None, - audio_loader: Optional[AudioLoader] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, + preprocessor: PreprocessorProtocol | None = None, + audio_loader: AudioLoader | None = None, + figsize: Tuple[int, int] | None = None, + ax: Axes | None = None, + audio_dir: data.PathLike | None = None, duration: float = DEFAULT_DURATION, use_score: bool = True, add_spectrogram: bool = True, @@ -271,7 +271,7 @@ def plot_cross_trigger_match( fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_CROSS_TRIGGER_COLOR, - fontsize: Union[float, str] = "small", + fontsize: float | str = "small", annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, ) -> Axes: diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py index acb9099..61ce7b4 100644 --- a/src/batdetect2/plotting/metrics.py +++ b/src/batdetect2/plotting/metrics.py @@ -33,16 +33,16 @@ def plot_pr_curve( precision: np.ndarray, recall: np.ndarray, thresholds: np.ndarray, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, - color: Union[str, Tuple[float, float, float], None] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, + color: str | Tuple[float, float, float] | None = None, add_labels: bool = True, add_legend: bool = False, - marker: Union[str, Tuple[int, int, float], None] = "o", - markeredgecolor: Union[str, Tuple[float, float, float], None] = None, - markersize: Optional[float] = None, - linestyle: Union[str, Tuple[int, ...], None] = None, - linewidth: Optional[float] = None, + marker: str | Tuple[int, int, float] | None = "o", + markeredgecolor: str | Tuple[float, float, float] | None = None, + markersize: float | None = None, + linestyle: str | Tuple[int, ...] | None = None, + linewidth: float | None = None, label: str = "PR Curve", ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) @@ -77,8 +77,8 @@ def plot_pr_curve( def plot_pr_curves( data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, include_ap: bool = False, @@ -118,8 +118,8 @@ def plot_pr_curves( def plot_threshold_precision_curve( threshold: np.ndarray, precision: np.ndarray, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_labels: bool = True, ): ax = create_ax(ax=ax, figsize=figsize) @@ -140,8 +140,8 @@ def plot_threshold_precision_curve( def plot_threshold_precision_curves( data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, ): @@ -176,8 +176,8 @@ def plot_threshold_precision_curves( def plot_threshold_recall_curve( threshold: np.ndarray, recall: np.ndarray, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_labels: bool = True, ): ax = create_ax(ax=ax, figsize=figsize) @@ -198,8 +198,8 @@ def plot_threshold_recall_curve( def plot_threshold_recall_curves( data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, ): @@ -235,8 +235,8 @@ def plot_roc_curve( fpr: np.ndarray, tpr: np.ndarray, thresholds: np.ndarray, - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_labels: bool = True, ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) @@ -261,8 +261,8 @@ def plot_roc_curve( def plot_roc_curves( data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], - ax: Optional[axes.Axes] = None, - figsize: Optional[Tuple[int, int]] = None, + ax: axes.Axes | None = None, + figsize: Tuple[int, int] | None = None, add_legend: bool = True, add_labels: bool = True, ) -> axes.Axes: diff --git a/src/batdetect2/postprocess/config.py b/src/batdetect2/postprocess/config.py index 8299d5c..82de367 100644 --- a/src/batdetect2/postprocess/config.py +++ b/src/batdetect2/postprocess/config.py @@ -57,7 +57,7 @@ class PostprocessConfig(BaseConfig): def load_postprocess_config( path: data.PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> PostprocessConfig: """Load the postprocessing configuration from a file. diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index e7e1635..5fb65e4 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -88,9 +88,7 @@ def convert_raw_prediction_to_sound_event_prediction( raw_prediction: RawPrediction, recording: data.Recording, targets: TargetProtocol, - classification_threshold: Optional[ - float - ] = DEFAULT_CLASSIFICATION_THRESHOLD, + classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only: bool = False, ): """Convert a single RawPrediction into a soundevent SoundEventPrediction.""" @@ -150,7 +148,7 @@ def get_class_tags( class_scores: np.ndarray, targets: TargetProtocol, top_class_only: bool = False, - threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD, + threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, ) -> List[data.PredictedTag]: """Generate specific PredictedTags based on class scores and decoder. diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index a1b5a65..74fdf20 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -32,7 +32,7 @@ def extract_detection_peaks( feature_heatmap: torch.Tensor, classification_heatmap: torch.Tensor, max_detections: int = 200, - threshold: Optional[float] = None, + threshold: float | None = None, ) -> List[ClipDetectionsTensor]: height = detection_heatmap.shape[-2] width = detection_heatmap.shape[-1] diff --git a/src/batdetect2/postprocess/nms.py b/src/batdetect2/postprocess/nms.py index f92800d..80114d0 100644 --- a/src/batdetect2/postprocess/nms.py +++ b/src/batdetect2/postprocess/nms.py @@ -27,7 +27,7 @@ BatDetect2. def non_max_suppression( tensor: torch.Tensor, - kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, + kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE, ) -> torch.Tensor: """Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap. diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index 677b315..c3f7b9c 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -24,7 +24,7 @@ __all__ = [ def build_postprocessor( preprocessor: PreprocessorProtocol, - config: Optional[PostprocessConfig] = None, + config: PostprocessConfig | None = None, ) -> PostprocessorProtocol: """Factory function to build the standard postprocessor.""" config = config or PostprocessConfig() @@ -51,7 +51,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): max_freq: float, top_k_per_sec: int = 200, detection_threshold: float = 0.01, - nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, + nms_kernel_size: int | Tuple[int, int] = NMS_KERNEL_SIZE, ): """Initialize the Postprocessor.""" super().__init__() @@ -66,7 +66,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): def forward( self, output: ModelOutput, - start_times: Optional[List[float]] = None, + start_times: List[float] | None = None, ) -> List[ClipDetectionsTensor]: detection_heatmap = non_max_suppression( output.detection_probs.detach(), diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py index 06f6321..516b86e 100644 --- a/src/batdetect2/postprocess/remapping.py +++ b/src/batdetect2/postprocess/remapping.py @@ -31,14 +31,14 @@ __all__ = [ def to_xarray( - array: Union[torch.Tensor, np.ndarray], + array: torch.Tensor | np.ndarray, start_time: float, end_time: float, min_freq: float = MIN_FREQ, max_freq: float = MAX_FREQ, name: str = "xarray", - extra_dims: Optional[List[str]] = None, - extra_coords: Optional[Dict[str, np.ndarray]] = None, + extra_dims: List[str] | None = None, + extra_coords: Dict[str, np.ndarray] | None = None, ) -> xr.DataArray: if isinstance(array, torch.Tensor): array = array.detach().cpu().numpy() diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index a9fd1f1..86f485b 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -78,11 +78,7 @@ class FixDuration(torch.nn.Module): AudioTransform = Annotated[ - Union[ - FixDurationConfig, - ScaleAudioConfig, - CenterAudioConfig, - ], + FixDurationConfig | ScaleAudioConfig | CenterAudioConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/preprocess/config.py b/src/batdetect2/preprocess/config.py index 2ac8150..9f2cd24 100644 --- a/src/batdetect2/preprocess/config.py +++ b/src/batdetect2/preprocess/config.py @@ -57,6 +57,6 @@ class PreprocessingConfig(BaseConfig): def load_preprocessing_config( path: PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> PreprocessingConfig: return load_config(path, schema=PreprocessingConfig, field=field) diff --git a/src/batdetect2/preprocess/preprocessor.py b/src/batdetect2/preprocess/preprocessor.py index ccd0f46..4358bce 100644 --- a/src/batdetect2/preprocess/preprocessor.py +++ b/src/batdetect2/preprocess/preprocessor.py @@ -102,7 +102,7 @@ def compute_output_samplerate( def build_preprocessor( - config: Optional[PreprocessingConfig] = None, + config: PreprocessingConfig | None = None, input_samplerate: int = TARGET_SAMPLERATE_HZ, ) -> PreprocessorProtocol: """Factory function to build the standard preprocessor from configuration.""" diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 5fcf073..e59a8c7 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -98,7 +98,7 @@ def _frequency_to_index( freq: float, n_fft: int, samplerate: int = TARGET_SAMPLERATE_HZ, -) -> Optional[int]: +) -> int | None: alpha = freq * 2 / samplerate height = np.floor(n_fft / 2) + 1 index = int(np.floor(alpha * height)) @@ -134,8 +134,8 @@ class FrequencyCrop(torch.nn.Module): self, samplerate: int, n_fft: int, - min_freq: Optional[int] = None, - max_freq: Optional[int] = None, + min_freq: int | None = None, + max_freq: int | None = None, ): super().__init__() self.n_fft = n_fft @@ -181,7 +181,7 @@ class FrequencyCrop(torch.nn.Module): def build_spectrogram_crop( config: FrequencyConfig, - stft: Optional[STFTConfig] = None, + stft: STFTConfig | None = None, samplerate: int = TARGET_SAMPLERATE_HZ, ) -> torch.nn.Module: stft = stft or STFTConfig() @@ -377,12 +377,7 @@ class PeakNormalize(torch.nn.Module): SpectrogramTransform = Annotated[ - Union[ - PcenConfig, - ScaleAmplitudeConfig, - SpectralMeanSubstractionConfig, - PeakNormalizeConfig, - ], + PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 47d7c98..138901b 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -30,16 +30,16 @@ class TargetClassConfig(BaseConfig): name: str - condition_input: Optional[SoundEventConditionConfig] = Field( + condition_input: SoundEventConditionConfig | None = Field( alias="match_if", default=None, ) - tags: Optional[List[data.Tag]] = Field(default=None, exclude=True) + tags: List[data.Tag] | None = Field(default=None, exclude=True) assign_tags: List[data.Tag] = Field(default_factory=list) - roi: Optional[ROIMapperConfig] = None + roi: ROIMapperConfig | None = None _match_if: SoundEventConditionConfig = PrivateAttr() @@ -202,7 +202,7 @@ class SoundEventClassifier: def __call__( self, sound_event_annotation: data.SoundEventAnnotation - ) -> Optional[str]: + ) -> str | None: for name, condition in self.mapping.items(): if condition(sound_event_annotation): return name diff --git a/src/batdetect2/targets/config.py b/src/batdetect2/targets/config.py index 73207d3..87f5491 100644 --- a/src/batdetect2/targets/config.py +++ b/src/batdetect2/targets/config.py @@ -48,7 +48,7 @@ class TargetConfig(BaseConfig): def load_target_config( path: data.PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> TargetConfig: """Load the unified target configuration from a file. diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index ee81e73..27d7d09 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -414,10 +414,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper): ROIMapperConfig = Annotated[ - Union[ - AnchorBBoxMapperConfig, - PeakEnergyBBoxMapperConfig, - ], + AnchorBBoxMapperConfig | PeakEnergyBBoxMapperConfig, Field(discriminator="name"), ] """A discriminated union of all supported ROI mapper configurations. @@ -428,7 +425,7 @@ implementations by using the `name` field as a discriminator. def build_roi_mapper( - config: Optional[ROIMapperConfig] = None, + config: ROIMapperConfig | None = None, ) -> ROITargetMapper: """Factory function to create an ROITargetMapper from a config object. @@ -572,9 +569,9 @@ def get_peak_energy_coordinates( audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, start_time: float = 0, - end_time: Optional[float] = None, + end_time: float | None = None, low_freq: float = 0, - high_freq: Optional[float] = None, + high_freq: float | None = None, loading_buffer: float = 0.05, ) -> Position: """Find the coordinates of the highest energy point in a spectrogram. diff --git a/src/batdetect2/targets/targets.py b/src/batdetect2/targets/targets.py index a692025..9aa4fe8 100644 --- a/src/batdetect2/targets/targets.py +++ b/src/batdetect2/targets/targets.py @@ -107,7 +107,7 @@ class Targets(TargetProtocol): def encode_class( self, sound_event: data.SoundEventAnnotation - ) -> Optional[str]: + ) -> str | None: """Encode a sound event annotation to its target class name. Applies the configured class definition rules (including priority) @@ -182,7 +182,7 @@ class Targets(TargetProtocol): self, position: Position, size: Size, - class_name: Optional[str] = None, + class_name: str | None = None, ) -> data.Geometry: """Recover an approximate geometric ROI from a position and dimensions. @@ -219,7 +219,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( ) -def build_targets(config: Optional[TargetConfig] = None) -> Targets: +def build_targets(config: TargetConfig | None = None) -> Targets: """Build a Targets object from a loaded TargetConfig. Parameters @@ -251,7 +251,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets: def load_targets( config_path: data.PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> Targets: """Load a Targets object directly from a configuration file. @@ -292,7 +292,7 @@ def load_targets( def iterate_encoded_sound_events( sound_events: Iterable[data.SoundEventAnnotation], targets: TargetProtocol, -) -> Iterable[Tuple[Optional[str], Position, Size]]: +) -> Iterable[Tuple[str | None, Position, Size]]: for sound_event in sound_events: if not targets.filter(sound_event): continue diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 36cba2c..9c67a65 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -42,7 +42,7 @@ __all__ = [ AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]] -audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = ( +audio_augmentations: Registry[Augmentation, [int, AudioSource | None]] = ( Registry(name="audio_augmentation") ) @@ -103,7 +103,7 @@ class MixAudio(torch.nn.Module): def from_config( config: MixAudioConfig, samplerate: int, - source: Optional[AudioSource], + source: AudioSource | None, ): if source is None: warnings.warn( @@ -207,7 +207,7 @@ class AddEcho(torch.nn.Module): def from_config( config: AddEchoConfig, samplerate: int, - source: Optional[AudioSource], + source: AudioSource | None, ): return AddEcho( samplerate=samplerate, @@ -487,33 +487,18 @@ def mask_frequency( AudioAugmentationConfig = Annotated[ - Union[ - MixAudioConfig, - AddEchoConfig, - ], + MixAudioConfig | AddEchoConfig, Field(discriminator="name"), ] SpectrogramAugmentationConfig = Annotated[ - Union[ - ScaleVolumeConfig, - WarpConfig, - MaskFrequencyConfig, - MaskTimeConfig, - ], + ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig, Field(discriminator="name"), ] AugmentationConfig = Annotated[ - Union[ - MixAudioConfig, - AddEchoConfig, - ScaleVolumeConfig, - WarpConfig, - MaskFrequencyConfig, - MaskTimeConfig, - ], + MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig, Field(discriminator="name"), ] """Type alias for the discriminated union of individual augmentation config.""" @@ -559,8 +544,8 @@ class MaybeApply(torch.nn.Module): def build_augmentation_from_config( config: AugmentationConfig, samplerate: int, - audio_source: Optional[AudioSource] = None, -) -> Optional[Augmentation]: + audio_source: AudioSource | None = None, +) -> Augmentation | None: """Factory function to build a single augmentation from its config.""" if config.name == "mix_audio": if audio_source is None: @@ -645,10 +630,10 @@ class AugmentationSequence(torch.nn.Module): def build_audio_augmentations( - steps: Optional[Sequence[AudioAugmentationConfig]] = None, + steps: Sequence[AudioAugmentationConfig] | None = None, samplerate: int = TARGET_SAMPLERATE_HZ, - audio_source: Optional[AudioSource] = None, -) -> Optional[Augmentation]: + audio_source: AudioSource | None = None, +) -> Augmentation | None: if not steps: return None @@ -673,8 +658,8 @@ def build_audio_augmentations( def build_spectrogram_augmentations( - steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None, -) -> Optional[Augmentation]: + steps: Sequence[SpectrogramAugmentationConfig] | None = None, +) -> Augmentation | None: if not steps: return None @@ -698,9 +683,9 @@ def build_spectrogram_augmentations( def build_augmentations( samplerate: int, - config: Optional[AugmentationsConfig] = None, - audio_source: Optional[AudioSource] = None, -) -> Tuple[Optional[Augmentation], Optional[Augmentation]]: + config: AugmentationsConfig | None = None, + audio_source: AudioSource | None = None, +) -> Tuple[Augmentation | None, Augmentation | None]: """Build a composite augmentation pipeline function from configuration.""" config = config or DEFAULT_AUGMENTATION_CONFIG @@ -723,7 +708,7 @@ def build_augmentations( def load_augmentation_config( - path: data.PathLike, field: Optional[str] = None + path: data.PathLike, field: str | None = None ) -> AugmentationsConfig: """Load the augmentations configuration from a file.""" return load_config(path, schema=AugmentationsConfig, field=field) diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index 48b7432..f8d6127 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -18,14 +18,14 @@ class CheckpointConfig(BaseConfig): monitor: str = "classification/mean_average_precision" mode: str = "max" save_top_k: int = 1 - filename: Optional[str] = None + filename: str | None = None def build_checkpoint_callback( - config: Optional[CheckpointConfig] = None, - checkpoint_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + config: CheckpointConfig | None = None, + checkpoint_dir: Path | None = None, + experiment_name: str | None = None, + run_name: str | None = None, ) -> Callback: config = config or CheckpointConfig() diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 38db5ef..9645af5 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -22,20 +22,20 @@ class PLTrainerConfig(BaseConfig): accumulate_grad_batches: int = 1 deterministic: bool = True check_val_every_n_epoch: int = 1 - devices: Union[str, int] = "auto" + devices: str | int = "auto" enable_checkpointing: bool = True - gradient_clip_val: Optional[float] = None - limit_train_batches: Optional[Union[int, float]] = None - limit_test_batches: Optional[Union[int, float]] = None - limit_val_batches: Optional[Union[int, float]] = None - log_every_n_steps: Optional[int] = None - max_epochs: Optional[int] = 200 - min_epochs: Optional[int] = None - max_steps: Optional[int] = None - min_steps: Optional[int] = None - max_time: Optional[str] = None - precision: Optional[str] = None - val_check_interval: Optional[Union[int, float]] = None + gradient_clip_val: float | None = None + limit_train_batches: int | float | None = None + limit_test_batches: int | float | None = None + limit_val_batches: int | float | None = None + log_every_n_steps: int | None = None + max_epochs: int | None = 200 + min_epochs: int | None = None + max_steps: int | None = None + min_steps: int | None = None + max_time: str | None = None + precision: str | None = None + val_check_interval: int | float | None = None class OptimizerConfig(BaseConfig): @@ -57,6 +57,6 @@ class TrainingConfig(BaseConfig): def load_train_config( path: data.PathLike, - field: Optional[str] = None, + field: str | None = None, ) -> TrainingConfig: return load_config(path, schema=TrainingConfig, field=field) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 9fe54c9..3a6a68a 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -44,10 +44,10 @@ class TrainingDataset(Dataset): audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, - clipper: Optional[ClipperProtocol] = None, - audio_augmentation: Optional[Augmentation] = None, - spectrogram_augmentation: Optional[Augmentation] = None, - audio_dir: Optional[data.PathLike] = None, + clipper: ClipperProtocol | None = None, + audio_augmentation: Augmentation | None = None, + spectrogram_augmentation: Augmentation | None = None, + audio_dir: data.PathLike | None = None, ): self.clip_annotations = clip_annotations self.clipper = clipper @@ -108,8 +108,8 @@ class ValidationDataset(Dataset): audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, - clipper: Optional[ClipperProtocol] = None, - audio_dir: Optional[data.PathLike] = None, + clipper: ClipperProtocol | None = None, + audio_dir: data.PathLike | None = None, ): self.clip_annotations = clip_annotations self.labeller = labeller @@ -165,11 +165,11 @@ class TrainLoaderConfig(BaseConfig): def build_train_loader( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[TrainLoaderConfig] = None, - num_workers: Optional[int] = None, + audio_loader: AudioLoader | None = None, + labeller: ClipLabeller | None = None, + preprocessor: PreprocessorProtocol | None = None, + config: TrainLoaderConfig | None = None, + num_workers: int | None = None, ) -> DataLoader: config = config or TrainLoaderConfig() @@ -207,11 +207,11 @@ class ValLoaderConfig(BaseConfig): def build_val_loader( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[ValLoaderConfig] = None, - num_workers: Optional[int] = None, + audio_loader: AudioLoader | None = None, + labeller: ClipLabeller | None = None, + preprocessor: PreprocessorProtocol | None = None, + config: ValLoaderConfig | None = None, + num_workers: int | None = None, ): logger.info("Building validation data loader...") config = config or ValLoaderConfig() @@ -240,10 +240,10 @@ def build_val_loader( def build_train_dataset( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[TrainLoaderConfig] = None, + audio_loader: AudioLoader | None = None, + labeller: ClipLabeller | None = None, + preprocessor: PreprocessorProtocol | None = None, + config: TrainLoaderConfig | None = None, ) -> TrainingDataset: logger.info("Building training dataset...") config = config or TrainLoaderConfig() @@ -291,10 +291,10 @@ def build_train_dataset( def build_val_dataset( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[ValLoaderConfig] = None, + audio_loader: AudioLoader | None = None, + labeller: ClipLabeller | None = None, + preprocessor: PreprocessorProtocol | None = None, + config: ValLoaderConfig | None = None, ) -> ValidationDataset: logger.info("Building validation dataset...") config = config or ValLoaderConfig() diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 58c6a37..e31df53 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -42,10 +42,10 @@ class LabelConfig(BaseConfig): def build_clip_labeler( - targets: Optional[TargetProtocol] = None, + targets: TargetProtocol | None = None, min_freq: float = MIN_FREQ, max_freq: float = MAX_FREQ, - config: Optional[LabelConfig] = None, + config: LabelConfig | None = None, ) -> ClipLabeller: """Construct the final clip labelling function.""" config = config or LabelConfig() @@ -153,7 +153,7 @@ def generate_heatmaps( def load_label_config( - path: data.PathLike, field: Optional[str] = None + path: data.PathLike, field: str | None = None ) -> LabelConfig: """Load the heatmap label generation configuration from a file. diff --git a/src/batdetect2/train/legacy/train.py b/src/batdetect2/train/legacy/train.py index f15fd74..799c4ce 100644 --- a/src/batdetect2/train/legacy/train.py +++ b/src/batdetect2/train/legacy/train.py @@ -21,7 +21,7 @@ def train_loop( model: DetectionModel, train_dataset: LabeledDataset[TrainInputs], validation_dataset: LabeledDataset[TrainInputs], - device: Optional[torch.device] = None, + device: torch.device | None = None, num_epochs: int = 100, learning_rate: float = 1e-4, ): diff --git a/src/batdetect2/train/legacy/train_utils.py b/src/batdetect2/train/legacy/train_utils.py index b2987d9..fd05f3a 100644 --- a/src/batdetect2/train/legacy/train_utils.py +++ b/src/batdetect2/train/legacy/train_utils.py @@ -106,10 +106,10 @@ def standardize_low_freq( def format_annotation( annotation: types.FileAnnotation, - events_of_interest: Optional[List[str]] = None, - name_replace: Optional[Dict[str, str]] = None, + events_of_interest: List[str] | None = None, + name_replace: Dict[str, str] | None = None, convert_to_genus: bool = False, - classes_to_ignore: Optional[List[str]] = None, + classes_to_ignore: List[str] | None = None, ) -> types.FileAnnotation: formated = [] for aa in annotation["annotation"]: @@ -154,7 +154,7 @@ def format_annotation( def get_class_names( data: List[types.FileAnnotation], - classes_to_ignore: Optional[List[str]] = None, + classes_to_ignore: List[str] | None = None, ) -> Tuple[StringCounter, List[float]]: """Extracts class names and their inverse frequencies. @@ -201,9 +201,9 @@ def load_set_of_anns( *, convert_to_genus: bool = False, filter_issues: bool = False, - events_of_interest: Optional[List[str]] = None, - classes_to_ignore: Optional[List[str]] = None, - name_replace: Optional[Dict[str, str]] = None, + events_of_interest: List[str] | None = None, + classes_to_ignore: List[str] | None = None, + name_replace: Dict[str, str] | None = None, ) -> List[types.FileAnnotation]: # load the annotations anns = [] diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 92d1468..499820b 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -26,10 +26,10 @@ class TrainingModule(L.LightningModule): def __init__( self, - config: Optional[dict] = None, + config: dict | None = None, t_max: int = 100, - model: Optional[Model] = None, - loss: Optional[torch.nn.Module] = None, + model: Model | None = None, + loss: torch.nn.Module | None = None, ): from batdetect2.config import validate_config @@ -103,7 +103,7 @@ def load_model_from_checkpoint( def build_training_module( - config: Optional[dict] = None, + config: dict | None = None, t_max: int = 200, ) -> TrainingModule: return TrainingModule(config=config, t_max=t_max) diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index e4ecd27..bcad7a2 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -151,7 +151,7 @@ class FocalLoss(nn.Module): eps: float = 1e-5, beta: float = 4, alpha: float = 2, - class_weights: Optional[torch.Tensor] = None, + class_weights: torch.Tensor | None = None, mask_zero: bool = False, ): super().__init__() @@ -422,8 +422,8 @@ class LossFunction(nn.Module, LossProtocol): def build_loss( - config: Optional[LossConfig] = None, - class_weights: Optional[np.ndarray] = None, + config: LossConfig | None = None, + class_weights: np.ndarray | None = None, ) -> nn.Module: """Factory function to build the main LossFunction from configuration. diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 82350c6..70ff131 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -35,21 +35,21 @@ __all__ = [ def train( train_annotations: Sequence[data.ClipAnnotation], - val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, + val_annotations: Sequence[data.ClipAnnotation] | None = None, targets: Optional["TargetProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, audio_loader: Optional["AudioLoader"] = None, labeller: Optional["ClipLabeller"] = None, config: Optional["BatDetect2Config"] = None, - trainer: Optional[Trainer] = None, - train_workers: Optional[int] = None, - val_workers: Optional[int] = None, - checkpoint_dir: Optional[Path] = None, - log_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - num_epochs: Optional[int] = None, - run_name: Optional[str] = None, - seed: Optional[int] = None, + trainer: Trainer | None = None, + train_workers: int | None = None, + val_workers: int | None = None, + checkpoint_dir: Path | None = None, + log_dir: Path | None = None, + experiment_name: str | None = None, + num_epochs: int | None = None, + run_name: str | None = None, + seed: int | None = None, ): from batdetect2.config import BatDetect2Config @@ -126,11 +126,11 @@ def train( def build_trainer( config: "BatDetect2Config", evaluator: "EvaluatorProtocol", - checkpoint_dir: Optional[Path] = None, - log_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, - num_epochs: Optional[int] = None, + checkpoint_dir: Path | None = None, + log_dir: Path | None = None, + experiment_name: str | None = None, + run_name: str | None = None, + num_epochs: int | None = None, ) -> Trainer: trainer_conf = config.train.trainer logger.opt(lazy=True).debug( diff --git a/src/batdetect2/types.py b/src/batdetect2/types.py index 78b229a..539ebe1 100644 --- a/src/batdetect2/types.py +++ b/src/batdetect2/types.py @@ -240,7 +240,7 @@ class ProcessingConfiguration(TypedDict): detection_threshold: float """Threshold for detection probability.""" - time_expansion: Optional[float] + time_expansion: float | None """Time expansion factor of the processed recordings.""" top_n: int @@ -249,7 +249,7 @@ class ProcessingConfiguration(TypedDict): return_raw_preds: bool """Whether to return raw predictions.""" - max_duration: Optional[float] + max_duration: float | None """Maximum duration of audio file to process in seconds.""" nms_kernel_size: int diff --git a/src/batdetect2/typing/data.py b/src/batdetect2/typing/data.py index 52184ce..d148af1 100644 --- a/src/batdetect2/typing/data.py +++ b/src/batdetect2/typing/data.py @@ -20,7 +20,7 @@ class OutputFormatterProtocol(Protocol, Generic[T]): self, predictions: Sequence[T], path: PathLike, - audio_dir: Optional[PathLike] = None, + audio_dir: PathLike | None = None, ) -> None: ... def load(self, path: PathLike) -> List[T]: ... diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 6c1fbfb..54a9e71 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -28,19 +28,19 @@ __all__ = [ class MatchEvaluation: clip: data.Clip - sound_event_annotation: Optional[data.SoundEventAnnotation] + sound_event_annotation: data.SoundEventAnnotation | None gt_det: bool - gt_class: Optional[str] - gt_geometry: Optional[data.Geometry] + gt_class: str | None + gt_geometry: data.Geometry | None pred_score: float pred_class_scores: Dict[str, float] - pred_geometry: Optional[data.Geometry] + pred_geometry: data.Geometry | None affinity: float @property - def top_class(self) -> Optional[str]: + def top_class(self) -> str | None: if not self.pred_class_scores: return None @@ -76,7 +76,7 @@ class MatcherProtocol(Protocol): ground_truth: Sequence[data.Geometry], predictions: Sequence[data.Geometry], scores: Sequence[float], - ) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ... + ) -> Iterable[Tuple[int | None, int | None, float]]: ... Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index ece0f15..5ba4d47 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -42,7 +42,7 @@ class GeometryDecoder(Protocol): """ def __call__( - self, position: Position, size: Size, class_name: Optional[str] = None + self, position: Position, size: Size, class_name: str | None = None ) -> data.Geometry: ... @@ -93,5 +93,5 @@ class PostprocessorProtocol(Protocol): def __call__( self, output: ModelOutput, - start_times: Optional[Sequence[float]] = None, + start_times: Sequence[float] | None = None, ) -> List[ClipDetectionsTensor]: ... diff --git a/src/batdetect2/typing/preprocess.py b/src/batdetect2/typing/preprocess.py index 1e660f3..9be9990 100644 --- a/src/batdetect2/typing/preprocess.py +++ b/src/batdetect2/typing/preprocess.py @@ -37,7 +37,7 @@ class AudioLoader(Protocol): def load_file( self, path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> np.ndarray: """Load and preprocess audio directly from a file path. @@ -60,7 +60,7 @@ class AudioLoader(Protocol): def load_recording( self, recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> np.ndarray: """Load and preprocess the entire audio for a Recording object. @@ -90,7 +90,7 @@ class AudioLoader(Protocol): def load_clip( self, clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, + audio_dir: data.PathLike | None = None, ) -> np.ndarray: """Load and preprocess the audio segment defined by a Clip object. diff --git a/src/batdetect2/typing/targets.py b/src/batdetect2/typing/targets.py index b86573c..48db241 100644 --- a/src/batdetect2/typing/targets.py +++ b/src/batdetect2/typing/targets.py @@ -27,7 +27,7 @@ __all__ = [ "Size", ] -SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] +SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None] """Type alias for a sound event class encoder function. An encoder function takes a sound event annotation and returns the string name @@ -125,7 +125,7 @@ class TargetProtocol(Protocol): def encode_class( self, sound_event: data.SoundEventAnnotation, - ) -> Optional[str]: + ) -> str | None: """Encode a sound event annotation to its target class name. Parameters @@ -198,7 +198,7 @@ class TargetProtocol(Protocol): self, position: Position, size: Size, - class_name: Optional[str] = None, + class_name: str | None = None, ) -> data.Geometry: """Recover the ROI geometry from a position and dimensions. diff --git a/src/batdetect2/utils/audio_utils.py b/src/batdetect2/utils/audio_utils.py index 531c8a5..e64405b 100644 --- a/src/batdetect2/utils/audio_utils.py +++ b/src/batdetect2/utils/audio_utils.py @@ -146,7 +146,7 @@ def load_audio( time_exp_fact: float, target_samp_rate: int, scale: bool = False, - max_duration: Optional[float] = None, + max_duration: float | None = None, ) -> Tuple[int, np.ndarray]: """Load an audio file and resample it to the target sampling rate. @@ -241,7 +241,7 @@ def pad_audio( window_overlap: float = parameters.FFT_OVERLAP, resize_factor: float = parameters.RESIZE_FACTOR, divide_factor: int = parameters.SPEC_DIVIDE_FACTOR, - fixed_width: Optional[int] = None, + fixed_width: int | None = None, ): """Pad audio to be evenly divisible by `divide_factor`. diff --git a/src/batdetect2/utils/detector_utils.py b/src/batdetect2/utils/detector_utils.py index 9c297c4..084c5a4 100644 --- a/src/batdetect2/utils/detector_utils.py +++ b/src/batdetect2/utils/detector_utils.py @@ -84,7 +84,7 @@ def list_audio_files(ip_dir: str) -> List[str]: def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, - device: Union[torch.device, str, None] = None, + device: torch.device | str | None = None, weights_only: bool = True, ) -> Tuple[DetectionModel, ModelParameters]: """Load model from file. @@ -279,7 +279,7 @@ def convert_results( spec_feats, cnn_feats, spec_slices, - nyquist_freq: Optional[float] = None, + nyquist_freq: float | None = None, ) -> RunResults: """Convert results to dictionary as expected by the annotation tool. @@ -717,7 +717,7 @@ def process_file( model: DetectionModel, config: ProcessingConfiguration, device: torch.device, -) -> Union[RunResults, Any]: +) -> RunResults | Any: """Process a single audio file with detection model. Will split the audio file into chunks if it is too long and diff --git a/src/batdetect2/utils/tensors.py b/src/batdetect2/utils/tensors.py index 8d2c77b..9bf77e4 100644 --- a/src/batdetect2/utils/tensors.py +++ b/src/batdetect2/utils/tensors.py @@ -6,7 +6,7 @@ from torch.nn import functional as F def extend_width( - array: Union[np.ndarray, torch.Tensor], + array: np.ndarray | torch.Tensor, extra: int, axis: int = -1, value: float = 0, @@ -28,7 +28,7 @@ def extend_width( def make_width_divisible( - array: Union[np.ndarray, torch.Tensor], + array: np.ndarray | torch.Tensor, factor: int, axis: int = -1, value: float = 0, @@ -46,7 +46,7 @@ def make_width_divisible( def adjust_width( - array: Union[np.ndarray, torch.Tensor], + array: np.ndarray | torch.Tensor, width: int, axis: int = -1, value: float = 0,