diff --git a/.gitignore b/.gitignore index e63be69..a06bd6a 100644 --- a/.gitignore +++ b/.gitignore @@ -102,7 +102,7 @@ experiments/* DvcLiveLogger/checkpoints logs/ mlruns/ -outputs/ +/outputs/ notebooks/lightning_logs # Jupiter notebooks diff --git a/docs/source/architecture.md b/docs/source/architecture.md index 7b8ca1f..ced2c7a 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -89,5 +89,5 @@ Crucial for training, this module translates physical annotations (Regions of In ## Summary To navigate this codebase effectively: 1. Follow **`api_v2.py`** to see how high-level operations invoke individual components. -2. Rely heavily on the typed **Protocols** located in `src/batdetect2/typing/` to understand the inputs and outputs of each subsystem without needing to read the specific implementations. -3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network. \ No newline at end of file +2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation. +3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network. diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 90ed5e7..ca14288 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -8,6 +8,7 @@ from soundevent import data from soundevent.audio.files import get_audio_files from batdetect2.audio import build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.config import BatDetect2Config from batdetect2.core import merge_configs from batdetect2.data import ( @@ -15,6 +16,7 @@ from batdetect2.data import ( ) from batdetect2.data.datasets import Dataset from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate +from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.models import Model, build_model @@ -25,24 +27,22 @@ from batdetect2.outputs import ( build_output_transform, get_output_formatter, ) +from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.postprocess import build_postprocessor, to_raw_predictions +from batdetect2.postprocess.types import ( + ClipDetections, + Detection, + PostprocessorProtocol, +) from batdetect2.preprocess import build_preprocessor +from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import build_targets +from batdetect2.targets.types import TargetProtocol from batdetect2.train import ( DEFAULT_CHECKPOINT_DIR, load_model_from_checkpoint, run_train, ) -from batdetect2.typing import ( - AudioLoader, - ClipDetections, - Detection, - EvaluatorProtocol, - OutputFormatterProtocol, - PostprocessorProtocol, - PreprocessorProtocol, - TargetProtocol, -) class BatDetect2API: diff --git a/src/batdetect2/audio/clips.py b/src/batdetect2/audio/clips.py index dd70756..21390c6 100644 --- a/src/batdetect2/audio/clips.py +++ b/src/batdetect2/audio/clips.py @@ -6,13 +6,13 @@ from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds, intervals_overlap +from batdetect2.audio.types import ClipperProtocol from batdetect2.core import ( BaseConfig, ImportConfig, Registry, add_import_config, ) -from batdetect2.typing import ClipperProtocol DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_MAX_EMPTY_CLIP = 0.1 diff --git a/src/batdetect2/audio/loader.py b/src/batdetect2/audio/loader.py index 6f7ef8c..60dc8a9 100644 --- a/src/batdetect2/audio/loader.py +++ b/src/batdetect2/audio/loader.py @@ -5,8 +5,8 @@ from scipy.signal import resample, resample_poly from soundevent import audio, data from soundfile import LibsndfileError +from batdetect2.audio.types import AudioLoader from batdetect2.core import BaseConfig -from batdetect2.typing import AudioLoader __all__ = [ "SoundEventAudioLoader", diff --git a/src/batdetect2/audio/types.py b/src/batdetect2/audio/types.py new file mode 100644 index 0000000..87d46f2 --- /dev/null +++ b/src/batdetect2/audio/types.py @@ -0,0 +1,40 @@ +from typing import Protocol + +import numpy as np +from soundevent import data + +__all__ = [ + "AudioLoader", + "ClipperProtocol", +] + + +class AudioLoader(Protocol): + samplerate: int + + def load_file( + self, + path: data.PathLike, + audio_dir: data.PathLike | None = None, + ) -> np.ndarray: ... + + def load_recording( + self, + recording: data.Recording, + audio_dir: data.PathLike | None = None, + ) -> np.ndarray: ... + + def load_clip( + self, + clip: data.Clip, + audio_dir: data.PathLike | None = None, + ) -> np.ndarray: ... + + +class ClipperProtocol(Protocol): + def __call__( + self, + clip_annotation: data.ClipAnnotation, + ) -> data.ClipAnnotation: ... + + def get_subclip(self, clip: data.Clip) -> data.Clip: ... diff --git a/src/batdetect2/core/registries.py b/src/batdetect2/core/registries.py index 3fb170e..710ec63 100644 --- a/src/batdetect2/core/registries.py +++ b/src/batdetect2/core/registries.py @@ -4,6 +4,7 @@ from typing import ( Concatenate, Generic, ParamSpec, + Sequence, Type, TypeVar, ) @@ -147,6 +148,7 @@ T_Import = TypeVar("T_Import", bound=ImportConfig) def add_import_config( registry: Registry[T_Type, P_Type], + arg_names: Sequence[str] | None = None, ) -> Callable[[Type[T_Import]], Type[T_Import]]: """Decorator that registers an ImportConfig subclass as an escape hatch. @@ -181,15 +183,22 @@ def add_import_config( *args: P_Type.args, **kwargs: P_Type.kwargs, ) -> T_Type: - if len(args) > 0: + _arg_names = arg_names or [] + + if len(args) != len(_arg_names): raise ValueError( "Positional arguments are not supported " - "for import escape hatch." + "for import escape hatch unless you specify " + "the argument names. Use `arg_names` to specify " + "the names of the positional arguments." ) + args_dict = {_arg_names[i]: args[i] for i in range(len(args))} + hydra_cfg = { "_target_": config.target, **config.arguments, + **args_dict, **kwargs, } return instantiate(hydra_cfg) diff --git a/src/batdetect2/data/iterators.py b/src/batdetect2/data/iterators.py index c499b0b..9216450 100644 --- a/src/batdetect2/data/iterators.py +++ b/src/batdetect2/data/iterators.py @@ -3,7 +3,7 @@ from collections.abc import Generator from soundevent import data from batdetect2.data.datasets import Dataset -from batdetect2.typing.targets import TargetProtocol +from batdetect2.targets.types import TargetProtocol def iterate_over_sound_events( diff --git a/src/batdetect2/data/split.py b/src/batdetect2/data/split.py index fa4f871..2a0f951 100644 --- a/src/batdetect2/data/split.py +++ b/src/batdetect2/data/split.py @@ -5,7 +5,7 @@ from batdetect2.data.summary import ( extract_recordings_df, extract_sound_events_df, ) -from batdetect2.typing.targets import TargetProtocol +from batdetect2.targets.types import TargetProtocol def split_dataset_by_recordings( diff --git a/src/batdetect2/data/summary.py b/src/batdetect2/data/summary.py index 1db0948..7d81ef6 100644 --- a/src/batdetect2/data/summary.py +++ b/src/batdetect2/data/summary.py @@ -2,7 +2,7 @@ import pandas as pd from soundevent.geometry import compute_bounds from batdetect2.data.datasets import Dataset -from batdetect2.typing.targets import TargetProtocol +from batdetect2.targets.types import TargetProtocol __all__ = [ "extract_recordings_df", diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index a4141fa..89284eb 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -16,7 +16,8 @@ from batdetect2.core import ( Registry, add_import_config, ) -from batdetect2.typing import AffinityFunction, Detection +from batdetect2.evaluate.types import AffinityFunction +from batdetect2.postprocess.types import Detection affinity_functions: Registry[AffinityFunction, []] = Registry( "affinity_function" diff --git a/src/batdetect2/evaluate/dataset.py b/src/batdetect2/evaluate/dataset.py index 9bf106a..ebc7431 100644 --- a/src/batdetect2/evaluate/dataset.py +++ b/src/batdetect2/evaluate/dataset.py @@ -8,14 +8,11 @@ from torch.utils.data import DataLoader, Dataset from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper from batdetect2.audio.clips import PaddedClipConfig +from batdetect2.audio.types import AudioLoader, ClipperProtocol from batdetect2.core import BaseConfig from batdetect2.core.arrays import adjust_width from batdetect2.preprocess import build_preprocessor -from batdetect2.typing import ( - AudioLoader, - ClipperProtocol, - PreprocessorProtocol, -) +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ "TestDataset", diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index c533e8b..99d6219 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -5,22 +5,20 @@ from lightning import Trainer from soundevent import data from batdetect2.audio import build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.evaluate.dataset import build_test_loader from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.logging import build_logger from batdetect2.models import Model from batdetect2.outputs import build_output_transform -from batdetect2.typing import Detection +from batdetect2.outputs.types import OutputFormatterProtocol +from batdetect2.postprocess.types import Detection +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import TargetProtocol if TYPE_CHECKING: from batdetect2.config import BatDetect2Config - from batdetect2.typing import ( - AudioLoader, - OutputFormatterProtocol, - PreprocessorProtocol, - TargetProtocol, - ) DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index 39e960f..30152b9 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -5,9 +5,10 @@ from soundevent import data from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.tasks import build_task +from batdetect2.evaluate.types import EvaluatorProtocol +from batdetect2.postprocess.types import ClipDetections from batdetect2.targets import build_targets -from batdetect2.typing import EvaluatorProtocol, TargetProtocol -from batdetect2.typing.postprocess import ClipDetections +from batdetect2.targets.types import TargetProtocol __all__ = [ "Evaluator", diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index 0abd367..c721e44 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -5,12 +5,12 @@ from soundevent import data from torch.utils.data import DataLoader from batdetect2.evaluate.dataset import TestDataset, TestExample +from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import get_image_logger from batdetect2.models import Model from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess import to_raw_predictions -from batdetect2.typing import EvaluatorProtocol -from batdetect2.typing.postprocess import ClipDetections +from batdetect2.postprocess.types import ClipDetections class EvaluationModule(LightningModule): diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index 6ef5971..daf3e31 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -26,7 +26,8 @@ from batdetect2.evaluate.metrics.common import ( average_precision, compute_precision_recall, ) -from batdetect2.typing import Detection, TargetProtocol +from batdetect2.postprocess.types import Detection +from batdetect2.targets.types import TargetProtocol __all__ = [ "ClassificationMetric", diff --git a/src/batdetect2/evaluate/metrics/detection.py b/src/batdetect2/evaluate/metrics/detection.py index 2f13915..59fa0b6 100644 --- a/src/batdetect2/evaluate/metrics/detection.py +++ b/src/batdetect2/evaluate/metrics/detection.py @@ -20,7 +20,7 @@ from batdetect2.core import ( add_import_config, ) from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing import Detection +from batdetect2.postprocess.types import Detection __all__ = [ "DetectionMetricConfig", diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py index 5fa2605..e131e6f 100644 --- a/src/batdetect2/evaluate/metrics/top_class.py +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -20,8 +20,8 @@ from batdetect2.core import ( add_import_config, ) from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing import Detection -from batdetect2.typing.targets import TargetProtocol +from batdetect2.postprocess.types import Detection +from batdetect2.targets.types import TargetProtocol __all__ = [ "TopClassMetricConfig", diff --git a/src/batdetect2/evaluate/plots/base.py b/src/batdetect2/evaluate/plots/base.py index beceb42..ee4028c 100644 --- a/src/batdetect2/evaluate/plots/base.py +++ b/src/batdetect2/evaluate/plots/base.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt from matplotlib.figure import Figure from batdetect2.core import BaseConfig -from batdetect2.typing import TargetProtocol +from batdetect2.targets.types import TargetProtocol class BasePlotConfig(BaseConfig): diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py index f8727b0..400dbbc 100644 --- a/src/batdetect2/evaluate/plots/classification.py +++ b/src/batdetect2/evaluate/plots/classification.py @@ -29,7 +29,7 @@ from batdetect2.plotting.metrics import ( plot_threshold_recall_curve, plot_threshold_recall_curves, ) -from batdetect2.typing import TargetProtocol +from batdetect2.targets.types import TargetProtocol ClassificationPlotter = Callable[ [Sequence[ClipEval]], Iterable[Tuple[str, Figure]] diff --git a/src/batdetect2/evaluate/plots/clip_classification.py b/src/batdetect2/evaluate/plots/clip_classification.py index df34482..481b9a1 100644 --- a/src/batdetect2/evaluate/plots/clip_classification.py +++ b/src/batdetect2/evaluate/plots/clip_classification.py @@ -22,7 +22,7 @@ from batdetect2.plotting.metrics import ( plot_roc_curve, plot_roc_curves, ) -from batdetect2.typing import TargetProtocol +from batdetect2.targets.types import TargetProtocol __all__ = [ "ClipClassificationPlotConfig", diff --git a/src/batdetect2/evaluate/plots/clip_detection.py b/src/batdetect2/evaluate/plots/clip_detection.py index 3e44804..fc8a6b8 100644 --- a/src/batdetect2/evaluate/plots/clip_detection.py +++ b/src/batdetect2/evaluate/plots/clip_detection.py @@ -18,7 +18,7 @@ from batdetect2.evaluate.metrics.clip_detection import ClipEval from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve -from batdetect2.typing import TargetProtocol +from batdetect2.targets.types import TargetProtocol __all__ = [ "ClipDetectionPlotConfig", diff --git a/src/batdetect2/evaluate/plots/detection.py b/src/batdetect2/evaluate/plots/detection.py index cf73e5e..a99f8c4 100644 --- a/src/batdetect2/evaluate/plots/detection.py +++ b/src/batdetect2/evaluate/plots/detection.py @@ -16,6 +16,7 @@ from pydantic import Field from sklearn import metrics from batdetect2.audio import AudioConfig, build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.detection import ClipEval @@ -23,7 +24,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.plotting.detections import plot_clip_detections from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import TargetProtocol DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]] diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py index 003ce42..d48ce87 100644 --- a/src/batdetect2/evaluate/plots/top_class.py +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -16,6 +16,7 @@ from pydantic import Field from sklearn import metrics from batdetect2.audio import AudioConfig, build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.top_class import ( @@ -27,7 +28,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import TargetProtocol TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]] diff --git a/src/batdetect2/evaluate/tasks/__init__.py b/src/batdetect2/evaluate/tasks/__init__.py index 625524e..11b3f01 100644 --- a/src/batdetect2/evaluate/tasks/__init__.py +++ b/src/batdetect2/evaluate/tasks/__init__.py @@ -11,12 +11,10 @@ from batdetect2.evaluate.tasks.clip_classification import ( from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig +from batdetect2.evaluate.types import EvaluatorProtocol +from batdetect2.postprocess.types import ClipDetections from batdetect2.targets import build_targets -from batdetect2.typing import ( - ClipDetections, - EvaluatorProtocol, - TargetProtocol, -) +from batdetect2.targets.types import TargetProtocol __all__ = [ "TaskConfig", diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index a229137..77065cd 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -26,13 +26,12 @@ from batdetect2.evaluate.affinity import ( TimeAffinityConfig, build_affinity_function, ) -from batdetect2.typing import ( +from batdetect2.evaluate.types import ( AffinityFunction, - ClipDetections, - Detection, EvaluatorProtocol, - TargetProtocol, ) +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol __all__ = [ "BaseTaskConfig", diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index 1da934d..43977bb 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -21,11 +21,8 @@ from batdetect2.evaluate.tasks.base import ( BaseSEDTaskConfig, tasks_registry, ) -from batdetect2.typing import ( - ClipDetections, - Detection, - TargetProtocol, -) +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol class ClassificationTaskConfig(BaseSEDTaskConfig): diff --git a/src/batdetect2/evaluate/tasks/clip_classification.py b/src/batdetect2/evaluate/tasks/clip_classification.py index 32e2383..958d279 100644 --- a/src/batdetect2/evaluate/tasks/clip_classification.py +++ b/src/batdetect2/evaluate/tasks/clip_classification.py @@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import ClipDetections, TargetProtocol +from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import TargetProtocol class ClipClassificationTaskConfig(BaseTaskConfig): diff --git a/src/batdetect2/evaluate/tasks/clip_detection.py b/src/batdetect2/evaluate/tasks/clip_detection.py index b19efa0..ed810d5 100644 --- a/src/batdetect2/evaluate/tasks/clip_detection.py +++ b/src/batdetect2/evaluate/tasks/clip_detection.py @@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import ClipDetections, TargetProtocol +from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import TargetProtocol class ClipDetectionTaskConfig(BaseTaskConfig): diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py index 2c4e83f..49099c5 100644 --- a/src/batdetect2/evaluate/tasks/detection.py +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -20,8 +20,8 @@ from batdetect2.evaluate.tasks.base import ( BaseSEDTaskConfig, tasks_registry, ) -from batdetect2.typing import TargetProtocol -from batdetect2.typing.postprocess import ClipDetections +from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import TargetProtocol class DetectionTaskConfig(BaseSEDTaskConfig): diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index 0745891..337ee6a 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -20,7 +20,8 @@ from batdetect2.evaluate.tasks.base import ( BaseSEDTaskConfig, tasks_registry, ) -from batdetect2.typing import ClipDetections, TargetProtocol +from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import TargetProtocol class TopClassDetectionTaskConfig(BaseSEDTaskConfig): diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/evaluate/types.py similarity index 86% rename from src/batdetect2/typing/evaluate.py rename to src/batdetect2/evaluate/types.py index 9698342..58f1c86 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/evaluate/types.py @@ -1,45 +1,39 @@ from dataclasses import dataclass -from typing import ( - Generic, - Iterable, - Protocol, - Sequence, - TypeVar, -) +from typing import Generic, Iterable, Protocol, Sequence, TypeVar from matplotlib.figure import Figure from soundevent import data -from batdetect2.typing.postprocess import ClipDetections, Detection -from batdetect2.typing.targets import TargetProtocol +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol __all__ = [ + "AffinityFunction", + "ClipMatches", "EvaluatorProtocol", - "MetricsProtocol", "MatchEvaluation", + "MatcherProtocol", + "MetricsProtocol", + "PlotterProtocol", ] @dataclass class MatchEvaluation: clip: data.Clip - sound_event_annotation: data.SoundEventAnnotation | None gt_det: bool gt_class: str | None gt_geometry: data.Geometry | None - pred_score: float pred_class_scores: dict[str, float] pred_geometry: data.Geometry | None - affinity: float @property def top_class(self) -> str | None: if not self.pred_class_scores: return None - return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore @property @@ -53,10 +47,8 @@ class MatchEvaluation: @property def top_class_score(self) -> float: pred_class = self.top_class - if pred_class is None: return 0 - return self.pred_class_scores[pred_class] @@ -75,9 +67,6 @@ class MatcherProtocol(Protocol): ) -> Iterable[tuple[int | None, int | None, float]]: ... -Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) - - class AffinityFunction(Protocol): def __call__( self, @@ -115,9 +104,11 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): ) -> EvaluationOutput: ... def compute_metrics( - self, eval_outputs: EvaluationOutput + self, + eval_outputs: EvaluationOutput, ) -> dict[str, float]: ... def generate_plots( - self, eval_outputs: EvaluationOutput + self, + eval_outputs: EvaluationOutput, ) -> Iterable[tuple[str, Figure]]: ... diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 23681b7..987bb74 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -4,22 +4,20 @@ from lightning import Trainer from soundevent import data from batdetect2.audio.loader import build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.inference.clips import get_clips_from_files from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.lightning import InferenceModule from batdetect2.models import Model from batdetect2.outputs import OutputTransformProtocol, build_output_transform +from batdetect2.postprocess.types import ClipDetections from batdetect2.preprocess.preprocessor import build_preprocessor +from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets.targets import build_targets -from batdetect2.typing.postprocess import ClipDetections +from batdetect2.targets.types import TargetProtocol if TYPE_CHECKING: from batdetect2.config import BatDetect2Config - from batdetect2.typing import ( - AudioLoader, - PreprocessorProtocol, - TargetProtocol, - ) def run_batch_inference( diff --git a/src/batdetect2/inference/dataset.py b/src/batdetect2/inference/dataset.py index 13bfd32..0b1aa50 100644 --- a/src/batdetect2/inference/dataset.py +++ b/src/batdetect2/inference/dataset.py @@ -6,10 +6,11 @@ from soundevent import data from torch.utils.data import DataLoader, Dataset from batdetect2.audio import build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.core import BaseConfig from batdetect2.core.arrays import adjust_width from batdetect2.preprocess import build_preprocessor -from batdetect2.typing import AudioLoader, PreprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ "InferenceDataset", diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index 2db58e6..c2689d7 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -7,7 +7,7 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.models import Model from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess import to_raw_predictions -from batdetect2.typing.postprocess import ClipDetections +from batdetect2.postprocess.types import ClipDetections class InferenceModule(LightningModule): diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index e65ddef..824479f 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -62,16 +62,16 @@ from batdetect2.models.encoder import ( build_encoder, ) from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead +from batdetect2.models.types import DetectionModel from batdetect2.postprocess.config import PostprocessConfig -from batdetect2.preprocess.config import PreprocessingConfig -from batdetect2.targets.config import TargetConfig -from batdetect2.typing import ( +from batdetect2.postprocess.types import ( ClipDetectionsTensor, - DetectionModel, PostprocessorProtocol, - PreprocessorProtocol, - TargetProtocol, ) +from batdetect2.preprocess.config import PreprocessingConfig +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.config import TargetConfig +from batdetect2.targets.types import TargetProtocol __all__ = [ "BBoxHead", diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index f13aa7f..c3ef77c 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -51,7 +51,7 @@ from batdetect2.models.encoder import ( EncoderConfig, build_encoder, ) -from batdetect2.typing.models import ( +from batdetect2.models.types import ( BackboneModel, BottleneckProtocol, DecoderProtocol, diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index 55fe549..9b2154a 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -31,7 +31,7 @@ from batdetect2.models.blocks import ( VerticalConv, build_layer, ) -from batdetect2.typing.models import BottleneckProtocol +from batdetect2.models.types import BottleneckProtocol __all__ = [ "BottleneckConfig", diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 97238f6..ed49496 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -26,7 +26,7 @@ from batdetect2.models.backbones import ( build_backbone, ) from batdetect2.models.heads import BBoxHead, ClassifierHead -from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput +from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput __all__ = [ "Detector", diff --git a/src/batdetect2/models/types.py b/src/batdetect2/models/types.py new file mode 100644 index 0000000..5cf57b3 --- /dev/null +++ b/src/batdetect2/models/types.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from typing import NamedTuple, Protocol + +import torch + +__all__ = [ + "BackboneModel", + "BlockProtocol", + "BottleneckProtocol", + "DecoderProtocol", + "DetectionModel", + "EncoderDecoderModel", + "EncoderProtocol", + "ModelOutput", +] + + +class BlockProtocol(Protocol): + in_channels: int + out_channels: int + + def __call__(self, x: torch.Tensor) -> torch.Tensor: ... + + def get_output_height(self, input_height: int) -> int: ... + + +class EncoderProtocol(Protocol): + in_channels: int + out_channels: int + input_height: int + output_height: int + + def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ... + + +class BottleneckProtocol(Protocol): + in_channels: int + out_channels: int + input_height: int + + def __call__(self, x: torch.Tensor) -> torch.Tensor: ... + + +class DecoderProtocol(Protocol): + in_channels: int + out_channels: int + input_height: int + output_height: int + depth: int + + def __call__( + self, + x: torch.Tensor, + residuals: list[torch.Tensor], + ) -> torch.Tensor: ... + + +class ModelOutput(NamedTuple): + detection_probs: torch.Tensor + size_preds: torch.Tensor + class_probs: torch.Tensor + features: torch.Tensor + + +class BackboneModel(ABC, torch.nn.Module): + input_height: int + out_channels: int + + @abstractmethod + def forward(self, spec: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class EncoderDecoderModel(BackboneModel): + bottleneck_channels: int + + @abstractmethod + def encode(self, spec: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def decode(self, encoded: torch.Tensor) -> torch.Tensor: ... + + +class DetectionModel(ABC, torch.nn.Module): + @abstractmethod + def forward(self, spec: torch.Tensor) -> ModelOutput: ... diff --git a/src/batdetect2/outputs/formats/__init__.py b/src/batdetect2/outputs/formats/__init__.py index d103a37..0244b83 100644 --- a/src/batdetect2/outputs/formats/__init__.py +++ b/src/batdetect2/outputs/formats/__init__.py @@ -11,7 +11,7 @@ from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig from batdetect2.outputs.formats.parquet import ParquetOutputConfig from batdetect2.outputs.formats.raw import RawOutputConfig from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig -from batdetect2.typing import TargetProtocol +from batdetect2.targets.types import TargetProtocol __all__ = [ "BatDetect2OutputConfig", diff --git a/src/batdetect2/outputs/formats/base.py b/src/batdetect2/outputs/formats/base.py index 5fd5f24..9b0f012 100644 --- a/src/batdetect2/outputs/formats/base.py +++ b/src/batdetect2/outputs/formats/base.py @@ -4,10 +4,8 @@ from typing import Literal from soundevent.data import PathLike from batdetect2.core import ImportConfig, Registry, add_import_config -from batdetect2.typing import ( - OutputFormatterProtocol, - TargetProtocol, -) +from batdetect2.outputs.types import OutputFormatterProtocol +from batdetect2.targets.types import TargetProtocol __all__ = [ "OutputFormatterProtocol", diff --git a/src/batdetect2/outputs/formats/batdetect2.py b/src/batdetect2/outputs/formats/batdetect2.py index 386d214..fc733b0 100644 --- a/src/batdetect2/outputs/formats/batdetect2.py +++ b/src/batdetect2/outputs/formats/batdetect2.py @@ -12,12 +12,9 @@ from batdetect2.outputs.formats.base import ( output_formatters, ) from batdetect2.targets import terms -from batdetect2.typing import ( - ClipDetections, - Detection, - OutputFormatterProtocol, - TargetProtocol, -) +from batdetect2.outputs.types import OutputFormatterProtocol +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol try: from typing import NotRequired # type: ignore diff --git a/src/batdetect2/outputs/formats/parquet.py b/src/batdetect2/outputs/formats/parquet.py index c6d0034..909fa9d 100644 --- a/src/batdetect2/outputs/formats/parquet.py +++ b/src/batdetect2/outputs/formats/parquet.py @@ -13,12 +13,9 @@ from batdetect2.outputs.formats.base import ( make_path_relative, output_formatters, ) -from batdetect2.typing import ( - ClipDetections, - Detection, - OutputFormatterProtocol, - TargetProtocol, -) +from batdetect2.outputs.types import OutputFormatterProtocol +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol class ParquetOutputConfig(BaseConfig): diff --git a/src/batdetect2/outputs/formats/raw.py b/src/batdetect2/outputs/formats/raw.py index dc58cd8..c4150df 100644 --- a/src/batdetect2/outputs/formats/raw.py +++ b/src/batdetect2/outputs/formats/raw.py @@ -14,12 +14,9 @@ from batdetect2.outputs.formats.base import ( make_path_relative, output_formatters, ) -from batdetect2.typing import ( - ClipDetections, - Detection, - OutputFormatterProtocol, - TargetProtocol, -) +from batdetect2.outputs.types import OutputFormatterProtocol +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol class RawOutputConfig(BaseConfig): diff --git a/src/batdetect2/outputs/formats/soundevent.py b/src/batdetect2/outputs/formats/soundevent.py index 1be9616..fd70a11 100644 --- a/src/batdetect2/outputs/formats/soundevent.py +++ b/src/batdetect2/outputs/formats/soundevent.py @@ -8,12 +8,9 @@ from batdetect2.core import BaseConfig from batdetect2.outputs.formats.base import ( output_formatters, ) -from batdetect2.typing import ( - ClipDetections, - Detection, - OutputFormatterProtocol, - TargetProtocol, -) +from batdetect2.outputs.types import OutputFormatterProtocol +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol class SoundEventOutputConfig(BaseConfig): diff --git a/src/batdetect2/outputs/transforms.py b/src/batdetect2/outputs/transforms.py index 89ef5dc..de7dafe 100644 --- a/src/batdetect2/outputs/transforms.py +++ b/src/batdetect2/outputs/transforms.py @@ -5,7 +5,7 @@ from typing import Protocol from soundevent.geometry import shift_geometry from batdetect2.core.configs import BaseConfig -from batdetect2.typing import ClipDetections, Detection +from batdetect2.postprocess.types import ClipDetections, Detection __all__ = [ "OutputTransform", diff --git a/src/batdetect2/typing/data.py b/src/batdetect2/outputs/types.py similarity index 55% rename from src/batdetect2/typing/data.py rename to src/batdetect2/outputs/types.py index 12f13fd..6e67fe6 100644 --- a/src/batdetect2/typing/data.py +++ b/src/batdetect2/outputs/types.py @@ -1,8 +1,9 @@ -from typing import Generic, List, Protocol, Sequence, TypeVar +from collections.abc import Sequence +from typing import Generic, Protocol, TypeVar from soundevent.data import PathLike -from batdetect2.typing.postprocess import ClipDetections +from batdetect2.postprocess.types import ClipDetections __all__ = [ "OutputFormatterProtocol", @@ -12,7 +13,7 @@ T = TypeVar("T") class OutputFormatterProtocol(Protocol, Generic[T]): - def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ... + def format(self, predictions: Sequence[ClipDetections]) -> list[T]: ... def save( self, @@ -21,4 +22,4 @@ class OutputFormatterProtocol(Protocol, Generic[T]): audio_dir: PathLike | None = None, ) -> None: ... - def load(self, path: PathLike) -> List[T]: ... + def load(self, path: PathLike) -> list[T]: ... diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index be6f798..a866360 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -3,8 +3,8 @@ from soundevent import data, plot from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.common import create_ax -from batdetect2.typing.preprocess import PreprocessorProtocol -from batdetect2.typing.targets import TargetProtocol +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import TargetProtocol __all__ = [ "plot_clip_annotation", diff --git a/src/batdetect2/plotting/clip_predictions.py b/src/batdetect2/plotting/clip_predictions.py index a8cc198..d0b24fd 100644 --- a/src/batdetect2/plotting/clip_predictions.py +++ b/src/batdetect2/plotting/clip_predictions.py @@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag from batdetect2.plotting.clips import plot_clip -from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ "plot_clip_prediction", diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py index 9b67e25..9edf99f 100644 --- a/src/batdetect2/plotting/clips.py +++ b/src/batdetect2/plotting/clips.py @@ -4,9 +4,10 @@ from matplotlib.axes import Axes from soundevent import data from batdetect2.audio import build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.plotting.common import plot_spectrogram from batdetect2.preprocess import build_preprocessor -from batdetect2.typing import AudioLoader, PreprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ "plot_clip", diff --git a/src/batdetect2/plotting/gallery.py b/src/batdetect2/plotting/gallery.py index 4cc9eeb..63fb932 100644 --- a/src/batdetect2/plotting/gallery.py +++ b/src/batdetect2/plotting/gallery.py @@ -3,6 +3,7 @@ from typing import Sequence import matplotlib.pyplot as plt from matplotlib.figure import Figure +from batdetect2.audio.types import AudioLoader from batdetect2.plotting.matches import ( MatchProtocol, plot_cross_trigger_match, @@ -10,7 +11,7 @@ from batdetect2.plotting.matches import ( plot_false_positive_match, plot_true_positive_match, ) -from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = ["plot_match_gallery"] diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index cbd1be6..9806f5f 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -4,12 +4,10 @@ from matplotlib.axes import Axes from soundevent import data, plot from soundevent.geometry import compute_bounds +from batdetect2.audio.types import AudioLoader from batdetect2.plotting.clips import plot_clip -from batdetect2.typing import ( - AudioLoader, - Detection, - PreprocessorProtocol, -) +from batdetect2.postprocess.types import Detection +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ "plot_false_positive_match", diff --git a/src/batdetect2/postprocess/clips.py b/src/batdetect2/postprocess/clips.py index 5b46bbd..dcd17d5 100644 --- a/src/batdetect2/postprocess/clips.py +++ b/src/batdetect2/postprocess/clips.py @@ -1,4 +1,4 @@ -from batdetect2.typing import ClipDetections +from batdetect2.postprocess.types import ClipDetections class ClipTransform: diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index d779280..517522d 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -5,11 +5,11 @@ from typing import List import numpy as np from soundevent import data -from batdetect2.typing.postprocess import ( +from batdetect2.postprocess.types import ( ClipDetectionsArray, Detection, ) -from batdetect2.typing.targets import TargetProtocol +from batdetect2.targets.types import TargetProtocol __all__ = [ "to_raw_predictions", diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index 15c59be..b963e47 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -19,7 +19,7 @@ from typing import List import torch -from batdetect2.typing.postprocess import ClipDetectionsTensor +from batdetect2.postprocess.types import ClipDetectionsTensor __all__ = [ "extract_detection_peaks", diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index 104a03c..7d5d3a4 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -1,18 +1,18 @@ import torch from loguru import logger +from batdetect2.models.types import ModelOutput from batdetect2.postprocess.config import ( PostprocessConfig, ) from batdetect2.postprocess.extraction import extract_detection_peaks from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression from batdetect2.postprocess.remapping import map_detection_to_clip -from batdetect2.typing import ModelOutput -from batdetect2.typing.postprocess import ( +from batdetect2.postprocess.types import ( ClipDetectionsTensor, PostprocessorProtocol, ) -from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ "build_postprocessor", diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py index e47aee3..d1e95c8 100644 --- a/src/batdetect2/postprocess/remapping.py +++ b/src/batdetect2/postprocess/remapping.py @@ -19,8 +19,8 @@ import torch import xarray as xr from soundevent.arrays import Dimensions +from batdetect2.postprocess.types import ClipDetectionsTensor from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.typing.postprocess import ClipDetectionsTensor __all__ = [ "features_to_xarray", diff --git a/src/batdetect2/postprocess/types.py b/src/batdetect2/postprocess/types.py new file mode 100644 index 0000000..61e18e3 --- /dev/null +++ b/src/batdetect2/postprocess/types.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple, Protocol + +import numpy as np +import torch +from soundevent import data + +from batdetect2.targets.types import Position, Size + +if TYPE_CHECKING: + from batdetect2.models.types import ModelOutput + +__all__ = [ + "ClipDetections", + "ClipDetectionsArray", + "ClipDetectionsTensor", + "ClipPrediction", + "Detection", + "GeometryDecoder", + "PostprocessorProtocol", +] + + +class GeometryDecoder(Protocol): + def __call__( + self, + position: Position, + size: Size, + class_name: str | None = None, + ) -> data.Geometry: ... + + +@dataclass +class Detection: + geometry: data.Geometry + detection_score: float + class_scores: np.ndarray + features: np.ndarray + + +class ClipDetectionsArray(NamedTuple): + scores: np.ndarray + sizes: np.ndarray + class_scores: np.ndarray + times: np.ndarray + frequencies: np.ndarray + features: np.ndarray + + +class ClipDetectionsTensor(NamedTuple): + scores: torch.Tensor + sizes: torch.Tensor + class_scores: torch.Tensor + times: torch.Tensor + frequencies: torch.Tensor + features: torch.Tensor + + def numpy(self) -> ClipDetectionsArray: + return ClipDetectionsArray( + scores=self.scores.detach().cpu().numpy(), + sizes=self.sizes.detach().cpu().numpy(), + class_scores=self.class_scores.detach().cpu().numpy(), + times=self.times.detach().cpu().numpy(), + frequencies=self.frequencies.detach().cpu().numpy(), + features=self.features.detach().cpu().numpy(), + ) + + +@dataclass +class ClipDetections: + clip: data.Clip + detections: list[Detection] + + +@dataclass +class ClipPrediction: + clip: data.Clip + detection_score: float + class_scores: np.ndarray + + +class PostprocessorProtocol(Protocol): + def __call__( + self, output: "ModelOutput" + ) -> list[ClipDetectionsTensor]: ... diff --git a/src/batdetect2/preprocess/preprocessor.py b/src/batdetect2/preprocess/preprocessor.py index 63781d9..e8fcc91 100644 --- a/src/batdetect2/preprocess/preprocessor.py +++ b/src/batdetect2/preprocess/preprocessor.py @@ -1,7 +1,7 @@ """Assembles the full batdetect2 preprocessing pipeline. This module defines :class:`Preprocessor`, the concrete implementation of -:class:`~batdetect2.typing.PreprocessorProtocol`, and the +:class:`~batdetect2.preprocess.types.PreprocessorProtocol`, and the :func:`build_preprocessor` factory function that constructs it from a :class:`~batdetect2.preprocess.config.PreprocessingConfig`. @@ -33,7 +33,7 @@ from batdetect2.preprocess.spectrogram import ( build_spectrogram_resizer, build_spectrogram_transform, ) -from batdetect2.typing import PreprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol __all__ = [ "Preprocessor", @@ -42,7 +42,7 @@ __all__ = [ class Preprocessor(torch.nn.Module, PreprocessorProtocol): - """Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`. + """Standard implementation of the :class:`~batdetect2.preprocess.types.PreprocessorProtocol`. Wraps all preprocessing stages as ``torch.nn.Module`` submodules so that parameters (e.g. PCEN filter coefficients) can be tracked and diff --git a/src/batdetect2/preprocess/types.py b/src/batdetect2/preprocess/types.py new file mode 100644 index 0000000..39485e9 --- /dev/null +++ b/src/batdetect2/preprocess/types.py @@ -0,0 +1,31 @@ +from typing import Protocol + +import numpy as np +import torch + +__all__ = [ + "PreprocessorProtocol", + "SpectrogramBuilder", +] + + +class SpectrogramBuilder(Protocol): + def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... + + +class PreprocessorProtocol(Protocol): + max_freq: float + min_freq: float + input_samplerate: int + output_samplerate: float + + def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... + + def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ... + + def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ... + + def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ... + + def process_numpy(self, wav: np.ndarray) -> np.ndarray: + return self(torch.tensor(wav)).numpy() diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index f61b848..e7b3604 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -16,7 +16,7 @@ from batdetect2.data.conditions import ( ) from batdetect2.targets.rois import ROIMapperConfig from batdetect2.targets.terms import call_type, generic_class -from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder +from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder __all__ = [ "build_sound_event_decoder", diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 9b87495..215b99e 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -27,17 +27,13 @@ from pydantic import Field from soundevent import data from batdetect2.audio import AudioConfig, build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.core.arrays import spec_to_xarray from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.typing import ( - AudioLoader, - Position, - PreprocessorProtocol, - ROITargetMapper, - Size, -) +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import Position, ROITargetMapper, Size __all__ = [ "Anchor", diff --git a/src/batdetect2/targets/targets.py b/src/batdetect2/targets/targets.py index 0c83d50..f7a6dad 100644 --- a/src/batdetect2/targets/targets.py +++ b/src/batdetect2/targets/targets.py @@ -16,7 +16,7 @@ from batdetect2.targets.rois import ( AnchorBBoxMapperConfig, build_roi_mapper, ) -from batdetect2.typing.targets import Position, Size, TargetProtocol +from batdetect2.targets.types import Position, Size, TargetProtocol class Targets(TargetProtocol): diff --git a/src/batdetect2/targets/types.py b/src/batdetect2/targets/types.py new file mode 100644 index 0000000..af5ab44 --- /dev/null +++ b/src/batdetect2/targets/types.py @@ -0,0 +1,60 @@ +from collections.abc import Callable +from typing import Protocol + +import numpy as np +from soundevent import data + +__all__ = [ + "Position", + "ROITargetMapper", + "Size", + "SoundEventDecoder", + "SoundEventEncoder", + "SoundEventFilter", + "TargetProtocol", +] + +SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None] +SoundEventDecoder = Callable[[str], list[data.Tag]] +SoundEventFilter = Callable[[data.SoundEventAnnotation], bool] + +Position = tuple[float, float] +Size = np.ndarray + + +class TargetProtocol(Protocol): + class_names: list[str] + detection_class_tags: list[data.Tag] + detection_class_name: str + dimension_names: list[str] + + def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ... + + def encode_class( + self, + sound_event: data.SoundEventAnnotation, + ) -> str | None: ... + + def decode_class(self, class_label: str) -> list[data.Tag]: ... + + def encode_roi( + self, + sound_event: data.SoundEventAnnotation, + ) -> tuple[Position, Size]: ... + + def decode_roi( + self, + position: Position, + size: Size, + class_name: str | None = None, + ) -> data.Geometry: ... + + +class ROITargetMapper(Protocol): + dimension_names: list[str] + + def encode( + self, sound_event: data.SoundEvent + ) -> tuple[Position, Size]: ... + + def decode(self, position: Position, size: Size) -> data.Geometry: ... diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 2c452cc..41e860e 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -13,6 +13,7 @@ from soundevent.geometry import scale_geometry, shift_geometry from batdetect2.audio.clips import get_subclip_annotation from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ +from batdetect2.audio.types import AudioLoader from batdetect2.core.arrays import adjust_width from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.registries import ( @@ -20,7 +21,7 @@ from batdetect2.core.registries import ( Registry, add_import_config, ) -from batdetect2.typing import AudioLoader, Augmentation +from batdetect2.train.types import Augmentation __all__ = [ "AugmentationConfig", diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 2383955..43fb000 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -5,17 +5,15 @@ from lightning.pytorch.callbacks import Callback from soundevent import data from torch.utils.data import DataLoader +from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import get_image_logger +from batdetect2.models.types import ModelOutput from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess import to_raw_predictions +from batdetect2.postprocess.types import ClipDetections from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule -from batdetect2.typing import ( - ClipDetections, - EvaluatorProtocol, - ModelOutput, - TrainExample, -) +from batdetect2.train.types import TrainExample class ValidationMetrics(Callback): diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 34898aa..3e22b8e 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -8,9 +8,11 @@ from torch.utils.data import DataLoader, Dataset from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper from batdetect2.audio.clips import PaddedClipConfig +from batdetect2.audio.types import AudioLoader, ClipperProtocol from batdetect2.core import BaseConfig from batdetect2.core.arrays import adjust_width from batdetect2.preprocess import build_preprocessor +from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, @@ -18,14 +20,7 @@ from batdetect2.train.augmentations import ( build_augmentations, ) from batdetect2.train.labels import build_clip_labeler -from batdetect2.typing import ( - AudioLoader, - Augmentation, - ClipLabeller, - ClipperProtocol, - PreprocessorProtocol, - TrainExample, -) +from batdetect2.train.types import Augmentation, ClipLabeller, TrainExample __all__ = [ "TrainingDataset", diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 1d83b42..96e8d44 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -15,7 +15,8 @@ from soundevent import data from batdetect2.core.configs import BaseConfig, load_config from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.targets import build_targets, iterate_encoded_sound_events -from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol +from batdetect2.targets.types import TargetProtocol +from batdetect2.train.types import ClipLabeller, Heatmaps __all__ = [ "LabelConfig", diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 76e20c4..6e51ce1 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -2,11 +2,12 @@ import lightning as L from soundevent.data import PathLike from batdetect2.models import Model, ModelConfig, build_model +from batdetect2.models.types import ModelOutput from batdetect2.train.config import TrainingConfig from batdetect2.train.losses import build_loss from batdetect2.train.optimizers import build_optimizer from batdetect2.train.schedulers import build_scheduler -from batdetect2.typing import LossProtocol, ModelOutput, TrainExample +from batdetect2.train.types import LossProtocol, TrainExample __all__ = [ "TrainingModule", diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index b98d2dd..fd675a5 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -26,7 +26,8 @@ from pydantic import Field from torch import nn from batdetect2.core.configs import BaseConfig -from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample +from batdetect2.models.types import ModelOutput +from batdetect2.train.types import Losses, LossProtocol, TrainExample __all__ = [ "BBoxLoss", diff --git a/src/batdetect2/train/optimizers.py b/src/batdetect2/train/optimizers.py index 54543b7..c5b79cc 100644 --- a/src/batdetect2/train/optimizers.py +++ b/src/batdetect2/train/optimizers.py @@ -43,7 +43,7 @@ optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry( ) -@add_import_config(optimizer_registry) +@add_import_config(optimizer_registry, arg_names=["params"]) class OptimizerImportConfig(ImportConfig): """Use any callable as an optimizer. @@ -84,4 +84,4 @@ def build_optimizer( Optimizer configuration. Defaults to ``AdamOptimizerConfig``. """ config = config or AdamOptimizerConfig() - return optimizer_registry.build(config, params=parameters) + return optimizer_registry.build(config, parameters) diff --git a/src/batdetect2/train/schedulers.py b/src/batdetect2/train/schedulers.py index ae1c742..7c69741 100644 --- a/src/batdetect2/train/schedulers.py +++ b/src/batdetect2/train/schedulers.py @@ -40,7 +40,7 @@ class CosineAnnealingSchedulerConfig(BaseConfig): scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler") -@add_import_config(scheduler_registry) +@add_import_config(scheduler_registry, arg_names=["optimizer"]) class SchedulerImportConfig(ImportConfig): """Use any callable as a scheduler. @@ -78,4 +78,4 @@ def build_scheduler( """Build a scheduler from configuration.""" config = config or CosineAnnealingSchedulerConfig() - return scheduler_registry.build(config, optimizer=optimizer) + return scheduler_registry.build(config, optimizer) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index bec3cf1..3f61654 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,32 +1,28 @@ from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import Optional from lightning import Trainer, seed_everything from loguru import logger from soundevent import data from batdetect2.audio import AudioConfig, build_audio_loader +from batdetect2.audio.types import AudioLoader from batdetect2.evaluate import build_evaluator +from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import build_logger from batdetect2.models import ModelConfig from batdetect2.preprocess import build_preprocessor +from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import build_targets +from batdetect2.targets.types import TargetProtocol from batdetect2.train import TrainingConfig from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.checkpoints import build_checkpoint_callback from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module - -if TYPE_CHECKING: - from batdetect2.typing import ( - AudioLoader, - ClipLabeller, - EvaluatorProtocol, - PreprocessorProtocol, - TargetProtocol, - ) +from batdetect2.train.types import ClipLabeller __all__ = [ "build_trainer", diff --git a/src/batdetect2/train/types.py b/src/batdetect2/train/types.py new file mode 100644 index 0000000..ed347e9 --- /dev/null +++ b/src/batdetect2/train/types.py @@ -0,0 +1,70 @@ +from collections.abc import Callable +from typing import TYPE_CHECKING, NamedTuple, Protocol + +import torch +from soundevent import data + +if TYPE_CHECKING: + from batdetect2.models.types import ModelOutput + +__all__ = [ + "Augmentation", + "ClipLabeller", + "Heatmaps", + "Losses", + "LossProtocol", + "TrainExample", +] + + +class Heatmaps(NamedTuple): + detection: torch.Tensor + classes: torch.Tensor + size: torch.Tensor + + +class PreprocessedExample(NamedTuple): + audio: torch.Tensor + spectrogram: torch.Tensor + detection_heatmap: torch.Tensor + class_heatmap: torch.Tensor + size_heatmap: torch.Tensor + + def copy(self): + return PreprocessedExample( + audio=self.audio.clone(), + spectrogram=self.spectrogram.clone(), + detection_heatmap=self.detection_heatmap.clone(), + size_heatmap=self.size_heatmap.clone(), + class_heatmap=self.class_heatmap.clone(), + ) + + +ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps] + + +Augmentation = Callable[ + [torch.Tensor, data.ClipAnnotation], + tuple[torch.Tensor, data.ClipAnnotation], +] + + +class TrainExample(NamedTuple): + spec: torch.Tensor + detection_heatmap: torch.Tensor + class_heatmap: torch.Tensor + size_heatmap: torch.Tensor + idx: torch.Tensor + start_time: torch.Tensor + end_time: torch.Tensor + + +class Losses(NamedTuple): + detection: torch.Tensor + size: torch.Tensor + classification: torch.Tensor + total: torch.Tensor + + +class LossProtocol(Protocol): + def __call__(self, pred: "ModelOutput", gt: TrainExample) -> Losses: ... diff --git a/src/batdetect2/types.py b/src/batdetect2/types.py index 35cdacc..5d7455e 100644 --- a/src/batdetect2/types.py +++ b/src/batdetect2/types.py @@ -1,18 +1,14 @@ """Types used in the code base.""" -from typing import Any, NamedTuple, TypedDict +import sys +from typing import Any, NamedTuple, Protocol, TypedDict import numpy as np import torch -try: - from typing import Protocol -except ImportError: - from typing_extensions import Protocol - -try: - from typing import NotRequired # type: ignore -except ImportError: +if sys.version_info >= (3, 11): + from typing import NotRequired +else: from typing_extensions import NotRequired diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py deleted file mode 100644 index 8571b5c..0000000 --- a/src/batdetect2/typing/__init__.py +++ /dev/null @@ -1,75 +0,0 @@ -from batdetect2.typing.data import OutputFormatterProtocol -from batdetect2.typing.evaluate import ( - AffinityFunction, - ClipMatches, - EvaluatorProtocol, - MatcherProtocol, - MatchEvaluation, - MetricsProtocol, - PlotterProtocol, -) -from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput -from batdetect2.typing.postprocess import ( - ClipDetections, - ClipDetectionsTensor, - Detection, - GeometryDecoder, - PostprocessorProtocol, -) -from batdetect2.typing.preprocess import ( - AudioLoader, - PreprocessorProtocol, -) -from batdetect2.typing.targets import ( - Position, - ROITargetMapper, - Size, - SoundEventDecoder, - SoundEventEncoder, - SoundEventFilter, - TargetProtocol, -) -from batdetect2.typing.train import ( - Augmentation, - ClipLabeller, - ClipperProtocol, - Heatmaps, - Losses, - LossProtocol, - TrainExample, -) - -__all__ = [ - "AffinityFunction", - "AudioLoader", - "Augmentation", - "BackboneModel", - "ClipDetections", - "ClipDetectionsTensor", - "ClipLabeller", - "ClipMatches", - "ClipperProtocol", - "DetectionModel", - "EvaluatorProtocol", - "GeometryDecoder", - "Heatmaps", - "LossProtocol", - "Losses", - "MatchEvaluation", - "MatcherProtocol", - "MetricsProtocol", - "ModelOutput", - "OutputFormatterProtocol", - "PlotterProtocol", - "Position", - "PostprocessorProtocol", - "PreprocessorProtocol", - "ROITargetMapper", - "Detection", - "Size", - "SoundEventDecoder", - "SoundEventEncoder", - "SoundEventFilter", - "TargetProtocol", - "TrainExample", -] diff --git a/src/batdetect2/typing/models.py b/src/batdetect2/typing/models.py deleted file mode 100644 index 2bb1881..0000000 --- a/src/batdetect2/typing/models.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Defines shared interfaces (ABCs) and data structures for models. - -This module centralizes the definitions of core data structures, like the -standard model output container (`ModelOutput`), and establishes abstract base -classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts -for fundamental model components, ensuring modularity and consistent -interaction within the `batdetect2.models` package. - -Key components: -- `ModelOutput`: Standard structure for outputs from detection models. -- `BackboneModel`: Generic interface for any feature extraction backbone. -- `EncoderDecoderModel`: Specialized interface for backbones with distinct - encoder-decoder stages (e.g., U-Net), providing access to intermediate - features. -- `DetectionModel`: Interface for the complete end-to-end detection model. -""" - -from abc import ABC, abstractmethod -from typing import List, NamedTuple, Protocol - -import torch - -__all__ = [ - "ModelOutput", - "BackboneModel", - "EncoderDecoderModel", - "DetectionModel", - "BlockProtocol", - "EncoderProtocol", - "BottleneckProtocol", - "DecoderProtocol", -] - - -class BlockProtocol(Protocol): - """Interface for blocks of network layers.""" - - in_channels: int - out_channels: int - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass of the block.""" - ... - - def get_output_height(self, input_height: int) -> int: - """Calculate the output height based on input height.""" - ... - - -class EncoderProtocol(Protocol): - """Interface for the downsampling path of a network.""" - - in_channels: int - out_channels: int - input_height: int - output_height: int - - def __call__(self, x: torch.Tensor) -> List[torch.Tensor]: - """Forward pass must return intermediate tensors for skip connections.""" - ... - - -class BottleneckProtocol(Protocol): - """Interface for the middle part of a U-Net-like network.""" - - in_channels: int - out_channels: int - input_height: int - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Processes the features from the encoder.""" - ... - - -class DecoderProtocol(Protocol): - """Interface for the upsampling reconstruction path.""" - - in_channels: int - out_channels: int - input_height: int - output_height: int - depth: int - - def __call__( - self, - x: torch.Tensor, - residuals: List[torch.Tensor], - ) -> torch.Tensor: - """Upsamples features while integrating skip connections.""" - ... - - -class ModelOutput(NamedTuple): - """Standard container for the outputs of a BatDetect2 detection model. - - This structure groups the different prediction tensors produced by the - model for a batch of input spectrograms. All tensors typically share the - same spatial dimensions (height H, width W) corresponding to the model's - output resolution, and the same batch size (N). - - Attributes - ---------- - detection_probs : torch.Tensor - Tensor containing the probability of sound event presence at each - location in the output grid. - Shape: `(N, 1, H, W)` - size_preds : torch.Tensor - Tensor containing the predicted size dimensions - (e.g., width and height) for a potential bounding box at each location. - Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height) - class_probs : torch.Tensor - Tensor containing the predicted probabilities (or logits, depending on - the final activation) for each target class at each location. - The number of channels corresponds to the number of specific classes - defined in the `Targets` configuration. - Shape: `(N, num_classes, H, W)` - features : torch.Tensor - Tensor containing features extracted by the model's backbone. These - might be used for downstream tasks or analysis. The number of channels - depends on the specific model architecture. - Shape: `(N, num_features, H, W)` - """ - - detection_probs: torch.Tensor - size_preds: torch.Tensor - class_probs: torch.Tensor - features: torch.Tensor - - -class BackboneModel(ABC, torch.nn.Module): - """Abstract Base Class for generic feature extraction backbone models. - - Defines the minimal interface for a feature extractor network within a - BatDetect2 model. Its primary role is to process an input spectrogram - tensor and produce a spatially rich feature map tensor, which is then - typically consumed by separate prediction heads (for detection, - classification, size). - - This base class is agnostic to the specific internal architecture (e.g., - it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete - implementations must inherit from this class and `torch.nn.Module`, - implement the `forward` method, and define the required attributes. - - Attributes - ---------- - input_height : int - Expected height (number of frequency bins) of the input spectrogram - tensor that the backbone is designed to process. - out_channels : int - Number of channels in the final feature map tensor produced by the - backbone's `forward` method. - """ - - input_height: int - """Expected input spectrogram height (frequency bins).""" - - out_channels: int - """Number of output channels in the final feature map.""" - - @abstractmethod - def forward(self, spec: torch.Tensor) -> torch.Tensor: - """Perform the forward pass to extract features from the spectrogram. - - Parameters - ---------- - spec : torch.Tensor - Input spectrogram tensor, typically with shape - `(batch_size, 1, frequency_bins, time_bins)`. - `frequency_bins` should match `self.input_height`. - - Returns - ------- - torch.Tensor - Output feature map tensor, typically with shape - `(batch_size, self.out_channels, output_height, output_width)`. - The spatial dimensions (`output_height`, `output_width`) depend - on the specific backbone architecture (e.g., they might match the - input or be downsampled). - """ - raise NotImplementedError - - -class EncoderDecoderModel(BackboneModel): - """Abstract Base Class for Encoder-Decoder style backbone models. - - This class specializes `BackboneModel` for architectures that have distinct - encoder stages (downsampling path), a bottleneck, and decoder stages - (upsampling path). - - It provides separate abstract methods for the `encode` and `decode` steps, - allowing access to the intermediate "bottleneck" features produced by the - encoder. This can be useful for tasks like transfer learning or specialized - analyses. - - Attributes - ---------- - input_height : int - (Inherited from BackboneModel) Expected input spectrogram height. - out_channels : int - (Inherited from BackboneModel) Number of output channels in the final - feature map produced by the decoder/forward pass. - bottleneck_channels : int - Number of channels in the feature map produced by the encoder at its - deepest point (the bottleneck), before the decoder starts. - """ - - bottleneck_channels: int - """Number of channels at the encoder's bottleneck.""" - - @abstractmethod - def encode(self, spec: torch.Tensor) -> torch.Tensor: - """Process the input spectrogram through the encoder part. - - Takes the input spectrogram and passes it through the downsampling path - of the network up to the bottleneck layer. - - Parameters - ---------- - spec : torch.Tensor - Input spectrogram tensor, typically with shape - `(batch_size, 1, frequency_bins, time_bins)`. - - Returns - ------- - torch.Tensor - The encoded feature map from the bottleneck layer, typically with - shape `(batch_size, self.bottleneck_channels, bottleneck_height, - bottleneck_width)`. The spatial dimensions are usually downsampled - relative to the input. - """ - ... - - @abstractmethod - def decode(self, encoded: torch.Tensor) -> torch.Tensor: - """Process the bottleneck features through the decoder part. - - Takes the encoded feature map from the bottleneck and passes it through - the upsampling path (potentially using skip connections from the - encoder) to produce the final output feature map. - - Parameters - ---------- - encoded : torch.Tensor - The bottleneck feature map tensor produced by the `encode` method. - - Returns - ------- - torch.Tensor - The final output feature map tensor, typically with shape - `(batch_size, self.out_channels, output_height, output_width)`. - This should match the output shape of the `forward` method. - """ - ... - - -class DetectionModel(ABC, torch.nn.Module): - """Abstract Base Class for complete BatDetect2 detection models. - - Defines the interface for the overall model that takes an input spectrogram - and produces all necessary outputs for detection, classification, and size - prediction, packaged within a `ModelOutput` object. - - Concrete implementations typically combine a `BackboneModel` for feature - extraction with specific prediction heads for each output type. They must - inherit from this class and `torch.nn.Module`, and implement the `forward` - method. - """ - - @abstractmethod - def forward(self, spec: torch.Tensor) -> ModelOutput: - """Perform the forward pass of the full detection model. - - Processes the input spectrogram through the backbone and prediction - heads to generate all required output tensors. - - Parameters - ---------- - spec : torch.Tensor - Input spectrogram tensor, typically with shape - `(batch_size, 1, frequency_bins, time_bins)`. - - Returns - ------- - ModelOutput - A NamedTuple containing the prediction tensors: `detection_probs`, - `size_preds`, `class_probs`, and `features`. - """ diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py deleted file mode 100644 index 08e642d..0000000 --- a/src/batdetect2/typing/postprocess.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Defines shared interfaces and data structures for postprocessing. - -This module centralizes the Protocol definitions and common data structures -used throughout the `batdetect2.postprocess` module. - -The main component is the `PostprocessorProtocol`, which outlines the standard -interface for an object responsible for executing the entire postprocessing -pipeline. This pipeline transforms raw neural network outputs into interpretable -detections represented as `soundevent` objects. Using protocols ensures -modularity and consistent interaction between different parts of the BatDetect2 -system that deal with model predictions. -""" - -from dataclasses import dataclass -from typing import List, NamedTuple, Protocol - -import numpy as np -import torch -from soundevent import data - -from batdetect2.typing.models import ModelOutput -from batdetect2.typing.targets import Position, Size - -__all__ = [ - "Detection", - "PostprocessorProtocol", - "GeometryDecoder", -] - - -# TODO: update the docstring -class GeometryDecoder(Protocol): - """Type alias for a function that recovers geometry from position and size. - - This callable takes: - 1. A position tuple `(time, frequency)`. - 2. A NumPy array of size dimensions (e.g., `[width, height]`). - 3. Optionally a class name of the highest scoring class. This is to accomodate - different ways of decoding geometry that depend on the predicted class. - It should return the reconstructed `soundevent.data.Geometry` (typically a - `BoundingBox`). - """ - - def __call__( - self, position: Position, size: Size, class_name: str | None = None - ) -> data.Geometry: ... - - -@dataclass -class Detection: - geometry: data.Geometry - detection_score: float - class_scores: np.ndarray - features: np.ndarray - - -class ClipDetectionsArray(NamedTuple): - scores: np.ndarray - sizes: np.ndarray - class_scores: np.ndarray - times: np.ndarray - frequencies: np.ndarray - features: np.ndarray - - -class ClipDetectionsTensor(NamedTuple): - scores: torch.Tensor - sizes: torch.Tensor - class_scores: torch.Tensor - times: torch.Tensor - frequencies: torch.Tensor - features: torch.Tensor - - def numpy(self) -> ClipDetectionsArray: - return ClipDetectionsArray( - scores=self.scores.detach().cpu().numpy(), - sizes=self.sizes.detach().cpu().numpy(), - class_scores=self.class_scores.detach().cpu().numpy(), - times=self.times.detach().cpu().numpy(), - frequencies=self.frequencies.detach().cpu().numpy(), - features=self.features.detach().cpu().numpy(), - ) - - -@dataclass -class ClipDetections: - clip: data.Clip - detections: List[Detection] - - -@dataclass -class ClipPrediction: - clip: data.Clip - detection_score: float - class_scores: np.ndarray - - -class PostprocessorProtocol(Protocol): - """Protocol defining the interface for the full postprocessing pipeline.""" - - def __call__( - self, - output: ModelOutput, - ) -> List[ClipDetectionsTensor]: ... diff --git a/src/batdetect2/typing/preprocess.py b/src/batdetect2/typing/preprocess.py deleted file mode 100644 index 80d9069..0000000 --- a/src/batdetect2/typing/preprocess.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Defines common interfaces (Protocols) for preprocessing components. - -This module centralizes the Protocol definitions used throughout the -`batdetect2.preprocess` package. Protocols define expected methods and -signatures, allowing for flexible and interchangeable implementations of -components like audio loaders and spectrogram builders. - -Using these protocols ensures that different parts of the preprocessing -pipeline can interact consistently, regardless of the specific underlying -implementation (e.g., different libraries or custom configurations). -""" - -from typing import Protocol - -import numpy as np -import torch -from soundevent import data - -__all__ = [ - "AudioLoader", - "SpectrogramBuilder", - "PreprocessorProtocol", -] - - -class AudioLoader(Protocol): - """Defines the interface for an audio loading and processing component. - - An AudioLoader is responsible for retrieving audio data corresponding to - different soundevent objects (files, Recordings, Clips) and applying a - configured set of initial preprocessing steps. Adhering to this protocol - allows for different loading strategies or implementations. - """ - - samplerate: int - - def load_file( - self, - path: data.PathLike, - audio_dir: data.PathLike | None = None, - ) -> np.ndarray: - """Load and preprocess audio directly from a file path. - - Parameters - ---------- - path : PathLike - Path to the audio file. - audio_dir : PathLike, optional - A directory prefix to prepend to the path if `path` is relative. - - Raises - ------ - FileNotFoundError - If the audio file cannot be found. - Exception - If the audio file cannot be loaded or processed. - """ - ... - - def load_recording( - self, - recording: data.Recording, - audio_dir: data.PathLike | None = None, - ) -> np.ndarray: - """Load and preprocess the entire audio for a Recording object. - - Parameters - ---------- - recording : data.Recording - The Recording object containing metadata about the audio file. - audio_dir : PathLike, optional - A directory where the audio file associated with the recording - can be found, especially if the path in the recording is relative. - - Returns - ------- - np.ndarray - The loaded and preprocessed audio waveform as a 1-D NumPy - array. Typically loads only the first channel. - - Raises - ------ - FileNotFoundError - If the audio file associated with the recording cannot be found. - Exception - If the audio file cannot be loaded or processed. - """ - ... - - def load_clip( - self, - clip: data.Clip, - audio_dir: data.PathLike | None = None, - ) -> np.ndarray: - """Load and preprocess the audio segment defined by a Clip object. - - Parameters - ---------- - clip : data.Clip - The Clip object specifying the recording and the start/end times - of the segment to load. - audio_dir : PathLike, optional - A directory where the audio file associated with the clip's - recording can be found. - - Returns - ------- - np.ndarray - The loaded and preprocessed audio waveform for the specified - clip duration as a 1-D NumPy array. Typically loads only the - first channel. - - Raises - ------ - FileNotFoundError - If the audio file associated with the clip cannot be found. - Exception - If the audio file cannot be loaded or processed. - """ - ... - - -class SpectrogramBuilder(Protocol): - """Defines the interface for a spectrogram generation component.""" - - def __call__(self, wav: torch.Tensor) -> torch.Tensor: - """Generate a spectrogram from an audio waveform.""" - ... - - -class PreprocessorProtocol(Protocol): - """Defines a high-level interface for the complete preprocessing pipeline.""" - - max_freq: float - - min_freq: float - - input_samplerate: int - - output_samplerate: float - - def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... - - def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ... - - def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ... - - def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ... - - def process_numpy(self, wav: np.ndarray) -> np.ndarray: - """Run the full preprocessing pipeline on a NumPy waveform. - - This default implementation converts the array to a - ``torch.Tensor``, calls :meth:`__call__`, and converts the - result back to a NumPy array. Concrete implementations may - override this for efficiency. - - Parameters - ---------- - wav : np.ndarray - Input waveform as a 1-D NumPy array. - - Returns - ------- - np.ndarray - Preprocessed spectrogram as a NumPy array. - """ - return self(torch.tensor(wav)).numpy() diff --git a/src/batdetect2/typing/targets.py b/src/batdetect2/typing/targets.py deleted file mode 100644 index 8dfe44f..0000000 --- a/src/batdetect2/typing/targets.py +++ /dev/null @@ -1,298 +0,0 @@ -"""Defines the core interface (Protocol) for the target definition pipeline. - -This module specifies the standard structure, attributes, and methods expected -from an object that encapsulates the complete configured logic for processing -sound event annotations within the `batdetect2.targets` system. - -The main component defined here is the `TargetProtocol`. This protocol acts as -a contract for the entire target definition process, covering semantic aspects -(filtering, tag transformation, class encoding/decoding) as well as geometric -aspects (mapping regions of interest to target positions and sizes). It ensures -that components responsible for these tasks can be interacted with consistently -throughout BatDetect2. -""" - -from collections.abc import Callable -from typing import List, Protocol - -import numpy as np -from soundevent import data - -__all__ = [ - "TargetProtocol", - "SoundEventEncoder", - "SoundEventDecoder", - "SoundEventFilter", - "Position", - "Size", -] - -SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None] -"""Type alias for a sound event class encoder function. - -An encoder function takes a sound event annotation and returns the string name -of the target class it belongs to, based on a predefined set of rules. -If the annotation does not match any defined target class according to the -rules, the function returns None. -""" - - -SoundEventDecoder = Callable[[str], List[data.Tag]] -"""Type alias for a sound event class decoder function. - -A decoder function takes a class name string (as predicted by the model or -assigned during encoding) and returns a list of `soundevent.data.Tag` objects -that represent that class according to the configuration. This is used to -translate model outputs back into meaningful annotations. -""" - -SoundEventFilter = Callable[[data.SoundEventAnnotation], bool] -"""Type alias for a filter function. - -A filter function accepts a soundevent.data.SoundEventAnnotation object -and returns True if the annotation should be kept based on the filter's -criteria, or False if it should be discarded. -""" - -Position = tuple[float, float] -"""A tuple representing (time, frequency) coordinates.""" - -Size = np.ndarray -"""A NumPy array representing the size dimensions of a target.""" - - -class TargetProtocol(Protocol): - """Protocol defining the interface for the target definition pipeline. - - This protocol outlines the standard attributes and methods for an object - that encapsulates the complete, configured process for handling sound event - annotations (both tags and geometry). It defines how to: - - Select relevant annotations. - - Encode an annotation into a specific target class name. - - Decode a class name back into representative tags. - - Extract a target reference position from an annotation's geometry (ROI). - - Calculate target size dimensions from an annotation's geometry. - - Recover an approximate geometry (ROI) from a position and size - dimensions. - - Implementations of this protocol bundle all configured logic for these - steps. - - Attributes - ---------- - class_names : List[str] - An ordered list of the unique names of the specific target classes - defined by the configuration. - generic_class_tags : List[data.Tag] - A list of `soundevent.data.Tag` objects representing the configured - generic class category (e.g., used when no specific class matches). - dimension_names : List[str] - A list containing the names of the size dimensions returned by - `get_size` and expected by `recover_roi` (e.g., ['width', 'height']). - """ - - class_names: List[str] - """Ordered list of unique names for the specific target classes.""" - - detection_class_tags: List[data.Tag] - """List of tags representing the detection category (unclassified).""" - - detection_class_name: str - - dimension_names: List[str] - """Names of the size dimensions (e.g., ['width', 'height']).""" - - def filter(self, sound_event: data.SoundEventAnnotation) -> bool: - """Apply the filter to a sound event annotation. - - Determines if the annotation should be included in further processing - and training based on the configured filtering rules. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation to filter. - - Returns - ------- - bool - True if the annotation should be kept (passes the filter), - False otherwise. Implementations should return True if no - filtering is configured. - """ - ... - - def encode_class( - self, - sound_event: data.SoundEventAnnotation, - ) -> str | None: - """Encode a sound event annotation to its target class name. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The (potentially filtered and transformed) annotation to encode. - - Returns - ------- - str or None - The string name of the matched target class if the annotation - matches a specific class definition. Returns None if the annotation - does not match any specific class rule (indicating it may belong - to a generic category or should be handled differently downstream). - """ - ... - - def decode_class(self, class_label: str) -> List[data.Tag]: - """Decode a predicted class name back into representative tags. - - Parameters - ---------- - class_label : str - The class name string (e.g., predicted by a model) to decode. - - Returns - ------- - List[data.Tag] - The list of tags corresponding to the input class name according - to the configuration. May return an empty list or raise an error - for unmapped labels, depending on the implementation's configuration - (e.g., `raise_on_unmapped` flag during building). - - Raises - ------ - ValueError, KeyError - Implementations might raise an error if the `class_label` is not - found in the configured mapping and error raising is enabled. - """ - ... - - def encode_roi( - self, sound_event: data.SoundEventAnnotation - ) -> tuple[Position, Size]: - """Extract the target reference position from the annotation's geometry. - - Calculates the `(time, frequency)` coordinate representing the primary - location of the sound event. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation containing the geometry (ROI) to process. - - Returns - ------- - Tuple[float, float] - The calculated reference position `(time, frequency)`. - - Raises - ------ - ValueError - If the annotation lacks geometry or if the position cannot be - calculated for the geometry type or configured reference point. - """ - ... - - # TODO: Update docstrings - def decode_roi( - self, - position: Position, - size: Size, - class_name: str | None = None, - ) -> data.Geometry: - """Recover the ROI geometry from a position and dimensions. - - Performs the inverse mapping of `get_position` and `get_size`. It takes - a reference position `(time, frequency)` and an array of size - dimensions and reconstructs an approximate geometric representation. - - Parameters - ---------- - pos : Tuple[float, float] - The reference position `(time, frequency)`. - dims : np.ndarray - The NumPy array containing the dimensions (e.g., predicted - by the model), corresponding to the order in `dimension_names`. - class_name: str - class - - Returns - ------- - soundevent.data.Geometry - The reconstructed geometry. - - Raises - ------ - ValueError - If the number of provided `dims` does not match `dimension_names`, - if dimensions are invalid (e.g., negative after unscaling), or - if reconstruction fails based on the configured position type. - """ - ... - - -class ROITargetMapper(Protocol): - """Protocol defining the interface for ROI-to-target mapping. - - Specifies the `encode` and `decode` methods required for converting a - `soundevent.data.SoundEvent` into a target representation (a reference - position and a size vector) and for recovering an approximate ROI from that - representation. - - Attributes - ---------- - dimension_names : List[str] - A list containing the names of the dimensions in the `Size` array - returned by `encode` and expected by `decode`. - """ - - dimension_names: List[str] - - def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]: - """Encode a SoundEvent's geometry into a position and size. - - Parameters - ---------- - sound_event : data.SoundEvent - The input sound event, which must have a geometry attribute. - - Returns - ------- - Tuple[Position, Size] - A tuple containing: - - The reference position as (time, frequency) coordinates. - - A NumPy array with the calculated size dimensions. - - Raises - ------ - ValueError - If the sound event does not have a geometry. - """ - ... - - def decode(self, position: Position, size: Size) -> data.Geometry: - """Decode a position and size back into a geometric ROI. - - Performs the inverse mapping: takes a reference position and size - dimensions and reconstructs a geometric representation. - - Parameters - ---------- - position : Position - The reference position (time, frequency). - size : Size - NumPy array containing the size dimensions, matching the order - and meaning specified by `dimension_names`. - - Returns - ------- - soundevent.data.Geometry - The reconstructed geometry, typically a `BoundingBox`. - - Raises - ------ - ValueError - If the `size` array has an unexpected shape or if reconstruction - fails. - """ - ... diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py deleted file mode 100644 index 076e287..0000000 --- a/src/batdetect2/typing/train.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Callable, NamedTuple, Protocol - -import torch -from soundevent import data - -from batdetect2.typing.models import ModelOutput - -__all__ = [ - "Augmentation", - "ClipLabeller", - "ClipperProtocol", - "Heatmaps", - "LossProtocol", - "Losses", - "TrainExample", -] - - -class Heatmaps(NamedTuple): - """Structure holding the generated heatmap targets.""" - - detection: torch.Tensor - classes: torch.Tensor - size: torch.Tensor - - -class PreprocessedExample(NamedTuple): - audio: torch.Tensor - spectrogram: torch.Tensor - detection_heatmap: torch.Tensor - class_heatmap: torch.Tensor - size_heatmap: torch.Tensor - - def copy(self): - return PreprocessedExample( - audio=self.audio.clone(), - spectrogram=self.spectrogram.clone(), - detection_heatmap=self.detection_heatmap.clone(), - size_heatmap=self.size_heatmap.clone(), - class_heatmap=self.class_heatmap.clone(), - ) - - -ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps] -"""Type alias for the final clip labelling function. - -This function takes the complete annotations for a clip and the corresponding -spectrogram, applies all configured filtering, transformation, and encoding -steps, and returns the final `Heatmaps` used for model training. -""" - - -Augmentation = Callable[ - [torch.Tensor, data.ClipAnnotation], - tuple[torch.Tensor, data.ClipAnnotation], -] - - -class TrainExample(NamedTuple): - spec: torch.Tensor - detection_heatmap: torch.Tensor - class_heatmap: torch.Tensor - size_heatmap: torch.Tensor - idx: torch.Tensor - start_time: torch.Tensor - end_time: torch.Tensor - - -class Losses(NamedTuple): - """Structure to hold the computed loss values. - - Allows returning individual loss components along with the total weighted - loss for monitoring and analysis during training. - - Attributes - ---------- - detection : torch.Tensor - Scalar tensor representing the calculated detection loss component - (before weighting). - size : torch.Tensor - Scalar tensor representing the calculated size regression loss component - (before weighting). - classification : torch.Tensor - Scalar tensor representing the calculated classification loss component - (before weighting). - total : torch.Tensor - Scalar tensor representing the final combined loss, computed as the - weighted sum of the detection, size, and classification components. - This is the value typically used for backpropagation. - """ - - detection: torch.Tensor - size: torch.Tensor - classification: torch.Tensor - total: torch.Tensor - - -class LossProtocol(Protocol): - def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ... - - -class ClipperProtocol(Protocol): - def __call__( - self, - clip_annotation: data.ClipAnnotation, - ) -> data.ClipAnnotation: ... - - def get_subclip(self, clip: data.Clip) -> data.Clip: ... diff --git a/tests/conftest.py b/tests/conftest.py index 43922a4..367992a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,23 +11,20 @@ from soundevent import data, terms from batdetect2.audio import build_audio_loader from batdetect2.audio.clips import build_clipper +from batdetect2.audio.types import AudioLoader, ClipperProtocol from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor +from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import ( TargetConfig, build_targets, call_type, ) from batdetect2.targets.classes import TargetClassConfig +from batdetect2.targets.types import TargetProtocol from batdetect2.train.labels import build_clip_labeler -from batdetect2.typing import ( - ClipLabeller, - PreprocessorProtocol, - TargetProtocol, -) -from batdetect2.typing.preprocess import AudioLoader -from batdetect2.typing.train import ClipperProtocol +from batdetect2.train.types import ClipLabeller @pytest.fixture diff --git a/tests/test_data/test_predictions/test_parquet.py b/tests/test_data/test_predictions/test_parquet.py index b17e8bb..408ada9 100644 --- a/tests/test_data/test_predictions/test_parquet.py +++ b/tests/test_data/test_predictions/test_parquet.py @@ -9,11 +9,8 @@ from batdetect2.outputs.formats import ( ParquetOutputConfig, build_output_formatter, ) -from batdetect2.typing import ( - ClipDetections, - Detection, - TargetProtocol, -) +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol @pytest.fixture diff --git a/tests/test_data/test_predictions/test_raw.py b/tests/test_data/test_predictions/test_raw.py index c51b41f..f4fa1ab 100644 --- a/tests/test_data/test_predictions/test_raw.py +++ b/tests/test_data/test_predictions/test_raw.py @@ -5,11 +5,8 @@ import pytest from soundevent import data from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter -from batdetect2.typing import ( - ClipDetections, - Detection, - TargetProtocol, -) +from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.targets.types import TargetProtocol @pytest.fixture diff --git a/tests/test_evaluate/test_tasks/conftest.py b/tests/test_evaluate/test_tasks/conftest.py index 8999d83..25bdf31 100644 --- a/tests/test_evaluate/test_tasks/conftest.py +++ b/tests/test_evaluate/test_tasks/conftest.py @@ -4,7 +4,7 @@ import numpy as np import pytest from soundevent import data -from batdetect2.typing import Detection +from batdetect2.postprocess.types import Detection @pytest.fixture diff --git a/tests/test_evaluate/test_tasks/test_classification.py b/tests/test_evaluate/test_tasks/test_classification.py index 0647240..6123a08 100644 --- a/tests/test_evaluate/test_tasks/test_classification.py +++ b/tests/test_evaluate/test_tasks/test_classification.py @@ -4,8 +4,8 @@ from soundevent import data from batdetect2.evaluate.tasks import build_task from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig -from batdetect2.typing import ClipDetections -from batdetect2.typing.targets import TargetProtocol +from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import TargetProtocol def test_classification( diff --git a/tests/test_evaluate/test_tasks/test_detection.py b/tests/test_evaluate/test_tasks/test_detection.py index 5bbaa71..2d44851 100644 --- a/tests/test_evaluate/test_tasks/test_detection.py +++ b/tests/test_evaluate/test_tasks/test_detection.py @@ -4,8 +4,8 @@ from soundevent import data from batdetect2.evaluate.tasks import build_task from batdetect2.evaluate.tasks.detection import DetectionTaskConfig -from batdetect2.typing import ClipDetections -from batdetect2.typing.targets import TargetProtocol +from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import TargetProtocol def test_detection( diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index 2f9c00a..7af92cd 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -13,7 +13,7 @@ from batdetect2.models.backbones import ( build_backbone, load_backbone_config, ) -from batdetect2.typing.models import BackboneModel +from batdetect2.models.types import BackboneModel def test_unet_backbone_config_defaults(): diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index 823e07d..5cee836 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -7,7 +7,7 @@ from batdetect2.models.backbones import UNetBackboneConfig from batdetect2.models.detectors import Detector, build_detector from batdetect2.models.encoder import Encoder from batdetect2.models.heads import BBoxHead, ClassifierHead -from batdetect2.typing.models import ModelOutput +from batdetect2.models.types import ModelOutput @pytest.fixture diff --git a/tests/test_outputs/test_transform/test_transform.py b/tests/test_outputs/test_transform/test_transform.py index 811da12..bf98ce5 100644 --- a/tests/test_outputs/test_transform/test_transform.py +++ b/tests/test_outputs/test_transform/test_transform.py @@ -3,7 +3,7 @@ from soundevent import data from soundevent.geometry import compute_bounds from batdetect2.outputs import build_output_transform -from batdetect2.typing import ClipDetections, Detection +from batdetect2.postprocess.types import ClipDetections, Detection def test_shift_time_to_clip_start(clip: data.Clip): diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index ae79818..212e10b 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -14,7 +14,8 @@ from batdetect2.postprocess.decoding import ( get_generic_tags, get_prediction_features, ) -from batdetect2.typing import Detection, TargetProtocol +from batdetect2.postprocess.types import Detection +from batdetect2.targets.types import TargetProtocol @pytest.fixture diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index e6aabc3..2a16928 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -8,6 +8,7 @@ from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.api_v2 import BatDetect2API +from batdetect2.audio.types import AudioLoader from batdetect2.config import BatDetect2Config from batdetect2.models import ModelConfig from batdetect2.train import ( @@ -19,7 +20,6 @@ from batdetect2.train import ( from batdetect2.train.optimizers import AdamOptimizerConfig from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig from batdetect2.train.train import build_training_module -from batdetect2.typing.preprocess import AudioLoader def build_default_module(config: BatDetect2Config | None = None):