Moving types around to each submodule

This commit is contained in:
mbsantiago 2026-03-18 00:01:35 +00:00
parent c226dc3f2b
commit 751be53edf
93 changed files with 570 additions and 1276 deletions

2
.gitignore vendored
View File

@ -102,7 +102,7 @@ experiments/*
DvcLiveLogger/checkpoints
logs/
mlruns/
outputs/
/outputs/
notebooks/lightning_logs
# Jupiter notebooks

View File

@ -89,5 +89,5 @@ Crucial for training, this module translates physical annotations (Regions of In
## Summary
To navigate this codebase effectively:
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
2. Rely heavily on the typed **Protocols** located in `src/batdetect2/typing/` to understand the inputs and outputs of each subsystem without needing to read the specific implementations.
2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation.
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.

View File

@ -8,6 +8,7 @@ from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import (
@ -15,6 +16,7 @@ from batdetect2.data import (
)
from batdetect2.data.datasets import Dataset
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model
@ -25,24 +27,22 @@ from batdetect2.outputs import (
build_output_transform,
get_output_formatter,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.postprocess.types import (
ClipDetections,
Detection,
PostprocessorProtocol,
)
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
load_model_from_checkpoint,
run_train,
)
from batdetect2.typing import (
AudioLoader,
ClipDetections,
Detection,
EvaluatorProtocol,
OutputFormatterProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
class BatDetect2API:

View File

@ -6,13 +6,13 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.audio.types import ClipperProtocol
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1

View File

@ -5,8 +5,8 @@ from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.audio.types import AudioLoader
from batdetect2.core import BaseConfig
from batdetect2.typing import AudioLoader
__all__ = [
"SoundEventAudioLoader",

View File

@ -0,0 +1,40 @@
from typing import Protocol
import numpy as np
from soundevent import data
__all__ = [
"AudioLoader",
"ClipperProtocol",
]
class AudioLoader(Protocol):
samplerate: int
def load_file(
self,
path: data.PathLike,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
def load_recording(
self,
recording: data.Recording,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
def load_clip(
self,
clip: data.Clip,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
class ClipperProtocol(Protocol):
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ...
def get_subclip(self, clip: data.Clip) -> data.Clip: ...

View File

@ -4,6 +4,7 @@ from typing import (
Concatenate,
Generic,
ParamSpec,
Sequence,
Type,
TypeVar,
)
@ -147,6 +148,7 @@ T_Import = TypeVar("T_Import", bound=ImportConfig)
def add_import_config(
registry: Registry[T_Type, P_Type],
arg_names: Sequence[str] | None = None,
) -> Callable[[Type[T_Import]], Type[T_Import]]:
"""Decorator that registers an ImportConfig subclass as an escape hatch.
@ -181,15 +183,22 @@ def add_import_config(
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type:
if len(args) > 0:
_arg_names = arg_names or []
if len(args) != len(_arg_names):
raise ValueError(
"Positional arguments are not supported "
"for import escape hatch."
"for import escape hatch unless you specify "
"the argument names. Use `arg_names` to specify "
"the names of the positional arguments."
)
args_dict = {_arg_names[i]: args[i] for i in range(len(args))}
hydra_cfg = {
"_target_": config.target,
**config.arguments,
**args_dict,
**kwargs,
}
return instantiate(hydra_cfg)

View File

@ -3,7 +3,7 @@ from collections.abc import Generator
from soundevent import data
from batdetect2.data.datasets import Dataset
from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
def iterate_over_sound_events(

View File

@ -5,7 +5,7 @@ from batdetect2.data.summary import (
extract_recordings_df,
extract_sound_events_df,
)
from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
def split_dataset_by_recordings(

View File

@ -2,7 +2,7 @@ import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.data.datasets import Dataset
from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"extract_recordings_df",

View File

@ -16,7 +16,8 @@ from batdetect2.core import (
Registry,
add_import_config,
)
from batdetect2.typing import AffinityFunction, Detection
from batdetect2.evaluate.types import AffinityFunction
from batdetect2.postprocess.types import Detection
affinity_functions: Registry[AffinityFunction, []] = Registry(
"affinity_function"

View File

@ -8,14 +8,11 @@ from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipperProtocol,
PreprocessorProtocol,
)
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"TestDataset",

View File

@ -5,22 +5,20 @@ from lightning import Trainer
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.outputs import build_output_transform
from batdetect2.typing import Detection
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import Detection
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
OutputFormatterProtocol,
PreprocessorProtocol,
TargetProtocol,
)
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"

View File

@ -5,9 +5,10 @@ from soundevent import data
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
from batdetect2.typing.postprocess import ClipDetections
from batdetect2.targets.types import TargetProtocol
__all__ = [
"Evaluator",

View File

@ -5,12 +5,12 @@ from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import ClipDetections
from batdetect2.postprocess.types import ClipDetections
class EvaluationModule(LightningModule):

View File

@ -26,7 +26,8 @@ from batdetect2.evaluate.metrics.common import (
average_precision,
compute_precision_recall,
)
from batdetect2.typing import Detection, TargetProtocol
from batdetect2.postprocess.types import Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClassificationMetric",

View File

@ -20,7 +20,7 @@ from batdetect2.core import (
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import Detection
from batdetect2.postprocess.types import Detection
__all__ = [
"DetectionMetricConfig",

View File

@ -20,8 +20,8 @@ from batdetect2.core import (
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import Detection
from batdetect2.typing.targets import TargetProtocol
from batdetect2.postprocess.types import Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"TopClassMetricConfig",

View File

@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
class BasePlotConfig(BaseConfig):

View File

@ -29,7 +29,7 @@ from batdetect2.plotting.metrics import (
plot_threshold_recall_curve,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
ClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]

View File

@ -22,7 +22,7 @@ from batdetect2.plotting.metrics import (
plot_roc_curve,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClipClassificationPlotConfig",

View File

@ -18,7 +18,7 @@ from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClipDetectionPlotConfig",

View File

@ -16,6 +16,7 @@ from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
@ -23,7 +24,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]

View File

@ -16,6 +16,7 @@ from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import (
@ -27,7 +28,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]

View File

@ -11,12 +11,10 @@ from batdetect2.evaluate.tasks.clip_classification import (
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets import build_targets
from batdetect2.typing import (
ClipDetections,
EvaluatorProtocol,
TargetProtocol,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"TaskConfig",

View File

@ -26,13 +26,12 @@ from batdetect2.evaluate.affinity import (
TimeAffinityConfig,
build_affinity_function,
)
from batdetect2.typing import (
from batdetect2.evaluate.types import (
AffinityFunction,
ClipDetections,
Detection,
EvaluatorProtocol,
TargetProtocol,
)
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"BaseTaskConfig",

View File

@ -21,11 +21,8 @@ from batdetect2.evaluate.tasks.base import (
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import (
ClipDetections,
Detection,
TargetProtocol,
)
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
class ClassificationTaskConfig(BaseSEDTaskConfig):

View File

@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import ClipDetections, TargetProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig):

View File

@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import ClipDetections, TargetProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig):

View File

@ -20,8 +20,8 @@ from batdetect2.evaluate.tasks.base import (
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import ClipDetections
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class DetectionTaskConfig(BaseSEDTaskConfig):

View File

@ -20,7 +20,8 @@ from batdetect2.evaluate.tasks.base import (
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import ClipDetections, TargetProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):

View File

@ -1,45 +1,39 @@
from dataclasses import dataclass
from typing import (
Generic,
Iterable,
Protocol,
Sequence,
TypeVar,
)
from typing import Generic, Iterable, Protocol, Sequence, TypeVar
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.typing.postprocess import ClipDetections, Detection
from batdetect2.typing.targets import TargetProtocol
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"AffinityFunction",
"ClipMatches",
"EvaluatorProtocol",
"MetricsProtocol",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"PlotterProtocol",
]
@dataclass
class MatchEvaluation:
clip: data.Clip
sound_event_annotation: data.SoundEventAnnotation | None
gt_det: bool
gt_class: str | None
gt_geometry: data.Geometry | None
pred_score: float
pred_class_scores: dict[str, float]
pred_geometry: data.Geometry | None
affinity: float
@property
def top_class(self) -> str | None:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
@ -53,10 +47,8 @@ class MatchEvaluation:
@property
def top_class_score(self) -> float:
pred_class = self.top_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
@ -75,9 +67,6 @@ class MatcherProtocol(Protocol):
) -> Iterable[tuple[int | None, int | None, float]]: ...
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
class AffinityFunction(Protocol):
def __call__(
self,
@ -115,9 +104,11 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
) -> EvaluationOutput: ...
def compute_metrics(
self, eval_outputs: EvaluationOutput
self,
eval_outputs: EvaluationOutput,
) -> dict[str, float]: ...
def generate_plots(
self, eval_outputs: EvaluationOutput
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...

View File

@ -4,22 +4,20 @@ from lightning import Trainer
from soundevent import data
from batdetect2.audio.loader import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import ClipDetections
from batdetect2.targets.types import TargetProtocol
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
def run_batch_inference(

View File

@ -6,10 +6,11 @@ from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"InferenceDataset",

View File

@ -7,7 +7,7 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import ClipDetections
from batdetect2.postprocess.types import ClipDetections
class InferenceModule(LightningModule):

View File

@ -62,16 +62,16 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import DetectionModel
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.typing import (
from batdetect2.postprocess.types import (
ClipDetectionsTensor,
DetectionModel,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import TargetProtocol
__all__ = [
"BBoxHead",

View File

@ -51,7 +51,7 @@ from batdetect2.models.encoder import (
EncoderConfig,
build_encoder,
)
from batdetect2.typing.models import (
from batdetect2.models.types import (
BackboneModel,
BottleneckProtocol,
DecoderProtocol,

View File

@ -31,7 +31,7 @@ from batdetect2.models.blocks import (
VerticalConv,
build_layer,
)
from batdetect2.typing.models import BottleneckProtocol
from batdetect2.models.types import BottleneckProtocol
__all__ = [
"BottleneckConfig",

View File

@ -26,7 +26,7 @@ from batdetect2.models.backbones import (
build_backbone,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
__all__ = [
"Detector",

View File

@ -0,0 +1,86 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Protocol
import torch
__all__ = [
"BackboneModel",
"BlockProtocol",
"BottleneckProtocol",
"DecoderProtocol",
"DetectionModel",
"EncoderDecoderModel",
"EncoderProtocol",
"ModelOutput",
]
class BlockProtocol(Protocol):
in_channels: int
out_channels: int
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
def get_output_height(self, input_height: int) -> int: ...
class EncoderProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
output_height: int
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
class BottleneckProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
class DecoderProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
output_height: int
depth: int
def __call__(
self,
x: torch.Tensor,
residuals: list[torch.Tensor],
) -> torch.Tensor: ...
class ModelOutput(NamedTuple):
detection_probs: torch.Tensor
size_preds: torch.Tensor
class_probs: torch.Tensor
features: torch.Tensor
class BackboneModel(ABC, torch.nn.Module):
input_height: int
out_channels: int
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class EncoderDecoderModel(BackboneModel):
bottleneck_channels: int
@abstractmethod
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
class DetectionModel(ABC, torch.nn.Module):
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput: ...

View File

@ -11,7 +11,7 @@ from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"BatDetect2OutputConfig",

View File

@ -4,10 +4,8 @@ from typing import Literal
from soundevent.data import PathLike
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.typing import (
OutputFormatterProtocol,
TargetProtocol,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"OutputFormatterProtocol",

View File

@ -12,12 +12,9 @@ from batdetect2.outputs.formats.base import (
output_formatters,
)
from batdetect2.targets import terms
from batdetect2.typing import (
ClipDetections,
Detection,
OutputFormatterProtocol,
TargetProtocol,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
try:
from typing import NotRequired # type: ignore

View File

@ -13,12 +13,9 @@ from batdetect2.outputs.formats.base import (
make_path_relative,
output_formatters,
)
from batdetect2.typing import (
ClipDetections,
Detection,
OutputFormatterProtocol,
TargetProtocol,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
class ParquetOutputConfig(BaseConfig):

View File

@ -14,12 +14,9 @@ from batdetect2.outputs.formats.base import (
make_path_relative,
output_formatters,
)
from batdetect2.typing import (
ClipDetections,
Detection,
OutputFormatterProtocol,
TargetProtocol,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
class RawOutputConfig(BaseConfig):

View File

@ -8,12 +8,9 @@ from batdetect2.core import BaseConfig
from batdetect2.outputs.formats.base import (
output_formatters,
)
from batdetect2.typing import (
ClipDetections,
Detection,
OutputFormatterProtocol,
TargetProtocol,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
class SoundEventOutputConfig(BaseConfig):

View File

@ -5,7 +5,7 @@ from typing import Protocol
from soundevent.geometry import shift_geometry
from batdetect2.core.configs import BaseConfig
from batdetect2.typing import ClipDetections, Detection
from batdetect2.postprocess.types import ClipDetections, Detection
__all__ = [
"OutputTransform",

View File

@ -1,8 +1,9 @@
from typing import Generic, List, Protocol, Sequence, TypeVar
from collections.abc import Sequence
from typing import Generic, Protocol, TypeVar
from soundevent.data import PathLike
from batdetect2.typing.postprocess import ClipDetections
from batdetect2.postprocess.types import ClipDetections
__all__ = [
"OutputFormatterProtocol",
@ -12,7 +13,7 @@ T = TypeVar("T")
class OutputFormatterProtocol(Protocol, Generic[T]):
def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ...
def format(self, predictions: Sequence[ClipDetections]) -> list[T]: ...
def save(
self,
@ -21,4 +22,4 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
audio_dir: PathLike | None = None,
) -> None: ...
def load(self, path: PathLike) -> List[T]: ...
def load(self, path: PathLike) -> list[T]: ...

View File

@ -3,8 +3,8 @@ from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import create_ax
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"plot_clip_annotation",

View File

@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
from batdetect2.plotting.clips import plot_clip
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"plot_clip_prediction",

View File

@ -4,9 +4,10 @@ from matplotlib.axes import Axes
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"plot_clip",

View File

@ -3,6 +3,7 @@ from typing import Sequence
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.audio.types import AudioLoader
from batdetect2.plotting.matches import (
MatchProtocol,
plot_cross_trigger_match,
@ -10,7 +11,7 @@ from batdetect2.plotting.matches import (
plot_false_positive_match,
plot_true_positive_match,
)
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = ["plot_match_gallery"]

View File

@ -4,12 +4,10 @@ from matplotlib.axes import Axes
from soundevent import data, plot
from soundevent.geometry import compute_bounds
from batdetect2.audio.types import AudioLoader
from batdetect2.plotting.clips import plot_clip
from batdetect2.typing import (
AudioLoader,
Detection,
PreprocessorProtocol,
)
from batdetect2.postprocess.types import Detection
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"plot_false_positive_match",

View File

@ -1,4 +1,4 @@
from batdetect2.typing import ClipDetections
from batdetect2.postprocess.types import ClipDetections
class ClipTransform:

View File

@ -5,11 +5,11 @@ from typing import List
import numpy as np
from soundevent import data
from batdetect2.typing.postprocess import (
from batdetect2.postprocess.types import (
ClipDetectionsArray,
Detection,
)
from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"to_raw_predictions",

View File

@ -19,7 +19,7 @@ from typing import List
import torch
from batdetect2.typing.postprocess import ClipDetectionsTensor
from batdetect2.postprocess.types import ClipDetectionsTensor
__all__ = [
"extract_detection_peaks",

View File

@ -1,18 +1,18 @@
import torch
from loguru import logger
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.config import (
PostprocessConfig,
)
from batdetect2.postprocess.extraction import extract_detection_peaks
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
from batdetect2.postprocess.types import (
ClipDetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"build_postprocessor",

View File

@ -19,8 +19,8 @@ import torch
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.postprocess.types import ClipDetectionsTensor
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [
"features_to_xarray",

View File

@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Protocol
import numpy as np
import torch
from soundevent import data
from batdetect2.targets.types import Position, Size
if TYPE_CHECKING:
from batdetect2.models.types import ModelOutput
__all__ = [
"ClipDetections",
"ClipDetectionsArray",
"ClipDetectionsTensor",
"ClipPrediction",
"Detection",
"GeometryDecoder",
"PostprocessorProtocol",
]
class GeometryDecoder(Protocol):
def __call__(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry: ...
@dataclass
class Detection:
geometry: data.Geometry
detection_score: float
class_scores: np.ndarray
features: np.ndarray
class ClipDetectionsArray(NamedTuple):
scores: np.ndarray
sizes: np.ndarray
class_scores: np.ndarray
times: np.ndarray
frequencies: np.ndarray
features: np.ndarray
class ClipDetectionsTensor(NamedTuple):
scores: torch.Tensor
sizes: torch.Tensor
class_scores: torch.Tensor
times: torch.Tensor
frequencies: torch.Tensor
features: torch.Tensor
def numpy(self) -> ClipDetectionsArray:
return ClipDetectionsArray(
scores=self.scores.detach().cpu().numpy(),
sizes=self.sizes.detach().cpu().numpy(),
class_scores=self.class_scores.detach().cpu().numpy(),
times=self.times.detach().cpu().numpy(),
frequencies=self.frequencies.detach().cpu().numpy(),
features=self.features.detach().cpu().numpy(),
)
@dataclass
class ClipDetections:
clip: data.Clip
detections: list[Detection]
@dataclass
class ClipPrediction:
clip: data.Clip
detection_score: float
class_scores: np.ndarray
class PostprocessorProtocol(Protocol):
def __call__(
self, output: "ModelOutput"
) -> list[ClipDetectionsTensor]: ...

View File

@ -1,7 +1,7 @@
"""Assembles the full batdetect2 preprocessing pipeline.
This module defines :class:`Preprocessor`, the concrete implementation of
:class:`~batdetect2.typing.PreprocessorProtocol`, and the
:class:`~batdetect2.preprocess.types.PreprocessorProtocol`, and the
:func:`build_preprocessor` factory function that constructs it from a
:class:`~batdetect2.preprocess.config.PreprocessingConfig`.
@ -33,7 +33,7 @@ from batdetect2.preprocess.spectrogram import (
build_spectrogram_resizer,
build_spectrogram_transform,
)
from batdetect2.typing import PreprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"Preprocessor",
@ -42,7 +42,7 @@ __all__ = [
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`.
"""Standard implementation of the :class:`~batdetect2.preprocess.types.PreprocessorProtocol`.
Wraps all preprocessing stages as ``torch.nn.Module`` submodules so
that parameters (e.g. PCEN filter coefficients) can be tracked and

View File

@ -0,0 +1,31 @@
from typing import Protocol
import numpy as np
import torch
__all__ = [
"PreprocessorProtocol",
"SpectrogramBuilder",
]
class SpectrogramBuilder(Protocol):
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
class PreprocessorProtocol(Protocol):
max_freq: float
min_freq: float
input_samplerate: int
output_samplerate: float
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy()

View File

@ -16,7 +16,7 @@ from batdetect2.data.conditions import (
)
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import call_type, generic_class
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
__all__ = [
"build_sound_event_decoder",

View File

@ -27,17 +27,13 @@ from pydantic import Field
from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import (
AudioLoader,
Position,
PreprocessorProtocol,
ROITargetMapper,
Size,
)
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import Position, ROITargetMapper, Size
__all__ = [
"Anchor",

View File

@ -16,7 +16,7 @@ from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
build_roi_mapper,
)
from batdetect2.typing.targets import Position, Size, TargetProtocol
from batdetect2.targets.types import Position, Size, TargetProtocol
class Targets(TargetProtocol):

View File

@ -0,0 +1,60 @@
from collections.abc import Callable
from typing import Protocol
import numpy as np
from soundevent import data
__all__ = [
"Position",
"ROITargetMapper",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"TargetProtocol",
]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None]
SoundEventDecoder = Callable[[str], list[data.Tag]]
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
Position = tuple[float, float]
Size = np.ndarray
class TargetProtocol(Protocol):
class_names: list[str]
detection_class_tags: list[data.Tag]
detection_class_name: str
dimension_names: list[str]
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
def encode_class(
self,
sound_event: data.SoundEventAnnotation,
) -> str | None: ...
def decode_class(self, class_label: str) -> list[data.Tag]: ...
def encode_roi(
self,
sound_event: data.SoundEventAnnotation,
) -> tuple[Position, Size]: ...
def decode_roi(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry: ...
class ROITargetMapper(Protocol):
dimension_names: list[str]
def encode(
self, sound_event: data.SoundEvent
) -> tuple[Position, Size]: ...
def decode(self, position: Position, size: Size) -> data.Geometry: ...

View File

@ -13,6 +13,7 @@ from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.audio.types import AudioLoader
from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import (
@ -20,7 +21,7 @@ from batdetect2.core.registries import (
Registry,
add_import_config,
)
from batdetect2.typing import AudioLoader, Augmentation
from batdetect2.train.types import Augmentation
__all__ = [
"AugmentationConfig",

View File

@ -5,17 +5,15 @@ from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models.types import ModelOutput
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.postprocess.types import ClipDetections
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import (
ClipDetections,
EvaluatorProtocol,
ModelOutput,
TrainExample,
)
from batdetect2.train.types import TrainExample
class ValidationMetrics(Callback):

View File

@ -8,9 +8,11 @@ from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
@ -18,14 +20,7 @@ from batdetect2.train.augmentations import (
build_augmentations,
)
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import (
AudioLoader,
Augmentation,
ClipLabeller,
ClipperProtocol,
PreprocessorProtocol,
TrainExample,
)
from batdetect2.train.types import Augmentation, ClipLabeller, TrainExample
__all__ = [
"TrainingDataset",

View File

@ -15,7 +15,8 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.types import ClipLabeller, Heatmaps
__all__ = [
"LabelConfig",

View File

@ -2,11 +2,12 @@ import lightning as L
from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models.types import ModelOutput
from batdetect2.train.config import TrainingConfig
from batdetect2.train.losses import build_loss
from batdetect2.train.optimizers import build_optimizer
from batdetect2.train.schedulers import build_scheduler
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
from batdetect2.train.types import LossProtocol, TrainExample
__all__ = [
"TrainingModule",

View File

@ -26,7 +26,8 @@ from pydantic import Field
from torch import nn
from batdetect2.core.configs import BaseConfig
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
from batdetect2.models.types import ModelOutput
from batdetect2.train.types import Losses, LossProtocol, TrainExample
__all__ = [
"BBoxLoss",

View File

@ -43,7 +43,7 @@ optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry(
)
@add_import_config(optimizer_registry)
@add_import_config(optimizer_registry, arg_names=["params"])
class OptimizerImportConfig(ImportConfig):
"""Use any callable as an optimizer.
@ -84,4 +84,4 @@ def build_optimizer(
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
"""
config = config or AdamOptimizerConfig()
return optimizer_registry.build(config, params=parameters)
return optimizer_registry.build(config, parameters)

View File

@ -40,7 +40,7 @@ class CosineAnnealingSchedulerConfig(BaseConfig):
scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler")
@add_import_config(scheduler_registry)
@add_import_config(scheduler_registry, arg_names=["optimizer"])
class SchedulerImportConfig(ImportConfig):
"""Use any callable as a scheduler.
@ -78,4 +78,4 @@ def build_scheduler(
"""Build a scheduler from configuration."""
config = config or CosineAnnealingSchedulerConfig()
return scheduler_registry.build(config, optimizer=optimizer)
return scheduler_registry.build(config, optimizer)

View File

@ -1,32 +1,28 @@
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import Optional
from lightning import Trainer, seed_everything
from loguru import logger
from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import build_evaluator
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import build_logger
from batdetect2.models import ModelConfig
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainingConfig
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.checkpoints import build_checkpoint_callback
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
if TYPE_CHECKING:
from batdetect2.typing import (
AudioLoader,
ClipLabeller,
EvaluatorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.train.types import ClipLabeller
__all__ = [
"build_trainer",

View File

@ -0,0 +1,70 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, NamedTuple, Protocol
import torch
from soundevent import data
if TYPE_CHECKING:
from batdetect2.models.types import ModelOutput
__all__ = [
"Augmentation",
"ClipLabeller",
"Heatmaps",
"Losses",
"LossProtocol",
"TrainExample",
]
class Heatmaps(NamedTuple):
detection: torch.Tensor
classes: torch.Tensor
size: torch.Tensor
class PreprocessedExample(NamedTuple):
audio: torch.Tensor
spectrogram: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def copy(self):
return PreprocessedExample(
audio=self.audio.clone(),
spectrogram=self.spectrogram.clone(),
detection_heatmap=self.detection_heatmap.clone(),
size_heatmap=self.size_heatmap.clone(),
class_heatmap=self.class_heatmap.clone(),
)
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
Augmentation = Callable[
[torch.Tensor, data.ClipAnnotation],
tuple[torch.Tensor, data.ClipAnnotation],
]
class TrainExample(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class Losses(NamedTuple):
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
class LossProtocol(Protocol):
def __call__(self, pred: "ModelOutput", gt: TrainExample) -> Losses: ...

View File

@ -1,18 +1,14 @@
"""Types used in the code base."""
from typing import Any, NamedTuple, TypedDict
import sys
from typing import Any, NamedTuple, Protocol, TypedDict
import numpy as np
import torch
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
try:
from typing import NotRequired # type: ignore
except ImportError:
if sys.version_info >= (3, 11):
from typing import NotRequired
else:
from typing_extensions import NotRequired

View File

@ -1,75 +0,0 @@
from batdetect2.typing.data import OutputFormatterProtocol
from batdetect2.typing.evaluate import (
AffinityFunction,
ClipMatches,
EvaluatorProtocol,
MatcherProtocol,
MatchEvaluation,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
ClipDetections,
ClipDetectionsTensor,
Detection,
GeometryDecoder,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
)
from batdetect2.typing.targets import (
Position,
ROITargetMapper,
Size,
SoundEventDecoder,
SoundEventEncoder,
SoundEventFilter,
TargetProtocol,
)
from batdetect2.typing.train import (
Augmentation,
ClipLabeller,
ClipperProtocol,
Heatmaps,
Losses,
LossProtocol,
TrainExample,
)
__all__ = [
"AffinityFunction",
"AudioLoader",
"Augmentation",
"BackboneModel",
"ClipDetections",
"ClipDetectionsTensor",
"ClipLabeller",
"ClipMatches",
"ClipperProtocol",
"DetectionModel",
"EvaluatorProtocol",
"GeometryDecoder",
"Heatmaps",
"LossProtocol",
"Losses",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"ModelOutput",
"OutputFormatterProtocol",
"PlotterProtocol",
"Position",
"PostprocessorProtocol",
"PreprocessorProtocol",
"ROITargetMapper",
"Detection",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"TargetProtocol",
"TrainExample",
]

View File

@ -1,287 +0,0 @@
"""Defines shared interfaces (ABCs) and data structures for models.
This module centralizes the definitions of core data structures, like the
standard model output container (`ModelOutput`), and establishes abstract base
classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts
for fundamental model components, ensuring modularity and consistent
interaction within the `batdetect2.models` package.
Key components:
- `ModelOutput`: Standard structure for outputs from detection models.
- `BackboneModel`: Generic interface for any feature extraction backbone.
- `EncoderDecoderModel`: Specialized interface for backbones with distinct
encoder-decoder stages (e.g., U-Net), providing access to intermediate
features.
- `DetectionModel`: Interface for the complete end-to-end detection model.
"""
from abc import ABC, abstractmethod
from typing import List, NamedTuple, Protocol
import torch
__all__ = [
"ModelOutput",
"BackboneModel",
"EncoderDecoderModel",
"DetectionModel",
"BlockProtocol",
"EncoderProtocol",
"BottleneckProtocol",
"DecoderProtocol",
]
class BlockProtocol(Protocol):
"""Interface for blocks of network layers."""
in_channels: int
out_channels: int
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the block."""
...
def get_output_height(self, input_height: int) -> int:
"""Calculate the output height based on input height."""
...
class EncoderProtocol(Protocol):
"""Interface for the downsampling path of a network."""
in_channels: int
out_channels: int
input_height: int
output_height: int
def __call__(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Forward pass must return intermediate tensors for skip connections."""
...
class BottleneckProtocol(Protocol):
"""Interface for the middle part of a U-Net-like network."""
in_channels: int
out_channels: int
input_height: int
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Processes the features from the encoder."""
...
class DecoderProtocol(Protocol):
"""Interface for the upsampling reconstruction path."""
in_channels: int
out_channels: int
input_height: int
output_height: int
depth: int
def __call__(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
"""Upsamples features while integrating skip connections."""
...
class ModelOutput(NamedTuple):
"""Standard container for the outputs of a BatDetect2 detection model.
This structure groups the different prediction tensors produced by the
model for a batch of input spectrograms. All tensors typically share the
same spatial dimensions (height H, width W) corresponding to the model's
output resolution, and the same batch size (N).
Attributes
----------
detection_probs : torch.Tensor
Tensor containing the probability of sound event presence at each
location in the output grid.
Shape: `(N, 1, H, W)`
size_preds : torch.Tensor
Tensor containing the predicted size dimensions
(e.g., width and height) for a potential bounding box at each location.
Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height)
class_probs : torch.Tensor
Tensor containing the predicted probabilities (or logits, depending on
the final activation) for each target class at each location.
The number of channels corresponds to the number of specific classes
defined in the `Targets` configuration.
Shape: `(N, num_classes, H, W)`
features : torch.Tensor
Tensor containing features extracted by the model's backbone. These
might be used for downstream tasks or analysis. The number of channels
depends on the specific model architecture.
Shape: `(N, num_features, H, W)`
"""
detection_probs: torch.Tensor
size_preds: torch.Tensor
class_probs: torch.Tensor
features: torch.Tensor
class BackboneModel(ABC, torch.nn.Module):
"""Abstract Base Class for generic feature extraction backbone models.
Defines the minimal interface for a feature extractor network within a
BatDetect2 model. Its primary role is to process an input spectrogram
tensor and produce a spatially rich feature map tensor, which is then
typically consumed by separate prediction heads (for detection,
classification, size).
This base class is agnostic to the specific internal architecture (e.g.,
it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete
implementations must inherit from this class and `torch.nn.Module`,
implement the `forward` method, and define the required attributes.
Attributes
----------
input_height : int
Expected height (number of frequency bins) of the input spectrogram
tensor that the backbone is designed to process.
out_channels : int
Number of channels in the final feature map tensor produced by the
backbone's `forward` method.
"""
input_height: int
"""Expected input spectrogram height (frequency bins)."""
out_channels: int
"""Number of output channels in the final feature map."""
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Perform the forward pass to extract features from the spectrogram.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, 1, frequency_bins, time_bins)`.
`frequency_bins` should match `self.input_height`.
Returns
-------
torch.Tensor
Output feature map tensor, typically with shape
`(batch_size, self.out_channels, output_height, output_width)`.
The spatial dimensions (`output_height`, `output_width`) depend
on the specific backbone architecture (e.g., they might match the
input or be downsampled).
"""
raise NotImplementedError
class EncoderDecoderModel(BackboneModel):
"""Abstract Base Class for Encoder-Decoder style backbone models.
This class specializes `BackboneModel` for architectures that have distinct
encoder stages (downsampling path), a bottleneck, and decoder stages
(upsampling path).
It provides separate abstract methods for the `encode` and `decode` steps,
allowing access to the intermediate "bottleneck" features produced by the
encoder. This can be useful for tasks like transfer learning or specialized
analyses.
Attributes
----------
input_height : int
(Inherited from BackboneModel) Expected input spectrogram height.
out_channels : int
(Inherited from BackboneModel) Number of output channels in the final
feature map produced by the decoder/forward pass.
bottleneck_channels : int
Number of channels in the feature map produced by the encoder at its
deepest point (the bottleneck), before the decoder starts.
"""
bottleneck_channels: int
"""Number of channels at the encoder's bottleneck."""
@abstractmethod
def encode(self, spec: torch.Tensor) -> torch.Tensor:
"""Process the input spectrogram through the encoder part.
Takes the input spectrogram and passes it through the downsampling path
of the network up to the bottleneck layer.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, 1, frequency_bins, time_bins)`.
Returns
-------
torch.Tensor
The encoded feature map from the bottleneck layer, typically with
shape `(batch_size, self.bottleneck_channels, bottleneck_height,
bottleneck_width)`. The spatial dimensions are usually downsampled
relative to the input.
"""
...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor:
"""Process the bottleneck features through the decoder part.
Takes the encoded feature map from the bottleneck and passes it through
the upsampling path (potentially using skip connections from the
encoder) to produce the final output feature map.
Parameters
----------
encoded : torch.Tensor
The bottleneck feature map tensor produced by the `encode` method.
Returns
-------
torch.Tensor
The final output feature map tensor, typically with shape
`(batch_size, self.out_channels, output_height, output_width)`.
This should match the output shape of the `forward` method.
"""
...
class DetectionModel(ABC, torch.nn.Module):
"""Abstract Base Class for complete BatDetect2 detection models.
Defines the interface for the overall model that takes an input spectrogram
and produces all necessary outputs for detection, classification, and size
prediction, packaged within a `ModelOutput` object.
Concrete implementations typically combine a `BackboneModel` for feature
extraction with specific prediction heads for each output type. They must
inherit from this class and `torch.nn.Module`, and implement the `forward`
method.
"""
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the full detection model.
Processes the input spectrogram through the backbone and prediction
heads to generate all required output tensors.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, 1, frequency_bins, time_bins)`.
Returns
-------
ModelOutput
A NamedTuple containing the prediction tensors: `detection_probs`,
`size_preds`, `class_probs`, and `features`.
"""

View File

@ -1,104 +0,0 @@
"""Defines shared interfaces and data structures for postprocessing.
This module centralizes the Protocol definitions and common data structures
used throughout the `batdetect2.postprocess` module.
The main component is the `PostprocessorProtocol`, which outlines the standard
interface for an object responsible for executing the entire postprocessing
pipeline. This pipeline transforms raw neural network outputs into interpretable
detections represented as `soundevent` objects. Using protocols ensures
modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Protocol
import numpy as np
import torch
from soundevent import data
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.targets import Position, Size
__all__ = [
"Detection",
"PostprocessorProtocol",
"GeometryDecoder",
]
# TODO: update the docstring
class GeometryDecoder(Protocol):
"""Type alias for a function that recovers geometry from position and size.
This callable takes:
1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`).
3. Optionally a class name of the highest scoring class. This is to accomodate
different ways of decoding geometry that depend on the predicted class.
It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
def __call__(
self, position: Position, size: Size, class_name: str | None = None
) -> data.Geometry: ...
@dataclass
class Detection:
geometry: data.Geometry
detection_score: float
class_scores: np.ndarray
features: np.ndarray
class ClipDetectionsArray(NamedTuple):
scores: np.ndarray
sizes: np.ndarray
class_scores: np.ndarray
times: np.ndarray
frequencies: np.ndarray
features: np.ndarray
class ClipDetectionsTensor(NamedTuple):
scores: torch.Tensor
sizes: torch.Tensor
class_scores: torch.Tensor
times: torch.Tensor
frequencies: torch.Tensor
features: torch.Tensor
def numpy(self) -> ClipDetectionsArray:
return ClipDetectionsArray(
scores=self.scores.detach().cpu().numpy(),
sizes=self.sizes.detach().cpu().numpy(),
class_scores=self.class_scores.detach().cpu().numpy(),
times=self.times.detach().cpu().numpy(),
frequencies=self.frequencies.detach().cpu().numpy(),
features=self.features.detach().cpu().numpy(),
)
@dataclass
class ClipDetections:
clip: data.Clip
detections: List[Detection]
@dataclass
class ClipPrediction:
clip: data.Clip
detection_score: float
class_scores: np.ndarray
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline."""
def __call__(
self,
output: ModelOutput,
) -> List[ClipDetectionsTensor]: ...

View File

@ -1,168 +0,0 @@
"""Defines common interfaces (Protocols) for preprocessing components.
This module centralizes the Protocol definitions used throughout the
`batdetect2.preprocess` package. Protocols define expected methods and
signatures, allowing for flexible and interchangeable implementations of
components like audio loaders and spectrogram builders.
Using these protocols ensures that different parts of the preprocessing
pipeline can interact consistently, regardless of the specific underlying
implementation (e.g., different libraries or custom configurations).
"""
from typing import Protocol
import numpy as np
import torch
from soundevent import data
__all__ = [
"AudioLoader",
"SpectrogramBuilder",
"PreprocessorProtocol",
]
class AudioLoader(Protocol):
"""Defines the interface for an audio loading and processing component.
An AudioLoader is responsible for retrieving audio data corresponding to
different soundevent objects (files, Recordings, Clips) and applying a
configured set of initial preprocessing steps. Adhering to this protocol
allows for different loading strategies or implementations.
"""
samplerate: int
def load_file(
self,
path: data.PathLike,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix to prepend to the path if `path` is relative.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_recording(
self,
recording: data.Recording,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object.
Parameters
----------
recording : data.Recording
The Recording object containing metadata about the audio file.
audio_dir : PathLike, optional
A directory where the audio file associated with the recording
can be found, especially if the path in the recording is relative.
Returns
-------
np.ndarray
The loaded and preprocessed audio waveform as a 1-D NumPy
array. Typically loads only the first channel.
Raises
------
FileNotFoundError
If the audio file associated with the recording cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_clip(
self,
clip: data.Clip,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object.
Parameters
----------
clip : data.Clip
The Clip object specifying the recording and the start/end times
of the segment to load.
audio_dir : PathLike, optional
A directory where the audio file associated with the clip's
recording can be found.
Returns
-------
np.ndarray
The loaded and preprocessed audio waveform for the specified
clip duration as a 1-D NumPy array. Typically loads only the
first channel.
Raises
------
FileNotFoundError
If the audio file associated with the clip cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
class SpectrogramBuilder(Protocol):
"""Defines the interface for a spectrogram generation component."""
def __call__(self, wav: torch.Tensor) -> torch.Tensor:
"""Generate a spectrogram from an audio waveform."""
...
class PreprocessorProtocol(Protocol):
"""Defines a high-level interface for the complete preprocessing pipeline."""
max_freq: float
min_freq: float
input_samplerate: int
output_samplerate: float
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
"""Run the full preprocessing pipeline on a NumPy waveform.
This default implementation converts the array to a
``torch.Tensor``, calls :meth:`__call__`, and converts the
result back to a NumPy array. Concrete implementations may
override this for efficiency.
Parameters
----------
wav : np.ndarray
Input waveform as a 1-D NumPy array.
Returns
-------
np.ndarray
Preprocessed spectrogram as a NumPy array.
"""
return self(torch.tensor(wav)).numpy()

View File

@ -1,298 +0,0 @@
"""Defines the core interface (Protocol) for the target definition pipeline.
This module specifies the standard structure, attributes, and methods expected
from an object that encapsulates the complete configured logic for processing
sound event annotations within the `batdetect2.targets` system.
The main component defined here is the `TargetProtocol`. This protocol acts as
a contract for the entire target definition process, covering semantic aspects
(filtering, tag transformation, class encoding/decoding) as well as geometric
aspects (mapping regions of interest to target positions and sizes). It ensures
that components responsible for these tasks can be interacted with consistently
throughout BatDetect2.
"""
from collections.abc import Callable
from typing import List, Protocol
import numpy as np
from soundevent import data
__all__ = [
"TargetProtocol",
"SoundEventEncoder",
"SoundEventDecoder",
"SoundEventFilter",
"Position",
"Size",
]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None]
"""Type alias for a sound event class encoder function.
An encoder function takes a sound event annotation and returns the string name
of the target class it belongs to, based on a predefined set of rules.
If the annotation does not match any defined target class according to the
rules, the function returns None.
"""
SoundEventDecoder = Callable[[str], List[data.Tag]]
"""Type alias for a sound event class decoder function.
A decoder function takes a class name string (as predicted by the model or
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
that represent that class according to the configuration. This is used to
translate model outputs back into meaningful annotations.
"""
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
"""Type alias for a filter function.
A filter function accepts a soundevent.data.SoundEventAnnotation object
and returns True if the annotation should be kept based on the filter's
criteria, or False if it should be discarded.
"""
Position = tuple[float, float]
"""A tuple representing (time, frequency) coordinates."""
Size = np.ndarray
"""A NumPy array representing the size dimensions of a target."""
class TargetProtocol(Protocol):
"""Protocol defining the interface for the target definition pipeline.
This protocol outlines the standard attributes and methods for an object
that encapsulates the complete, configured process for handling sound event
annotations (both tags and geometry). It defines how to:
- Select relevant annotations.
- Encode an annotation into a specific target class name.
- Decode a class name back into representative tags.
- Extract a target reference position from an annotation's geometry (ROI).
- Calculate target size dimensions from an annotation's geometry.
- Recover an approximate geometry (ROI) from a position and size
dimensions.
Implementations of this protocol bundle all configured logic for these
steps.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined by the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class category (e.g., used when no specific class matches).
dimension_names : List[str]
A list containing the names of the size dimensions returned by
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
"""
class_names: List[str]
"""Ordered list of unique names for the specific target classes."""
detection_class_tags: List[data.Tag]
"""List of tags representing the detection category (unclassified)."""
detection_class_name: str
dimension_names: List[str]
"""Names of the size dimensions (e.g., ['width', 'height'])."""
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the filter to a sound event annotation.
Determines if the annotation should be included in further processing
and training based on the configured filtering rules.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to filter.
Returns
-------
bool
True if the annotation should be kept (passes the filter),
False otherwise. Implementations should return True if no
filtering is configured.
"""
...
def encode_class(
self,
sound_event: data.SoundEventAnnotation,
) -> str | None:
"""Encode a sound event annotation to its target class name.
Parameters
----------
sound_event : data.SoundEventAnnotation
The (potentially filtered and transformed) annotation to encode.
Returns
-------
str or None
The string name of the matched target class if the annotation
matches a specific class definition. Returns None if the annotation
does not match any specific class rule (indicating it may belong
to a generic category or should be handled differently downstream).
"""
...
def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Parameters
----------
class_label : str
The class name string (e.g., predicted by a model) to decode.
Returns
-------
List[data.Tag]
The list of tags corresponding to the input class name according
to the configuration. May return an empty list or raise an error
for unmapped labels, depending on the implementation's configuration
(e.g., `raise_on_unmapped` flag during building).
Raises
------
ValueError, KeyError
Implementations might raise an error if the `class_label` is not
found in the configured mapping and error raising is enabled.
"""
...
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary
location of the sound event.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
Tuple[float, float]
The calculated reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry or if the position cannot be
calculated for the geometry type or configured reference point.
"""
...
# TODO: Update docstrings
def decode_roi(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions.
Performs the inverse mapping of `get_position` and `get_size`. It takes
a reference position `(time, frequency)` and an array of size
dimensions and reconstructs an approximate geometric representation.
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
The NumPy array containing the dimensions (e.g., predicted
by the model), corresponding to the order in `dimension_names`.
class_name: str
class
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided `dims` does not match `dimension_names`,
if dimensions are invalid (e.g., negative after unscaling), or
if reconstruction fails based on the configured position type.
"""
...
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the `encode` and `decode` methods required for converting a
`soundevent.data.SoundEvent` into a target representation (a reference
position and a size vector) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions in the `Size` array
returned by `encode` and expected by `decode`.
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent's geometry into a position and size.
Parameters
----------
sound_event : data.SoundEvent
The input sound event, which must have a geometry attribute.
Returns
-------
Tuple[Position, Size]
A tuple containing:
- The reference position as (time, frequency) coordinates.
- A NumPy array with the calculated size dimensions.
Raises
------
ValueError
If the sound event does not have a geometry.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and size
dimensions and reconstructs a geometric representation.
Parameters
----------
position : Position
The reference position (time, frequency).
size : Size
NumPy array containing the size dimensions, matching the order
and meaning specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry, typically a `BoundingBox`.
Raises
------
ValueError
If the `size` array has an unexpected shape or if reconstruction
fails.
"""
...

View File

@ -1,108 +0,0 @@
from typing import Callable, NamedTuple, Protocol
import torch
from soundevent import data
from batdetect2.typing.models import ModelOutput
__all__ = [
"Augmentation",
"ClipLabeller",
"ClipperProtocol",
"Heatmaps",
"LossProtocol",
"Losses",
"TrainExample",
]
class Heatmaps(NamedTuple):
"""Structure holding the generated heatmap targets."""
detection: torch.Tensor
classes: torch.Tensor
size: torch.Tensor
class PreprocessedExample(NamedTuple):
audio: torch.Tensor
spectrogram: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def copy(self):
return PreprocessedExample(
audio=self.audio.clone(),
spectrogram=self.spectrogram.clone(),
detection_heatmap=self.detection_heatmap.clone(),
size_heatmap=self.size_heatmap.clone(),
class_heatmap=self.class_heatmap.clone(),
)
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
"""Type alias for the final clip labelling function.
This function takes the complete annotations for a clip and the corresponding
spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training.
"""
Augmentation = Callable[
[torch.Tensor, data.ClipAnnotation],
tuple[torch.Tensor, data.ClipAnnotation],
]
class TrainExample(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class Losses(NamedTuple):
"""Structure to hold the computed loss values.
Allows returning individual loss components along with the total weighted
loss for monitoring and analysis during training.
Attributes
----------
detection : torch.Tensor
Scalar tensor representing the calculated detection loss component
(before weighting).
size : torch.Tensor
Scalar tensor representing the calculated size regression loss component
(before weighting).
classification : torch.Tensor
Scalar tensor representing the calculated classification loss component
(before weighting).
total : torch.Tensor
Scalar tensor representing the final combined loss, computed as the
weighted sum of the detection, size, and classification components.
This is the value typically used for backpropagation.
"""
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
class LossProtocol(Protocol):
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
class ClipperProtocol(Protocol):
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ...
def get_subclip(self, clip: data.Clip) -> data.Clip: ...

View File

@ -11,23 +11,20 @@ from soundevent import data, terms
from batdetect2.audio import build_audio_loader
from batdetect2.audio.clips import build_clipper
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import (
TargetConfig,
build_targets,
call_type,
)
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import (
ClipLabeller,
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.typing.train import ClipperProtocol
from batdetect2.train.types import ClipLabeller
@pytest.fixture

View File

@ -9,11 +9,8 @@ from batdetect2.outputs.formats import (
ParquetOutputConfig,
build_output_formatter,
)
from batdetect2.typing import (
ClipDetections,
Detection,
TargetProtocol,
)
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
@pytest.fixture

View File

@ -5,11 +5,8 @@ import pytest
from soundevent import data
from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
from batdetect2.typing import (
ClipDetections,
Detection,
TargetProtocol,
)
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
@pytest.fixture

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest
from soundevent import data
from batdetect2.typing import Detection
from batdetect2.postprocess.types import Detection
@pytest.fixture

View File

@ -4,8 +4,8 @@ from soundevent import data
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.typing import ClipDetections
from batdetect2.typing.targets import TargetProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def test_classification(

View File

@ -4,8 +4,8 @@ from soundevent import data
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.typing import ClipDetections
from batdetect2.typing.targets import TargetProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def test_detection(

View File

@ -13,7 +13,7 @@ from batdetect2.models.backbones import (
build_backbone,
load_backbone_config,
)
from batdetect2.typing.models import BackboneModel
from batdetect2.models.types import BackboneModel
def test_unet_backbone_config_defaults():

View File

@ -7,7 +7,7 @@ from batdetect2.models.backbones import UNetBackboneConfig
from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.encoder import Encoder
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import ModelOutput
from batdetect2.models.types import ModelOutput
@pytest.fixture

View File

@ -3,7 +3,7 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.outputs import build_output_transform
from batdetect2.typing import ClipDetections, Detection
from batdetect2.postprocess.types import ClipDetections, Detection
def test_shift_time_to_clip_start(clip: data.Clip):

View File

@ -14,7 +14,8 @@ from batdetect2.postprocess.decoding import (
get_generic_tags,
get_prediction_features,
)
from batdetect2.typing import Detection, TargetProtocol
from batdetect2.postprocess.types import Detection
from batdetect2.targets.types import TargetProtocol
@pytest.fixture

View File

@ -8,6 +8,7 @@ from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio.types import AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.models import ModelConfig
from batdetect2.train import (
@ -19,7 +20,6 @@ from batdetect2.train import (
from batdetect2.train.optimizers import AdamOptimizerConfig
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
from batdetect2.train.train import build_training_module
from batdetect2.typing.preprocess import AudioLoader
def build_default_module(config: BatDetect2Config | None = None):