mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Moving types around to each submodule
This commit is contained in:
parent
c226dc3f2b
commit
751be53edf
2
.gitignore
vendored
2
.gitignore
vendored
@ -102,7 +102,7 @@ experiments/*
|
||||
DvcLiveLogger/checkpoints
|
||||
logs/
|
||||
mlruns/
|
||||
outputs/
|
||||
/outputs/
|
||||
notebooks/lightning_logs
|
||||
|
||||
# Jupiter notebooks
|
||||
|
||||
@ -89,5 +89,5 @@ Crucial for training, this module translates physical annotations (Regions of In
|
||||
## Summary
|
||||
To navigate this codebase effectively:
|
||||
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
|
||||
2. Rely heavily on the typed **Protocols** located in `src/batdetect2/typing/` to understand the inputs and outputs of each subsystem without needing to read the specific implementations.
|
||||
2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation.
|
||||
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.
|
||||
@ -8,6 +8,7 @@ from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.data import (
|
||||
@ -15,6 +16,7 @@ from batdetect2.data import (
|
||||
)
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model
|
||||
@ -25,24 +27,22 @@ from batdetect2.outputs import (
|
||||
build_output_transform,
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
load_model_from_checkpoint,
|
||||
run_train,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipDetections,
|
||||
Detection,
|
||||
EvaluatorProtocol,
|
||||
OutputFormatterProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
|
||||
@ -6,13 +6,13 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||
|
||||
from batdetect2.audio.types import ClipperProtocol
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
@ -5,8 +5,8 @@ from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"SoundEventAudioLoader",
|
||||
|
||||
40
src/batdetect2/audio/types.py
Normal file
40
src/batdetect2/audio/types.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"AudioLoader",
|
||||
"ClipperProtocol",
|
||||
]
|
||||
|
||||
|
||||
class AudioLoader(Protocol):
|
||||
samplerate: int
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray: ...
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray: ...
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray: ...
|
||||
|
||||
|
||||
class ClipperProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation: ...
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
||||
@ -4,6 +4,7 @@ from typing import (
|
||||
Concatenate,
|
||||
Generic,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
@ -147,6 +148,7 @@ T_Import = TypeVar("T_Import", bound=ImportConfig)
|
||||
|
||||
def add_import_config(
|
||||
registry: Registry[T_Type, P_Type],
|
||||
arg_names: Sequence[str] | None = None,
|
||||
) -> Callable[[Type[T_Import]], Type[T_Import]]:
|
||||
"""Decorator that registers an ImportConfig subclass as an escape hatch.
|
||||
|
||||
@ -181,15 +183,22 @@ def add_import_config(
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
) -> T_Type:
|
||||
if len(args) > 0:
|
||||
_arg_names = arg_names or []
|
||||
|
||||
if len(args) != len(_arg_names):
|
||||
raise ValueError(
|
||||
"Positional arguments are not supported "
|
||||
"for import escape hatch."
|
||||
"for import escape hatch unless you specify "
|
||||
"the argument names. Use `arg_names` to specify "
|
||||
"the names of the positional arguments."
|
||||
)
|
||||
|
||||
args_dict = {_arg_names[i]: args[i] for i in range(len(args))}
|
||||
|
||||
hydra_cfg = {
|
||||
"_target_": config.target,
|
||||
**config.arguments,
|
||||
**args_dict,
|
||||
**kwargs,
|
||||
}
|
||||
return instantiate(hydra_cfg)
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Generator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def iterate_over_sound_events(
|
||||
|
||||
@ -5,7 +5,7 @@ from batdetect2.data.summary import (
|
||||
extract_recordings_df,
|
||||
extract_sound_events_df,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def split_dataset_by_recordings(
|
||||
|
||||
@ -2,7 +2,7 @@ import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"extract_recordings_df",
|
||||
|
||||
@ -16,7 +16,8 @@ from batdetect2.core import (
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.typing import AffinityFunction, Detection
|
||||
from batdetect2.evaluate.types import AffinityFunction
|
||||
from batdetect2.postprocess.types import Detection
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
"affinity_function"
|
||||
|
||||
@ -8,14 +8,11 @@ from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"TestDataset",
|
||||
|
||||
@ -5,22 +5,20 @@ from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.typing import Detection
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
OutputFormatterProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
@ -5,9 +5,10 @@ from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
|
||||
@ -5,12 +5,12 @@ from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
|
||||
@ -26,7 +26,8 @@ from batdetect2.evaluate.metrics.common import (
|
||||
average_precision,
|
||||
compute_precision_recall,
|
||||
)
|
||||
from batdetect2.typing import Detection, TargetProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassificationMetric",
|
||||
|
||||
@ -20,7 +20,7 @@ from batdetect2.core import (
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import Detection
|
||||
from batdetect2.postprocess.types import Detection
|
||||
|
||||
__all__ = [
|
||||
"DetectionMetricConfig",
|
||||
|
||||
@ -20,8 +20,8 @@ from batdetect2.core import (
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import Detection
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TopClassMetricConfig",
|
||||
|
||||
@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class BasePlotConfig(BaseConfig):
|
||||
|
||||
@ -29,7 +29,7 @@ from batdetect2.plotting.metrics import (
|
||||
plot_threshold_recall_curve,
|
||||
plot_threshold_recall_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
ClassificationPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
|
||||
@ -22,7 +22,7 @@ from batdetect2.plotting.metrics import (
|
||||
plot_roc_curve,
|
||||
plot_roc_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipClassificationPlotConfig",
|
||||
|
||||
@ -18,7 +18,7 @@ from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionPlotConfig",
|
||||
|
||||
@ -16,6 +16,7 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
@ -23,7 +24,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.detections import plot_clip_detections
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
@ -27,7 +28,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
|
||||
|
||||
|
||||
@ -11,12 +11,10 @@ from batdetect2.evaluate.tasks.clip_classification import (
|
||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
EvaluatorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TaskConfig",
|
||||
|
||||
@ -26,13 +26,12 @@ from batdetect2.evaluate.affinity import (
|
||||
TimeAffinityConfig,
|
||||
build_affinity_function,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
from batdetect2.evaluate.types import (
|
||||
AffinityFunction,
|
||||
ClipDetections,
|
||||
Detection,
|
||||
EvaluatorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BaseTaskConfig",
|
||||
|
||||
@ -21,11 +21,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||
|
||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
|
||||
@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
|
||||
@ -20,8 +20,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||
|
||||
@ -20,7 +20,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||
|
||||
@ -1,45 +1,39 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Generic,
|
||||
Iterable,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Generic, Iterable, Protocol, Sequence, TypeVar
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import ClipDetections, Detection
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"AffinityFunction",
|
||||
"ClipMatches",
|
||||
"EvaluatorProtocol",
|
||||
"MetricsProtocol",
|
||||
"MatchEvaluation",
|
||||
"MatcherProtocol",
|
||||
"MetricsProtocol",
|
||||
"PlotterProtocol",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEvaluation:
|
||||
clip: data.Clip
|
||||
|
||||
sound_event_annotation: data.SoundEventAnnotation | None
|
||||
gt_det: bool
|
||||
gt_class: str | None
|
||||
gt_geometry: data.Geometry | None
|
||||
|
||||
pred_score: float
|
||||
pred_class_scores: dict[str, float]
|
||||
pred_geometry: data.Geometry | None
|
||||
|
||||
affinity: float
|
||||
|
||||
@property
|
||||
def top_class(self) -> str | None:
|
||||
if not self.pred_class_scores:
|
||||
return None
|
||||
|
||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||
|
||||
@property
|
||||
@ -53,10 +47,8 @@ class MatchEvaluation:
|
||||
@property
|
||||
def top_class_score(self) -> float:
|
||||
pred_class = self.top_class
|
||||
|
||||
if pred_class is None:
|
||||
return 0
|
||||
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
|
||||
@ -75,9 +67,6 @@ class MatcherProtocol(Protocol):
|
||||
) -> Iterable[tuple[int | None, int | None, float]]: ...
|
||||
|
||||
|
||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||
|
||||
|
||||
class AffinityFunction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
@ -115,9 +104,11 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
) -> EvaluationOutput: ...
|
||||
|
||||
def compute_metrics(
|
||||
self, eval_outputs: EvaluationOutput
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self, eval_outputs: EvaluationOutput
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
@ -4,22 +4,20 @@ from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio.loader import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
|
||||
@ -6,10 +6,11 @@ from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"InferenceDataset",
|
||||
|
||||
@ -7,7 +7,7 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class InferenceModule(LightningModule):
|
||||
|
||||
@ -62,16 +62,16 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.typing import (
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
DetectionModel,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
|
||||
@ -51,7 +51,7 @@ from batdetect2.models.encoder import (
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.typing.models import (
|
||||
from batdetect2.models.types import (
|
||||
BackboneModel,
|
||||
BottleneckProtocol,
|
||||
DecoderProtocol,
|
||||
|
||||
@ -31,7 +31,7 @@ from batdetect2.models.blocks import (
|
||||
VerticalConv,
|
||||
build_layer,
|
||||
)
|
||||
from batdetect2.typing.models import BottleneckProtocol
|
||||
from batdetect2.models.types import BottleneckProtocol
|
||||
|
||||
__all__ = [
|
||||
"BottleneckConfig",
|
||||
|
||||
@ -26,7 +26,7 @@ from batdetect2.models.backbones import (
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Detector",
|
||||
|
||||
86
src/batdetect2/models/types.py
Normal file
86
src/batdetect2/models/types.py
Normal file
@ -0,0 +1,86 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"BackboneModel",
|
||||
"BlockProtocol",
|
||||
"BottleneckProtocol",
|
||||
"DecoderProtocol",
|
||||
"DetectionModel",
|
||||
"EncoderDecoderModel",
|
||||
"EncoderProtocol",
|
||||
"ModelOutput",
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def get_output_height(self, input_height: int) -> int: ...
|
||||
|
||||
|
||||
class EncoderProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
||||
|
||||
|
||||
class BottleneckProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DecoderProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
depth: int
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: list[torch.Tensor],
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
detection_probs: torch.Tensor
|
||||
size_preds: torch.Tensor
|
||||
class_probs: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
class BackboneModel(ABC, torch.nn.Module):
|
||||
input_height: int
|
||||
out_channels: int
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EncoderDecoderModel(BackboneModel):
|
||||
bottleneck_channels: int
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DetectionModel(ABC, torch.nn.Module):
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||
@ -11,7 +11,7 @@ from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
|
||||
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
|
||||
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2OutputConfig",
|
||||
|
||||
@ -4,10 +4,8 @@ from typing import Literal
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.typing import (
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"OutputFormatterProtocol",
|
||||
|
||||
@ -12,12 +12,9 @@ from batdetect2.outputs.formats.base import (
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
try:
|
||||
from typing import NotRequired # type: ignore
|
||||
|
||||
@ -13,12 +13,9 @@ from batdetect2.outputs.formats.base import (
|
||||
make_path_relative,
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class ParquetOutputConfig(BaseConfig):
|
||||
|
||||
@ -14,12 +14,9 @@ from batdetect2.outputs.formats.base import (
|
||||
make_path_relative,
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class RawOutputConfig(BaseConfig):
|
||||
|
||||
@ -8,12 +8,9 @@ from batdetect2.core import BaseConfig
|
||||
from batdetect2.outputs.formats.base import (
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class SoundEventOutputConfig(BaseConfig):
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Protocol
|
||||
from soundevent.geometry import shift_geometry
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.typing import ClipDetections, Detection
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
|
||||
__all__ = [
|
||||
"OutputTransform",
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Generic, List, Protocol, Sequence, TypeVar
|
||||
from collections.abc import Sequence
|
||||
from typing import Generic, Protocol, TypeVar
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
__all__ = [
|
||||
"OutputFormatterProtocol",
|
||||
@ -12,7 +13,7 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||
def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ...
|
||||
def format(self, predictions: Sequence[ClipDetections]) -> list[T]: ...
|
||||
|
||||
def save(
|
||||
self,
|
||||
@ -21,4 +22,4 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||
audio_dir: PathLike | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def load(self, path: PathLike) -> List[T]: ...
|
||||
def load(self, path: PathLike) -> list[T]: ...
|
||||
@ -3,8 +3,8 @@ from soundevent import data, plot
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.common import create_ax
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_annotation",
|
||||
|
||||
@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
|
||||
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_prediction",
|
||||
|
||||
@ -4,9 +4,10 @@ from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip",
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Sequence
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.plotting.matches import (
|
||||
MatchProtocol,
|
||||
plot_cross_trigger_match,
|
||||
@ -10,7 +11,7 @@ from batdetect2.plotting.matches import (
|
||||
plot_false_positive_match,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = ["plot_match_gallery"]
|
||||
|
||||
|
||||
@ -4,12 +4,10 @@ from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Detection,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_false_positive_match",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from batdetect2.typing import ClipDetections
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class ClipTransform:
|
||||
|
||||
@ -5,11 +5,11 @@ from typing import List
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import (
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsArray,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"to_raw_predictions",
|
||||
|
||||
@ -19,7 +19,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
from batdetect2.postprocess.types import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"extract_detection_peaks",
|
||||
|
||||
@ -1,18 +1,18 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_postprocessor",
|
||||
|
||||
@ -19,8 +19,8 @@ import torch
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.postprocess.types import ClipDetectionsTensor
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
|
||||
85
src/batdetect2/postprocess/types.py
Normal file
85
src/batdetect2/postprocess/types.py
Normal file
@ -0,0 +1,85 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, NamedTuple, Protocol
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets.types import Position, Size
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.models.types import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"ClipDetections",
|
||||
"ClipDetectionsArray",
|
||||
"ClipDetectionsTensor",
|
||||
"ClipPrediction",
|
||||
"Detection",
|
||||
"GeometryDecoder",
|
||||
"PostprocessorProtocol",
|
||||
]
|
||||
|
||||
|
||||
class GeometryDecoder(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: str | None = None,
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class ClipDetectionsArray(NamedTuple):
|
||||
scores: np.ndarray
|
||||
sizes: np.ndarray
|
||||
class_scores: np.ndarray
|
||||
times: np.ndarray
|
||||
frequencies: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class ClipDetectionsTensor(NamedTuple):
|
||||
scores: torch.Tensor
|
||||
sizes: torch.Tensor
|
||||
class_scores: torch.Tensor
|
||||
times: torch.Tensor
|
||||
frequencies: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
def numpy(self) -> ClipDetectionsArray:
|
||||
return ClipDetectionsArray(
|
||||
scores=self.scores.detach().cpu().numpy(),
|
||||
sizes=self.sizes.detach().cpu().numpy(),
|
||||
class_scores=self.class_scores.detach().cpu().numpy(),
|
||||
times=self.times.detach().cpu().numpy(),
|
||||
frequencies=self.frequencies.detach().cpu().numpy(),
|
||||
features=self.features.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipDetections:
|
||||
clip: data.Clip
|
||||
detections: list[Detection]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipPrediction:
|
||||
clip: data.Clip
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
def __call__(
|
||||
self, output: "ModelOutput"
|
||||
) -> list[ClipDetectionsTensor]: ...
|
||||
@ -1,7 +1,7 @@
|
||||
"""Assembles the full batdetect2 preprocessing pipeline.
|
||||
|
||||
This module defines :class:`Preprocessor`, the concrete implementation of
|
||||
:class:`~batdetect2.typing.PreprocessorProtocol`, and the
|
||||
:class:`~batdetect2.preprocess.types.PreprocessorProtocol`, and the
|
||||
:func:`build_preprocessor` factory function that constructs it from a
|
||||
:class:`~batdetect2.preprocess.config.PreprocessingConfig`.
|
||||
|
||||
@ -33,7 +33,7 @@ from batdetect2.preprocess.spectrogram import (
|
||||
build_spectrogram_resizer,
|
||||
build_spectrogram_transform,
|
||||
)
|
||||
from batdetect2.typing import PreprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"Preprocessor",
|
||||
@ -42,7 +42,7 @@ __all__ = [
|
||||
|
||||
|
||||
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`.
|
||||
"""Standard implementation of the :class:`~batdetect2.preprocess.types.PreprocessorProtocol`.
|
||||
|
||||
Wraps all preprocessing stages as ``torch.nn.Module`` submodules so
|
||||
that parameters (e.g. PCEN filter coefficients) can be tracked and
|
||||
|
||||
31
src/batdetect2/preprocess/types.py
Normal file
31
src/batdetect2/preprocess/types.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"PreprocessorProtocol",
|
||||
"SpectrogramBuilder",
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramBuilder(Protocol):
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
input_samplerate: int
|
||||
output_samplerate: float
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||
return self(torch.tensor(wav)).numpy()
|
||||
@ -16,7 +16,7 @@ from batdetect2.data.conditions import (
|
||||
)
|
||||
from batdetect2.targets.rois import ROIMapperConfig
|
||||
from batdetect2.targets.terms import call_type, generic_class
|
||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
|
||||
|
||||
__all__ = [
|
||||
"build_sound_event_decoder",
|
||||
|
||||
@ -27,17 +27,13 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.core.arrays import spec_to_xarray
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Position,
|
||||
PreprocessorProtocol,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
)
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import Position, ROITargetMapper, Size
|
||||
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
|
||||
@ -16,7 +16,7 @@ from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
||||
|
||||
|
||||
class Targets(TargetProtocol):
|
||||
|
||||
60
src/batdetect2/targets/types.py
Normal file
60
src/batdetect2/targets/types.py
Normal file
@ -0,0 +1,60 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"Position",
|
||||
"ROITargetMapper",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"TargetProtocol",
|
||||
]
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None]
|
||||
SoundEventDecoder = Callable[[str], list[data.Tag]]
|
||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
Position = tuple[float, float]
|
||||
Size = np.ndarray
|
||||
|
||||
|
||||
class TargetProtocol(Protocol):
|
||||
class_names: list[str]
|
||||
detection_class_tags: list[data.Tag]
|
||||
detection_class_name: str
|
||||
dimension_names: list[str]
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||
|
||||
def encode_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> str | None: ...
|
||||
|
||||
def decode_class(self, class_label: str) -> list[data.Tag]: ...
|
||||
|
||||
def encode_roi(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> tuple[Position, Size]: ...
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: str | None = None,
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
dimension_names: list[str]
|
||||
|
||||
def encode(
|
||||
self, sound_event: data.SoundEvent
|
||||
) -> tuple[Position, Size]: ...
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry: ...
|
||||
@ -13,6 +13,7 @@ from soundevent.geometry import scale_geometry, shift_geometry
|
||||
|
||||
from batdetect2.audio.clips import get_subclip_annotation
|
||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import (
|
||||
@ -20,7 +21,7 @@ from batdetect2.core.registries import (
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.typing import AudioLoader, Augmentation
|
||||
from batdetect2.train.types import Augmentation
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
|
||||
@ -5,17 +5,15 @@ from lightning.pytorch.callbacks import Callback
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
EvaluatorProtocol,
|
||||
ModelOutput,
|
||||
TrainExample,
|
||||
)
|
||||
from batdetect2.train.types import TrainExample
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
|
||||
@ -8,9 +8,11 @@ from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
@ -18,14 +20,7 @@ from batdetect2.train.augmentations import (
|
||||
build_augmentations,
|
||||
)
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Augmentation,
|
||||
ClipLabeller,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
TrainExample,
|
||||
)
|
||||
from batdetect2.train.types import Augmentation, ClipLabeller, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"TrainingDataset",
|
||||
|
||||
@ -15,7 +15,8 @@ from soundevent import data
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.types import ClipLabeller, Heatmaps
|
||||
|
||||
__all__ = [
|
||||
"LabelConfig",
|
||||
|
||||
@ -2,11 +2,12 @@ import lightning as L
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train.optimizers import build_optimizer
|
||||
from batdetect2.train.schedulers import build_scheduler
|
||||
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
|
||||
from batdetect2.train.types import LossProtocol, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
|
||||
@ -26,7 +26,8 @@ from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.train.types import Losses, LossProtocol, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"BBoxLoss",
|
||||
|
||||
@ -43,7 +43,7 @@ optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(optimizer_registry)
|
||||
@add_import_config(optimizer_registry, arg_names=["params"])
|
||||
class OptimizerImportConfig(ImportConfig):
|
||||
"""Use any callable as an optimizer.
|
||||
|
||||
@ -84,4 +84,4 @@ def build_optimizer(
|
||||
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
||||
"""
|
||||
config = config or AdamOptimizerConfig()
|
||||
return optimizer_registry.build(config, params=parameters)
|
||||
return optimizer_registry.build(config, parameters)
|
||||
|
||||
@ -40,7 +40,7 @@ class CosineAnnealingSchedulerConfig(BaseConfig):
|
||||
scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler")
|
||||
|
||||
|
||||
@add_import_config(scheduler_registry)
|
||||
@add_import_config(scheduler_registry, arg_names=["optimizer"])
|
||||
class SchedulerImportConfig(ImportConfig):
|
||||
"""Use any callable as a scheduler.
|
||||
|
||||
@ -78,4 +78,4 @@ def build_scheduler(
|
||||
"""Build a scheduler from configuration."""
|
||||
config = config or CosineAnnealingSchedulerConfig()
|
||||
|
||||
return scheduler_registry.build(config, optimizer=optimizer)
|
||||
return scheduler_registry.build(config, optimizer)
|
||||
|
||||
@ -1,32 +1,28 @@
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
from lightning import Trainer, seed_everything
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipLabeller,
|
||||
EvaluatorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
__all__ = [
|
||||
"build_trainer",
|
||||
|
||||
70
src/batdetect2/train/types.py
Normal file
70
src/batdetect2/train/types.py
Normal file
@ -0,0 +1,70 @@
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.models.types import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Augmentation",
|
||||
"ClipLabeller",
|
||||
"Heatmaps",
|
||||
"Losses",
|
||||
"LossProtocol",
|
||||
"TrainExample",
|
||||
]
|
||||
|
||||
|
||||
class Heatmaps(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
classes: torch.Tensor
|
||||
size: torch.Tensor
|
||||
|
||||
|
||||
class PreprocessedExample(NamedTuple):
|
||||
audio: torch.Tensor
|
||||
spectrogram: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
def copy(self):
|
||||
return PreprocessedExample(
|
||||
audio=self.audio.clone(),
|
||||
spectrogram=self.spectrogram.clone(),
|
||||
detection_heatmap=self.detection_heatmap.clone(),
|
||||
size_heatmap=self.size_heatmap.clone(),
|
||||
class_heatmap=self.class_heatmap.clone(),
|
||||
)
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||
|
||||
|
||||
Augmentation = Callable[
|
||||
[torch.Tensor, data.ClipAnnotation],
|
||||
tuple[torch.Tensor, data.ClipAnnotation],
|
||||
]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: torch.Tensor
|
||||
end_time: torch.Tensor
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
|
||||
|
||||
class LossProtocol(Protocol):
|
||||
def __call__(self, pred: "ModelOutput", gt: TrainExample) -> Losses: ...
|
||||
@ -1,18 +1,14 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, NamedTuple, TypedDict
|
||||
import sys
|
||||
from typing import Any, NamedTuple, Protocol, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
try:
|
||||
from typing import NotRequired # type: ignore
|
||||
except ImportError:
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import NotRequired
|
||||
else:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
from batdetect2.typing.data import OutputFormatterProtocol
|
||||
from batdetect2.typing.evaluate import (
|
||||
AffinityFunction,
|
||||
ClipMatches,
|
||||
EvaluatorProtocol,
|
||||
MatcherProtocol,
|
||||
MatchEvaluation,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
ClipDetections,
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
GeometryDecoder,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.targets import (
|
||||
Position,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
SoundEventFilter,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.train import (
|
||||
Augmentation,
|
||||
ClipLabeller,
|
||||
ClipperProtocol,
|
||||
Heatmaps,
|
||||
Losses,
|
||||
LossProtocol,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AffinityFunction",
|
||||
"AudioLoader",
|
||||
"Augmentation",
|
||||
"BackboneModel",
|
||||
"ClipDetections",
|
||||
"ClipDetectionsTensor",
|
||||
"ClipLabeller",
|
||||
"ClipMatches",
|
||||
"ClipperProtocol",
|
||||
"DetectionModel",
|
||||
"EvaluatorProtocol",
|
||||
"GeometryDecoder",
|
||||
"Heatmaps",
|
||||
"LossProtocol",
|
||||
"Losses",
|
||||
"MatchEvaluation",
|
||||
"MatcherProtocol",
|
||||
"MetricsProtocol",
|
||||
"ModelOutput",
|
||||
"OutputFormatterProtocol",
|
||||
"PlotterProtocol",
|
||||
"Position",
|
||||
"PostprocessorProtocol",
|
||||
"PreprocessorProtocol",
|
||||
"ROITargetMapper",
|
||||
"Detection",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"TargetProtocol",
|
||||
"TrainExample",
|
||||
]
|
||||
@ -1,287 +0,0 @@
|
||||
"""Defines shared interfaces (ABCs) and data structures for models.
|
||||
|
||||
This module centralizes the definitions of core data structures, like the
|
||||
standard model output container (`ModelOutput`), and establishes abstract base
|
||||
classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts
|
||||
for fundamental model components, ensuring modularity and consistent
|
||||
interaction within the `batdetect2.models` package.
|
||||
|
||||
Key components:
|
||||
- `ModelOutput`: Standard structure for outputs from detection models.
|
||||
- `BackboneModel`: Generic interface for any feature extraction backbone.
|
||||
- `EncoderDecoderModel`: Specialized interface for backbones with distinct
|
||||
encoder-decoder stages (e.g., U-Net), providing access to intermediate
|
||||
features.
|
||||
- `DetectionModel`: Interface for the complete end-to-end detection model.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
"BackboneModel",
|
||||
"EncoderDecoderModel",
|
||||
"DetectionModel",
|
||||
"BlockProtocol",
|
||||
"EncoderProtocol",
|
||||
"BottleneckProtocol",
|
||||
"DecoderProtocol",
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
"""Interface for blocks of network layers."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the block."""
|
||||
...
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
"""Calculate the output height based on input height."""
|
||||
...
|
||||
|
||||
|
||||
class EncoderProtocol(Protocol):
|
||||
"""Interface for the downsampling path of a network."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Forward pass must return intermediate tensors for skip connections."""
|
||||
...
|
||||
|
||||
|
||||
class BottleneckProtocol(Protocol):
|
||||
"""Interface for the middle part of a U-Net-like network."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes the features from the encoder."""
|
||||
...
|
||||
|
||||
|
||||
class DecoderProtocol(Protocol):
|
||||
"""Interface for the upsampling reconstruction path."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
depth: int
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Upsamples features while integrating skip connections."""
|
||||
...
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Standard container for the outputs of a BatDetect2 detection model.
|
||||
|
||||
This structure groups the different prediction tensors produced by the
|
||||
model for a batch of input spectrograms. All tensors typically share the
|
||||
same spatial dimensions (height H, width W) corresponding to the model's
|
||||
output resolution, and the same batch size (N).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection_probs : torch.Tensor
|
||||
Tensor containing the probability of sound event presence at each
|
||||
location in the output grid.
|
||||
Shape: `(N, 1, H, W)`
|
||||
size_preds : torch.Tensor
|
||||
Tensor containing the predicted size dimensions
|
||||
(e.g., width and height) for a potential bounding box at each location.
|
||||
Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height)
|
||||
class_probs : torch.Tensor
|
||||
Tensor containing the predicted probabilities (or logits, depending on
|
||||
the final activation) for each target class at each location.
|
||||
The number of channels corresponds to the number of specific classes
|
||||
defined in the `Targets` configuration.
|
||||
Shape: `(N, num_classes, H, W)`
|
||||
features : torch.Tensor
|
||||
Tensor containing features extracted by the model's backbone. These
|
||||
might be used for downstream tasks or analysis. The number of channels
|
||||
depends on the specific model architecture.
|
||||
Shape: `(N, num_features, H, W)`
|
||||
"""
|
||||
|
||||
detection_probs: torch.Tensor
|
||||
size_preds: torch.Tensor
|
||||
class_probs: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
class BackboneModel(ABC, torch.nn.Module):
|
||||
"""Abstract Base Class for generic feature extraction backbone models.
|
||||
|
||||
Defines the minimal interface for a feature extractor network within a
|
||||
BatDetect2 model. Its primary role is to process an input spectrogram
|
||||
tensor and produce a spatially rich feature map tensor, which is then
|
||||
typically consumed by separate prediction heads (for detection,
|
||||
classification, size).
|
||||
|
||||
This base class is agnostic to the specific internal architecture (e.g.,
|
||||
it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete
|
||||
implementations must inherit from this class and `torch.nn.Module`,
|
||||
implement the `forward` method, and define the required attributes.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
Expected height (number of frequency bins) of the input spectrogram
|
||||
tensor that the backbone is designed to process.
|
||||
out_channels : int
|
||||
Number of channels in the final feature map tensor produced by the
|
||||
backbone's `forward` method.
|
||||
"""
|
||||
|
||||
input_height: int
|
||||
"""Expected input spectrogram height (frequency bins)."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels in the final feature map."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform the forward pass to extract features from the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
`frequency_bins` should match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map tensor, typically with shape
|
||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||
The spatial dimensions (`output_height`, `output_width`) depend
|
||||
on the specific backbone architecture (e.g., they might match the
|
||||
input or be downsampled).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EncoderDecoderModel(BackboneModel):
|
||||
"""Abstract Base Class for Encoder-Decoder style backbone models.
|
||||
|
||||
This class specializes `BackboneModel` for architectures that have distinct
|
||||
encoder stages (downsampling path), a bottleneck, and decoder stages
|
||||
(upsampling path).
|
||||
|
||||
It provides separate abstract methods for the `encode` and `decode` steps,
|
||||
allowing access to the intermediate "bottleneck" features produced by the
|
||||
encoder. This can be useful for tasks like transfer learning or specialized
|
||||
analyses.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
(Inherited from BackboneModel) Expected input spectrogram height.
|
||||
out_channels : int
|
||||
(Inherited from BackboneModel) Number of output channels in the final
|
||||
feature map produced by the decoder/forward pass.
|
||||
bottleneck_channels : int
|
||||
Number of channels in the feature map produced by the encoder at its
|
||||
deepest point (the bottleneck), before the decoder starts.
|
||||
"""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels at the encoder's bottleneck."""
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Process the input spectrogram through the encoder part.
|
||||
|
||||
Takes the input spectrogram and passes it through the downsampling path
|
||||
of the network up to the bottleneck layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The encoded feature map from the bottleneck layer, typically with
|
||||
shape `(batch_size, self.bottleneck_channels, bottleneck_height,
|
||||
bottleneck_width)`. The spatial dimensions are usually downsampled
|
||||
relative to the input.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor:
|
||||
"""Process the bottleneck features through the decoder part.
|
||||
|
||||
Takes the encoded feature map from the bottleneck and passes it through
|
||||
the upsampling path (potentially using skip connections from the
|
||||
encoder) to produce the final output feature map.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoded : torch.Tensor
|
||||
The bottleneck feature map tensor produced by the `encode` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final output feature map tensor, typically with shape
|
||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||
This should match the output shape of the `forward` method.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DetectionModel(ABC, torch.nn.Module):
|
||||
"""Abstract Base Class for complete BatDetect2 detection models.
|
||||
|
||||
Defines the interface for the overall model that takes an input spectrogram
|
||||
and produces all necessary outputs for detection, classification, and size
|
||||
prediction, packaged within a `ModelOutput` object.
|
||||
|
||||
Concrete implementations typically combine a `BackboneModel` for feature
|
||||
extraction with specific prediction heads for each output type. They must
|
||||
inherit from this class and `torch.nn.Module`, and implement the `forward`
|
||||
method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Perform the forward pass of the full detection model.
|
||||
|
||||
Processes the input spectrogram through the backbone and prediction
|
||||
heads to generate all required output tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ModelOutput
|
||||
A NamedTuple containing the prediction tensors: `detection_probs`,
|
||||
`size_preds`, `class_probs`, and `features`.
|
||||
"""
|
||||
@ -1,104 +0,0 @@
|
||||
"""Defines shared interfaces and data structures for postprocessing.
|
||||
|
||||
This module centralizes the Protocol definitions and common data structures
|
||||
used throughout the `batdetect2.postprocess` module.
|
||||
|
||||
The main component is the `PostprocessorProtocol`, which outlines the standard
|
||||
interface for an object responsible for executing the entire postprocessing
|
||||
pipeline. This pipeline transforms raw neural network outputs into interpretable
|
||||
detections represented as `soundevent` objects. Using protocols ensures
|
||||
modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Protocol
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing.targets import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"Detection",
|
||||
"PostprocessorProtocol",
|
||||
"GeometryDecoder",
|
||||
]
|
||||
|
||||
|
||||
# TODO: update the docstring
|
||||
class GeometryDecoder(Protocol):
|
||||
"""Type alias for a function that recovers geometry from position and size.
|
||||
|
||||
This callable takes:
|
||||
1. A position tuple `(time, frequency)`.
|
||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||
3. Optionally a class name of the highest scoring class. This is to accomodate
|
||||
different ways of decoding geometry that depend on the predicted class.
|
||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||
`BoundingBox`).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, position: Position, size: Size, class_name: str | None = None
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class ClipDetectionsArray(NamedTuple):
|
||||
scores: np.ndarray
|
||||
sizes: np.ndarray
|
||||
class_scores: np.ndarray
|
||||
times: np.ndarray
|
||||
frequencies: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class ClipDetectionsTensor(NamedTuple):
|
||||
scores: torch.Tensor
|
||||
sizes: torch.Tensor
|
||||
class_scores: torch.Tensor
|
||||
times: torch.Tensor
|
||||
frequencies: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
def numpy(self) -> ClipDetectionsArray:
|
||||
return ClipDetectionsArray(
|
||||
scores=self.scores.detach().cpu().numpy(),
|
||||
sizes=self.sizes.detach().cpu().numpy(),
|
||||
class_scores=self.class_scores.detach().cpu().numpy(),
|
||||
times=self.times.detach().cpu().numpy(),
|
||||
frequencies=self.frequencies.detach().cpu().numpy(),
|
||||
features=self.features.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipDetections:
|
||||
clip: data.Clip
|
||||
detections: List[Detection]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipPrediction:
|
||||
clip: data.Clip
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
) -> List[ClipDetectionsTensor]: ...
|
||||
@ -1,168 +0,0 @@
|
||||
"""Defines common interfaces (Protocols) for preprocessing components.
|
||||
|
||||
This module centralizes the Protocol definitions used throughout the
|
||||
`batdetect2.preprocess` package. Protocols define expected methods and
|
||||
signatures, allowing for flexible and interchangeable implementations of
|
||||
components like audio loaders and spectrogram builders.
|
||||
|
||||
Using these protocols ensures that different parts of the preprocessing
|
||||
pipeline can interact consistently, regardless of the specific underlying
|
||||
implementation (e.g., different libraries or custom configurations).
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"AudioLoader",
|
||||
"SpectrogramBuilder",
|
||||
"PreprocessorProtocol",
|
||||
]
|
||||
|
||||
|
||||
class AudioLoader(Protocol):
|
||||
"""Defines the interface for an audio loading and processing component.
|
||||
|
||||
An AudioLoader is responsible for retrieving audio data corresponding to
|
||||
different soundevent objects (files, Recordings, Clips) and applying a
|
||||
configured set of initial preprocessing steps. Adhering to this protocol
|
||||
allows for different loading strategies or implementations.
|
||||
"""
|
||||
|
||||
samplerate: int
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix to prepend to the path if `path` is relative.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If the audio file cannot be loaded or processed.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object containing metadata about the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory where the audio file associated with the recording
|
||||
can be found, especially if the path in the recording is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The loaded and preprocessed audio waveform as a 1-D NumPy
|
||||
array. Typically loads only the first channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file associated with the recording cannot be found.
|
||||
Exception
|
||||
If the audio file cannot be loaded or processed.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object specifying the recording and the start/end times
|
||||
of the segment to load.
|
||||
audio_dir : PathLike, optional
|
||||
A directory where the audio file associated with the clip's
|
||||
recording can be found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The loaded and preprocessed audio waveform for the specified
|
||||
clip duration as a 1-D NumPy array. Typically loads only the
|
||||
first channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file associated with the clip cannot be found.
|
||||
Exception
|
||||
If the audio file cannot be loaded or processed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SpectrogramBuilder(Protocol):
|
||||
"""Defines the interface for a spectrogram generation component."""
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate a spectrogram from an audio waveform."""
|
||||
...
|
||||
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
||||
|
||||
max_freq: float
|
||||
|
||||
min_freq: float
|
||||
|
||||
input_samplerate: int
|
||||
|
||||
output_samplerate: float
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||
"""Run the full preprocessing pipeline on a NumPy waveform.
|
||||
|
||||
This default implementation converts the array to a
|
||||
``torch.Tensor``, calls :meth:`__call__`, and converts the
|
||||
result back to a NumPy array. Concrete implementations may
|
||||
override this for efficiency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : np.ndarray
|
||||
Input waveform as a 1-D NumPy array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Preprocessed spectrogram as a NumPy array.
|
||||
"""
|
||||
return self(torch.tensor(wav)).numpy()
|
||||
@ -1,298 +0,0 @@
|
||||
"""Defines the core interface (Protocol) for the target definition pipeline.
|
||||
|
||||
This module specifies the standard structure, attributes, and methods expected
|
||||
from an object that encapsulates the complete configured logic for processing
|
||||
sound event annotations within the `batdetect2.targets` system.
|
||||
|
||||
The main component defined here is the `TargetProtocol`. This protocol acts as
|
||||
a contract for the entire target definition process, covering semantic aspects
|
||||
(filtering, tag transformation, class encoding/decoding) as well as geometric
|
||||
aspects (mapping regions of interest to target positions and sizes). It ensures
|
||||
that components responsible for these tasks can be interacted with consistently
|
||||
throughout BatDetect2.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import List, Protocol
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"TargetProtocol",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventFilter",
|
||||
"Position",
|
||||
"Size",
|
||||
]
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
|
||||
An encoder function takes a sound event annotation and returns the string name
|
||||
of the target class it belongs to, based on a predefined set of rules.
|
||||
If the annotation does not match any defined target class according to the
|
||||
rules, the function returns None.
|
||||
"""
|
||||
|
||||
|
||||
SoundEventDecoder = Callable[[str], List[data.Tag]]
|
||||
"""Type alias for a sound event class decoder function.
|
||||
|
||||
A decoder function takes a class name string (as predicted by the model or
|
||||
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
|
||||
that represent that class according to the configuration. This is used to
|
||||
translate model outputs back into meaningful annotations.
|
||||
"""
|
||||
|
||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||
"""Type alias for a filter function.
|
||||
|
||||
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
||||
and returns True if the annotation should be kept based on the filter's
|
||||
criteria, or False if it should be discarded.
|
||||
"""
|
||||
|
||||
Position = tuple[float, float]
|
||||
"""A tuple representing (time, frequency) coordinates."""
|
||||
|
||||
Size = np.ndarray
|
||||
"""A NumPy array representing the size dimensions of a target."""
|
||||
|
||||
|
||||
class TargetProtocol(Protocol):
|
||||
"""Protocol defining the interface for the target definition pipeline.
|
||||
|
||||
This protocol outlines the standard attributes and methods for an object
|
||||
that encapsulates the complete, configured process for handling sound event
|
||||
annotations (both tags and geometry). It defines how to:
|
||||
- Select relevant annotations.
|
||||
- Encode an annotation into a specific target class name.
|
||||
- Decode a class name back into representative tags.
|
||||
- Extract a target reference position from an annotation's geometry (ROI).
|
||||
- Calculate target size dimensions from an annotation's geometry.
|
||||
- Recover an approximate geometry (ROI) from a position and size
|
||||
dimensions.
|
||||
|
||||
Implementations of this protocol bundle all configured logic for these
|
||||
steps.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined by the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (e.g., used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the size dimensions returned by
|
||||
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
"""Ordered list of unique names for the specific target classes."""
|
||||
|
||||
detection_class_tags: List[data.Tag]
|
||||
"""List of tags representing the detection category (unclassified)."""
|
||||
|
||||
detection_class_name: str
|
||||
|
||||
dimension_names: List[str]
|
||||
"""Names of the size dimensions (e.g., ['width', 'height'])."""
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the filter to a sound event annotation.
|
||||
|
||||
Determines if the annotation should be included in further processing
|
||||
and training based on the configured filtering rules.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation should be kept (passes the filter),
|
||||
False otherwise. Implementations should return True if no
|
||||
filtering is configured.
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> str | None:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The (potentially filtered and transformed) annotation to encode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The string name of the matched target class if the annotation
|
||||
matches a specific class definition. Returns None if the annotation
|
||||
does not match any specific class rule (indicating it may belong
|
||||
to a generic category or should be handled differently downstream).
|
||||
"""
|
||||
...
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_label : str
|
||||
The class name string (e.g., predicted by a model) to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags corresponding to the input class name according
|
||||
to the configuration. May return an empty list or raise an error
|
||||
for unmapped labels, depending on the implementation's configuration
|
||||
(e.g., `raise_on_unmapped` flag during building).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, KeyError
|
||||
Implementations might raise an error if the `class_label` is not
|
||||
found in the configured mapping and error raising is enabled.
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's geometry.
|
||||
|
||||
Calculates the `(time, frequency)` coordinate representing the primary
|
||||
location of the sound event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI) to process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The calculated reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry or if the position cannot be
|
||||
calculated for the geometry type or configured reference point.
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: Update docstrings
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: str | None = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover the ROI geometry from a position and dimensions.
|
||||
|
||||
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
||||
a reference position `(time, frequency)` and an array of size
|
||||
dimensions and reconstructs an approximate geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
The NumPy array containing the dimensions (e.g., predicted
|
||||
by the model), corresponding to the order in `dimension_names`.
|
||||
class_name: str
|
||||
class
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of provided `dims` does not match `dimension_names`,
|
||||
if dimensions are invalid (e.g., negative after unscaling), or
|
||||
if reconstruction fails based on the configured position type.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
Specifies the `encode` and `decode` methods required for converting a
|
||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||
position and a size vector) and for recovering an approximate ROI from that
|
||||
representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the dimensions in the `Size` array
|
||||
returned by `encode` and expected by `decode`.
|
||||
"""
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||
"""Encode a SoundEvent's geometry into a position and size.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event, which must have a geometry attribute.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
A tuple containing:
|
||||
- The reference position as (time, frequency) coordinates.
|
||||
- A NumPy array with the calculated size dimensions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the sound event does not have a geometry.
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Decode a position and size back into a geometric ROI.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and size
|
||||
dimensions and reconstructs a geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Position
|
||||
The reference position (time, frequency).
|
||||
size : Size
|
||||
NumPy array containing the size dimensions, matching the order
|
||||
and meaning specified by `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry, typically a `BoundingBox`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the `size` array has an unexpected shape or if reconstruction
|
||||
fails.
|
||||
"""
|
||||
...
|
||||
@ -1,108 +0,0 @@
|
||||
from typing import Callable, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Augmentation",
|
||||
"ClipLabeller",
|
||||
"ClipperProtocol",
|
||||
"Heatmaps",
|
||||
"LossProtocol",
|
||||
"Losses",
|
||||
"TrainExample",
|
||||
]
|
||||
|
||||
|
||||
class Heatmaps(NamedTuple):
|
||||
"""Structure holding the generated heatmap targets."""
|
||||
|
||||
detection: torch.Tensor
|
||||
classes: torch.Tensor
|
||||
size: torch.Tensor
|
||||
|
||||
|
||||
class PreprocessedExample(NamedTuple):
|
||||
audio: torch.Tensor
|
||||
spectrogram: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
def copy(self):
|
||||
return PreprocessedExample(
|
||||
audio=self.audio.clone(),
|
||||
spectrogram=self.spectrogram.clone(),
|
||||
detection_heatmap=self.detection_heatmap.clone(),
|
||||
size_heatmap=self.size_heatmap.clone(),
|
||||
class_heatmap=self.class_heatmap.clone(),
|
||||
)
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||
"""Type alias for the final clip labelling function.
|
||||
|
||||
This function takes the complete annotations for a clip and the corresponding
|
||||
spectrogram, applies all configured filtering, transformation, and encoding
|
||||
steps, and returns the final `Heatmaps` used for model training.
|
||||
"""
|
||||
|
||||
|
||||
Augmentation = Callable[
|
||||
[torch.Tensor, data.ClipAnnotation],
|
||||
tuple[torch.Tensor, data.ClipAnnotation],
|
||||
]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: torch.Tensor
|
||||
end_time: torch.Tensor
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
"""Structure to hold the computed loss values.
|
||||
|
||||
Allows returning individual loss components along with the total weighted
|
||||
loss for monitoring and analysis during training.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : torch.Tensor
|
||||
Scalar tensor representing the calculated detection loss component
|
||||
(before weighting).
|
||||
size : torch.Tensor
|
||||
Scalar tensor representing the calculated size regression loss component
|
||||
(before weighting).
|
||||
classification : torch.Tensor
|
||||
Scalar tensor representing the calculated classification loss component
|
||||
(before weighting).
|
||||
total : torch.Tensor
|
||||
Scalar tensor representing the final combined loss, computed as the
|
||||
weighted sum of the detection, size, and classification components.
|
||||
This is the value typically used for backpropagation.
|
||||
"""
|
||||
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
|
||||
|
||||
class LossProtocol(Protocol):
|
||||
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
|
||||
|
||||
|
||||
class ClipperProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation: ...
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
||||
@ -11,23 +11,20 @@ from soundevent import data, terms
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.clips import build_clipper
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
build_targets,
|
||||
call_type,
|
||||
)
|
||||
from batdetect2.targets.classes import TargetClassConfig
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.typing import (
|
||||
ClipLabeller,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.typing.train import ClipperProtocol
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -9,11 +9,8 @@ from batdetect2.outputs.formats import (
|
||||
ParquetOutputConfig,
|
||||
build_output_formatter,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -5,11 +5,8 @@ import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -4,7 +4,7 @@ import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing import Detection
|
||||
from batdetect2.postprocess.types import Detection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -4,8 +4,8 @@ from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.typing import ClipDetections
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def test_classification(
|
||||
|
||||
@ -4,8 +4,8 @@ from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.typing import ClipDetections
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def test_detection(
|
||||
|
||||
@ -13,7 +13,7 @@ from batdetect2.models.backbones import (
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
|
||||
def test_unet_backbone_config_defaults():
|
||||
|
||||
@ -7,7 +7,7 @@ from batdetect2.models.backbones import UNetBackboneConfig
|
||||
from batdetect2.models.detectors import Detector, build_detector
|
||||
from batdetect2.models.encoder import Encoder
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.models.types import ModelOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -3,7 +3,7 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.typing import ClipDetections, Detection
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
|
||||
|
||||
def test_shift_time_to_clip_start(clip: data.Clip):
|
||||
|
||||
@ -14,7 +14,8 @@ from batdetect2.postprocess.decoding import (
|
||||
get_generic_tags,
|
||||
get_prediction_features,
|
||||
)
|
||||
from batdetect2.typing import Detection, TargetProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -8,6 +8,7 @@ from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.train import (
|
||||
@ -19,7 +20,6 @@ from batdetect2.train import (
|
||||
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||
from batdetect2.train.train import build_training_module
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
def build_default_module(config: BatDetect2Config | None = None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user