diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index e7fffc6..96006eb 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -7,35 +7,35 @@ import torch from soundevent import data from soundevent.audio.files import get_audio_files -from batdetect2.audio import build_audio_loader -from batdetect2.audio.types import AudioLoader +from batdetect2.audio import AudioLoader, build_audio_loader from batdetect2.config import BatDetect2Config from batdetect2.core import merge_configs -from batdetect2.data import load_dataset_from_config -from batdetect2.data.datasets import Dataset -from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate -from batdetect2.evaluate.types import EvaluatorProtocol +from batdetect2.data import Dataset, load_dataset_from_config +from batdetect2.evaluate import ( + DEFAULT_EVAL_DIR, + EvaluatorProtocol, + build_evaluator, + run_evaluate, +) from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.models import Model, build_model from batdetect2.outputs import ( OutputFormatConfig, + OutputFormatterProtocol, OutputTransformProtocol, build_output_formatter, build_output_transform, get_output_formatter, ) -from batdetect2.outputs.types import OutputFormatterProtocol -from batdetect2.postprocess import build_postprocessor -from batdetect2.postprocess.types import ( +from batdetect2.postprocess import ( ClipDetections, Detection, PostprocessorProtocol, + build_postprocessor, ) -from batdetect2.preprocess import build_preprocessor -from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets import build_targets -from batdetect2.targets.types import TargetProtocol +from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor +from batdetect2.targets import TargetProtocol, build_targets from batdetect2.train import ( DEFAULT_CHECKPOINT_DIR, load_model_from_checkpoint, @@ -168,6 +168,9 @@ class BatDetect2API: def load_audio(self, path: data.PathLike) -> np.ndarray: return self.audio_loader.load_file(path) + def load_recording(self, recording: data.Recording) -> np.ndarray: + return self.audio_loader.load_recording(recording) + def load_clip(self, clip: data.Clip) -> np.ndarray: return self.audio_loader.load_clip(clip) @@ -178,7 +181,7 @@ class BatDetect2API: tensor = torch.tensor(audio).unsqueeze(0) return self.preprocessor(tensor) - def process_file(self, audio_file: str) -> ClipDetections: + def process_file(self, audio_file: data.PathLike) -> ClipDetections: recording = data.Recording.from_file(audio_file, compute_hash=False) wav = self.audio_loader.load_recording(recording) detections = self.process_audio(wav) @@ -230,6 +233,7 @@ class BatDetect2API: def process_files( self, audio_files: Sequence[data.PathLike], + batch_size: int | None = None, num_workers: int = 0, ) -> list[ClipDetections]: return process_file_list( diff --git a/src/batdetect2/audio/__init__.py b/src/batdetect2/audio/__init__.py index 96b1259..9c4c7e8 100644 --- a/src/batdetect2/audio/__init__.py +++ b/src/batdetect2/audio/__init__.py @@ -5,8 +5,11 @@ from batdetect2.audio.loader import ( SoundEventAudioLoader, build_audio_loader, ) +from batdetect2.audio.types import AudioLoader, ClipperProtocol __all__ = [ + "AudioLoader", + "ClipperProtocol", "TARGET_SAMPLERATE_HZ", "AudioConfig", "SoundEventAudioLoader", diff --git a/src/batdetect2/data/__init__.py b/src/batdetect2/data/__init__.py index f12f8b0..29c7270 100644 --- a/src/batdetect2/data/__init__.py +++ b/src/batdetect2/data/__init__.py @@ -7,6 +7,7 @@ from batdetect2.data.annotations import ( load_annotated_dataset, ) from batdetect2.data.datasets import ( + Dataset, DatasetConfig, load_dataset, load_dataset_config, @@ -19,6 +20,7 @@ from batdetect2.data.summary import ( ) __all__ = [ + "Dataset", "AOEFAnnotations", "AnnotatedDataset", "AnnotationFormats", diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index c35d851..25463b5 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -2,14 +2,30 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.tasks import TaskConfig, build_task +from batdetect2.evaluate.types import ( + AffinityFunction, + ClipMatches, + EvaluationTaskProtocol, + EvaluatorProtocol, + MetricsProtocol, + PlotterProtocol, +) __all__ = [ + "AffinityFunction", + "ClipMatches", + "DEFAULT_EVAL_DIR", "EvaluationConfig", + "EvaluationTaskProtocol", "Evaluator", + "EvaluatorProtocol", + "MatchEvaluation", + "MatcherProtocol", + "MetricsProtocol", + "PlotterProtocol", "TaskConfig", "build_evaluator", "build_task", - "run_evaluate", "load_evaluation_config", - "DEFAULT_EVAL_DIR", + "run_evaluate", ] diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index b96fc12..3ab2bb9 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -1,49 +1,52 @@ -from typing import TYPE_CHECKING, List, Optional, Sequence +from typing import Sequence from lightning import Trainer from soundevent import data +from batdetect2.audio import AudioConfig from batdetect2.audio.loader import build_audio_loader from batdetect2.audio.types import AudioLoader +from batdetect2.inference import InferenceConfig from batdetect2.inference.clips import get_clips_from_files from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.lightning import InferenceModule from batdetect2.models import Model -from batdetect2.outputs import OutputTransformProtocol, build_output_transform +from batdetect2.outputs import ( + OutputsConfig, + OutputTransformProtocol, + build_output_transform, +) from batdetect2.postprocess.types import ClipDetections -from batdetect2.preprocess.preprocessor import build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.targets import build_targets from batdetect2.targets.types import TargetProtocol -if TYPE_CHECKING: - from batdetect2.config import BatDetect2Config - def run_batch_inference( - model, + model: Model, clips: Sequence[data.Clip], - targets: Optional["TargetProtocol"] = None, - audio_loader: Optional["AudioLoader"] = None, - preprocessor: Optional["PreprocessorProtocol"] = None, - config: Optional["BatDetect2Config"] = None, - output_transform: Optional[OutputTransformProtocol] = None, + targets: TargetProtocol | None = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + audio_config: AudioConfig | None = None, + output_transform: OutputTransformProtocol | None = None, + output_config: OutputsConfig | None = None, + inference_config: InferenceConfig | None = None, num_workers: int = 1, batch_size: int | None = None, -) -> List[ClipDetections]: - from batdetect2.config import BatDetect2Config - - config = config or BatDetect2Config() - - audio_loader = audio_loader or build_audio_loader() - - preprocessor = preprocessor or build_preprocessor( - input_samplerate=audio_loader.samplerate, +) -> list[ClipDetections]: + audio_config = audio_config or AudioConfig( + samplerate=model.preprocessor.input_samplerate, ) + output_config = output_config or OutputsConfig() + inference_config = inference_config or InferenceConfig() + + audio_loader = audio_loader or build_audio_loader(config=audio_config) + + preprocessor = preprocessor or model.preprocessor + targets = targets or model.targets - targets = targets or build_targets() output_transform = output_transform or build_output_transform( - config=config.outputs.transform, + config=output_config.transform, targets=targets, ) @@ -51,7 +54,7 @@ def run_batch_inference( clips, audio_loader=audio_loader, preprocessor=preprocessor, - config=config.inference.loader, + config=inference_config.loader, num_workers=num_workers, batch_size=batch_size, ) @@ -72,13 +75,18 @@ def run_batch_inference( def process_file_list( model: Model, paths: Sequence[data.PathLike], - config: "BatDetect2Config", - targets: Optional["TargetProtocol"] = None, - audio_loader: Optional["AudioLoader"] = None, - preprocessor: Optional["PreprocessorProtocol"] = None, + targets: TargetProtocol | None = None, + audio_loader: AudioLoader | None = None, + audio_config: AudioConfig | None = None, + preprocessor: PreprocessorProtocol | None = None, + inference_config: InferenceConfig | None = None, + output_config: OutputsConfig | None = None, + output_transform: OutputTransformProtocol | None = None, + batch_size: int | None = None, num_workers: int = 0, -) -> List[ClipDetections]: - clip_config = config.inference.clipping +) -> list[ClipDetections]: + inference_config = inference_config or InferenceConfig() + clip_config = inference_config.clipping clips = get_clips_from_files( paths, duration=clip_config.duration, @@ -92,6 +100,10 @@ def process_file_list( targets=targets, audio_loader=audio_loader, preprocessor=preprocessor, - config=config, + batch_size=batch_size, num_workers=num_workers, + output_config=output_config, + audio_config=audio_config, + output_transform=output_transform, + inference_config=inference_config, ) diff --git a/src/batdetect2/outputs/__init__.py b/src/batdetect2/outputs/__init__.py index c6528c1..68fae45 100644 --- a/src/batdetect2/outputs/__init__.py +++ b/src/batdetect2/outputs/__init__.py @@ -13,11 +13,15 @@ from batdetect2.outputs.transforms import ( OutputTransformConfig, build_output_transform, ) -from batdetect2.outputs.types import OutputTransformProtocol +from batdetect2.outputs.types import ( + OutputFormatterProtocol, + OutputTransformProtocol, +) __all__ = [ "BatDetect2OutputConfig", "OutputFormatConfig", + "OutputFormatterProtocol", "OutputTransformConfig", "OutputTransformProtocol", "OutputsConfig", diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index 20f2744..b80fb60 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -9,10 +9,26 @@ from batdetect2.postprocess.postprocessor import ( Postprocessor, build_postprocessor, ) +from batdetect2.postprocess.types import ( + ClipDetections, + ClipDetectionsArray, + ClipDetectionsTensor, + ClipPrediction, + Detection, + GeometryDecoder, + PostprocessorProtocol, +) __all__ = [ + "ClipDetections", + "ClipDetectionsArray", + "ClipDetectionsTensor", + "ClipPrediction", + "Detection", + "GeometryDecoder", "PostprocessConfig", "Postprocessor", + "PostprocessorProtocol", "build_postprocessor", "load_postprocess_config", "non_max_suppression", diff --git a/src/batdetect2/preprocess/__init__.py b/src/batdetect2/preprocess/__init__.py index 38c34ea..fe7fba8 100644 --- a/src/batdetect2/preprocess/__init__.py +++ b/src/batdetect2/preprocess/__init__.py @@ -7,8 +7,10 @@ from batdetect2.preprocess.config import ( ) from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ + "PreprocessorProtocol", "MAX_FREQ", "MIN_FREQ", "PreprocessingConfig", diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 8a09b66..4c43633 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -1,8 +1,6 @@ """BatDetect2 Target Definition system.""" from batdetect2.targets.classes import ( - SoundEventDecoder, - SoundEventEncoder, TargetClassConfig, build_sound_event_decoder, build_sound_event_encoder, @@ -12,7 +10,6 @@ from batdetect2.targets.config import TargetConfig, load_target_config from batdetect2.targets.rois import ( AnchorBBoxMapperConfig, ROIMapperConfig, - ROITargetMapper, build_roi_mapper, ) from batdetect2.targets.targets import ( @@ -27,15 +24,27 @@ from batdetect2.targets.terms import ( generic_class, individual, ) +from batdetect2.targets.types import ( + Position, + ROITargetMapper, + Size, + SoundEventDecoder, + SoundEventEncoder, + TargetProtocol, +) __all__ = [ "AnchorBBoxMapperConfig", + "Position", "ROIMapperConfig", "ROITargetMapper", + "Size", "SoundEventDecoder", "SoundEventEncoder", + "SoundEventFilter", "TargetClassConfig", "TargetConfig", + "TargetProtocol", "Targets", "build_roi_mapper", "build_sound_event_decoder",