mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Moving types around to each submodule
This commit is contained in:
parent
c226dc3f2b
commit
751be53edf
2
.gitignore
vendored
2
.gitignore
vendored
@ -102,7 +102,7 @@ experiments/*
|
|||||||
DvcLiveLogger/checkpoints
|
DvcLiveLogger/checkpoints
|
||||||
logs/
|
logs/
|
||||||
mlruns/
|
mlruns/
|
||||||
outputs/
|
/outputs/
|
||||||
notebooks/lightning_logs
|
notebooks/lightning_logs
|
||||||
|
|
||||||
# Jupiter notebooks
|
# Jupiter notebooks
|
||||||
|
|||||||
@ -89,5 +89,5 @@ Crucial for training, this module translates physical annotations (Regions of In
|
|||||||
## Summary
|
## Summary
|
||||||
To navigate this codebase effectively:
|
To navigate this codebase effectively:
|
||||||
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
|
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
|
||||||
2. Rely heavily on the typed **Protocols** located in `src/batdetect2/typing/` to understand the inputs and outputs of each subsystem without needing to read the specific implementations.
|
2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation.
|
||||||
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.
|
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from soundevent import data
|
|||||||
from soundevent.audio.files import get_audio_files
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.core import merge_configs
|
from batdetect2.core import merge_configs
|
||||||
from batdetect2.data import (
|
from batdetect2.data import (
|
||||||
@ -15,6 +16,7 @@ from batdetect2.data import (
|
|||||||
)
|
)
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||||
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
@ -25,24 +27,22 @@ from batdetect2.outputs import (
|
|||||||
build_output_transform,
|
build_output_transform,
|
||||||
get_output_formatter,
|
get_output_formatter,
|
||||||
)
|
)
|
||||||
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||||
|
from batdetect2.postprocess.types import (
|
||||||
|
ClipDetections,
|
||||||
|
Detection,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
)
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
DEFAULT_CHECKPOINT_DIR,
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
run_train,
|
run_train,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
ClipDetections,
|
|
||||||
Detection,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
OutputFormatterProtocol,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
class BatDetect2API:
|
||||||
|
|||||||
@ -6,13 +6,13 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||||
|
|
||||||
|
from batdetect2.audio.types import ClipperProtocol
|
||||||
from batdetect2.core import (
|
from batdetect2.core import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
ImportConfig,
|
ImportConfig,
|
||||||
Registry,
|
Registry,
|
||||||
add_import_config,
|
add_import_config,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import ClipperProtocol
|
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from scipy.signal import resample, resample_poly
|
|||||||
from soundevent import audio, data
|
from soundevent import audio, data
|
||||||
from soundfile import LibsndfileError
|
from soundfile import LibsndfileError
|
||||||
|
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.typing import AudioLoader
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SoundEventAudioLoader",
|
"SoundEventAudioLoader",
|
||||||
|
|||||||
40
src/batdetect2/audio/types.py
Normal file
40
src/batdetect2/audio/types.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioLoader",
|
||||||
|
"ClipperProtocol",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLoader(Protocol):
|
||||||
|
samplerate: int
|
||||||
|
|
||||||
|
def load_file(
|
||||||
|
self,
|
||||||
|
path: data.PathLike,
|
||||||
|
audio_dir: data.PathLike | None = None,
|
||||||
|
) -> np.ndarray: ...
|
||||||
|
|
||||||
|
def load_recording(
|
||||||
|
self,
|
||||||
|
recording: data.Recording,
|
||||||
|
audio_dir: data.PathLike | None = None,
|
||||||
|
) -> np.ndarray: ...
|
||||||
|
|
||||||
|
def load_clip(
|
||||||
|
self,
|
||||||
|
clip: data.Clip,
|
||||||
|
audio_dir: data.PathLike | None = None,
|
||||||
|
) -> np.ndarray: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ClipperProtocol(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> data.ClipAnnotation: ...
|
||||||
|
|
||||||
|
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
||||||
@ -4,6 +4,7 @@ from typing import (
|
|||||||
Concatenate,
|
Concatenate,
|
||||||
Generic,
|
Generic,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
|
Sequence,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
@ -147,6 +148,7 @@ T_Import = TypeVar("T_Import", bound=ImportConfig)
|
|||||||
|
|
||||||
def add_import_config(
|
def add_import_config(
|
||||||
registry: Registry[T_Type, P_Type],
|
registry: Registry[T_Type, P_Type],
|
||||||
|
arg_names: Sequence[str] | None = None,
|
||||||
) -> Callable[[Type[T_Import]], Type[T_Import]]:
|
) -> Callable[[Type[T_Import]], Type[T_Import]]:
|
||||||
"""Decorator that registers an ImportConfig subclass as an escape hatch.
|
"""Decorator that registers an ImportConfig subclass as an escape hatch.
|
||||||
|
|
||||||
@ -181,15 +183,22 @@ def add_import_config(
|
|||||||
*args: P_Type.args,
|
*args: P_Type.args,
|
||||||
**kwargs: P_Type.kwargs,
|
**kwargs: P_Type.kwargs,
|
||||||
) -> T_Type:
|
) -> T_Type:
|
||||||
if len(args) > 0:
|
_arg_names = arg_names or []
|
||||||
|
|
||||||
|
if len(args) != len(_arg_names):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Positional arguments are not supported "
|
"Positional arguments are not supported "
|
||||||
"for import escape hatch."
|
"for import escape hatch unless you specify "
|
||||||
|
"the argument names. Use `arg_names` to specify "
|
||||||
|
"the names of the positional arguments."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
args_dict = {_arg_names[i]: args[i] for i in range(len(args))}
|
||||||
|
|
||||||
hydra_cfg = {
|
hydra_cfg = {
|
||||||
"_target_": config.target,
|
"_target_": config.target,
|
||||||
**config.arguments,
|
**config.arguments,
|
||||||
|
**args_dict,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
return instantiate(hydra_cfg)
|
return instantiate(hydra_cfg)
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from collections.abc import Generator
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from batdetect2.data.summary import (
|
|||||||
extract_recordings_df,
|
extract_recordings_df,
|
||||||
extract_sound_events_df,
|
extract_sound_events_df,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def split_dataset_by_recordings(
|
def split_dataset_by_recordings(
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import pandas as pd
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_recordings_df",
|
"extract_recordings_df",
|
||||||
|
|||||||
@ -16,7 +16,8 @@ from batdetect2.core import (
|
|||||||
Registry,
|
Registry,
|
||||||
add_import_config,
|
add_import_config,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import AffinityFunction, Detection
|
from batdetect2.evaluate.types import AffinityFunction
|
||||||
|
from batdetect2.postprocess.types import Detection
|
||||||
|
|
||||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||||
"affinity_function"
|
"affinity_function"
|
||||||
|
|||||||
@ -8,14 +8,11 @@ from torch.utils.data import DataLoader, Dataset
|
|||||||
|
|
||||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||||
from batdetect2.audio.clips import PaddedClipConfig
|
from batdetect2.audio.clips import PaddedClipConfig
|
||||||
|
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.typing import (
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
AudioLoader,
|
|
||||||
ClipperProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TestDataset",
|
"TestDataset",
|
||||||
|
|||||||
@ -5,22 +5,20 @@ from lightning import Trainer
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.evaluate.dataset import build_test_loader
|
from batdetect2.evaluate.dataset import build_test_loader
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import build_output_transform
|
from batdetect2.outputs import build_output_transform
|
||||||
from batdetect2.typing import Detection
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
|
from batdetect2.postprocess.types import Detection
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
OutputFormatterProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||||
|
|
||||||
|
|||||||
@ -5,9 +5,10 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.evaluate.tasks import build_task
|
from batdetect2.evaluate.tasks import build_task
|
||||||
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
|
|||||||
@ -5,12 +5,12 @@ from soundevent import data
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import EvaluatorProtocol
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
|
|||||||
@ -26,7 +26,8 @@ from batdetect2.evaluate.metrics.common import (
|
|||||||
average_precision,
|
average_precision,
|
||||||
compute_precision_recall,
|
compute_precision_recall,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import Detection, TargetProtocol
|
from batdetect2.postprocess.types import Detection
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassificationMetric",
|
"ClassificationMetric",
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from batdetect2.core import (
|
|||||||
add_import_config,
|
add_import_config,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.typing import Detection
|
from batdetect2.postprocess.types import Detection
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DetectionMetricConfig",
|
"DetectionMetricConfig",
|
||||||
|
|||||||
@ -20,8 +20,8 @@ from batdetect2.core import (
|
|||||||
add_import_config,
|
add_import_config,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.typing import Detection
|
from batdetect2.postprocess.types import Detection
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TopClassMetricConfig",
|
"TopClassMetricConfig",
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
|
|||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class BasePlotConfig(BaseConfig):
|
class BasePlotConfig(BaseConfig):
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from batdetect2.plotting.metrics import (
|
|||||||
plot_threshold_recall_curve,
|
plot_threshold_recall_curve,
|
||||||
plot_threshold_recall_curves,
|
plot_threshold_recall_curves,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
ClassificationPlotter = Callable[
|
ClassificationPlotter = Callable[
|
||||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from batdetect2.plotting.metrics import (
|
|||||||
plot_roc_curve,
|
plot_roc_curve,
|
||||||
plot_roc_curves,
|
plot_roc_curves,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClipClassificationPlotConfig",
|
"ClipClassificationPlotConfig",
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
|||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClipDetectionPlotConfig",
|
"ClipDetectionPlotConfig",
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||||
@ -23,7 +24,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|||||||
from batdetect2.plotting.detections import plot_clip_detections
|
from batdetect2.plotting.detections import plot_clip_detections
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.metrics.top_class import (
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
@ -27,7 +28,8 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
|
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
|
||||||
|
|
||||||
|
|||||||
@ -11,12 +11,10 @@ from batdetect2.evaluate.tasks.clip_classification import (
|
|||||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||||
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import (
|
from batdetect2.targets.types import TargetProtocol
|
||||||
ClipDetections,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TaskConfig",
|
"TaskConfig",
|
||||||
|
|||||||
@ -26,13 +26,12 @@ from batdetect2.evaluate.affinity import (
|
|||||||
TimeAffinityConfig,
|
TimeAffinityConfig,
|
||||||
build_affinity_function,
|
build_affinity_function,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.evaluate.types import (
|
||||||
AffinityFunction,
|
AffinityFunction,
|
||||||
ClipDetections,
|
|
||||||
Detection,
|
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
TargetProtocol,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseTaskConfig",
|
"BaseTaskConfig",
|
||||||
|
|||||||
@ -21,11 +21,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseSEDTaskConfig,
|
BaseSEDTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
ClipDetections,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
Detection,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||||
|
|||||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
|
|||||||
@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
|
|||||||
@ -20,8 +20,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseSEDTaskConfig,
|
BaseSEDTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
|
|||||||
@ -20,7 +20,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseSEDTaskConfig,
|
BaseSEDTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
|
|||||||
@ -1,45 +1,39 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Generic, Iterable, Protocol, Sequence, TypeVar
|
||||||
Generic,
|
|
||||||
Iterable,
|
|
||||||
Protocol,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import ClipDetections, Detection
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AffinityFunction",
|
||||||
|
"ClipMatches",
|
||||||
"EvaluatorProtocol",
|
"EvaluatorProtocol",
|
||||||
"MetricsProtocol",
|
|
||||||
"MatchEvaluation",
|
"MatchEvaluation",
|
||||||
|
"MatcherProtocol",
|
||||||
|
"MetricsProtocol",
|
||||||
|
"PlotterProtocol",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatchEvaluation:
|
class MatchEvaluation:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation | None
|
sound_event_annotation: data.SoundEventAnnotation | None
|
||||||
gt_det: bool
|
gt_det: bool
|
||||||
gt_class: str | None
|
gt_class: str | None
|
||||||
gt_geometry: data.Geometry | None
|
gt_geometry: data.Geometry | None
|
||||||
|
|
||||||
pred_score: float
|
pred_score: float
|
||||||
pred_class_scores: dict[str, float]
|
pred_class_scores: dict[str, float]
|
||||||
pred_geometry: data.Geometry | None
|
pred_geometry: data.Geometry | None
|
||||||
|
|
||||||
affinity: float
|
affinity: float
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def top_class(self) -> str | None:
|
def top_class(self) -> str | None:
|
||||||
if not self.pred_class_scores:
|
if not self.pred_class_scores:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -53,10 +47,8 @@ class MatchEvaluation:
|
|||||||
@property
|
@property
|
||||||
def top_class_score(self) -> float:
|
def top_class_score(self) -> float:
|
||||||
pred_class = self.top_class
|
pred_class = self.top_class
|
||||||
|
|
||||||
if pred_class is None:
|
if pred_class is None:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return self.pred_class_scores[pred_class]
|
return self.pred_class_scores[pred_class]
|
||||||
|
|
||||||
|
|
||||||
@ -75,9 +67,6 @@ class MatcherProtocol(Protocol):
|
|||||||
) -> Iterable[tuple[int | None, int | None, float]]: ...
|
) -> Iterable[tuple[int | None, int | None, float]]: ...
|
||||||
|
|
||||||
|
|
||||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
class AffinityFunction(Protocol):
|
class AffinityFunction(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -115,9 +104,11 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
|||||||
) -> EvaluationOutput: ...
|
) -> EvaluationOutput: ...
|
||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(
|
||||||
self, eval_outputs: EvaluationOutput
|
self,
|
||||||
|
eval_outputs: EvaluationOutput,
|
||||||
) -> dict[str, float]: ...
|
) -> dict[str, float]: ...
|
||||||
|
|
||||||
def generate_plots(
|
def generate_plots(
|
||||||
self, eval_outputs: EvaluationOutput
|
self,
|
||||||
|
eval_outputs: EvaluationOutput,
|
||||||
) -> Iterable[tuple[str, Figure]]: ...
|
) -> Iterable[tuple[str, Figure]]: ...
|
||||||
@ -4,22 +4,20 @@ from lightning import Trainer
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio.loader import build_audio_loader
|
from batdetect2.audio.loader import build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.inference.clips import get_clips_from_files
|
from batdetect2.inference.clips import get_clips_from_files
|
||||||
from batdetect2.inference.dataset import build_inference_loader
|
from batdetect2.inference.dataset import build_inference_loader
|
||||||
from batdetect2.inference.lightning import InferenceModule
|
from batdetect2.inference.lightning import InferenceModule
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.targets import build_targets
|
from batdetect2.targets.targets import build_targets
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_batch_inference(
|
def run_batch_inference(
|
||||||
|
|||||||
@ -6,10 +6,11 @@ from soundevent import data
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InferenceDataset",
|
"InferenceDataset",
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
|
|||||||
@ -62,16 +62,16 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
|
from batdetect2.models.types import DetectionModel
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
from batdetect2.postprocess.types import (
|
||||||
from batdetect2.targets.config import TargetConfig
|
|
||||||
from batdetect2.typing import (
|
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
DetectionModel,
|
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.config import TargetConfig
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
|
|||||||
@ -51,7 +51,7 @@ from batdetect2.models.encoder import (
|
|||||||
EncoderConfig,
|
EncoderConfig,
|
||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.models import (
|
from batdetect2.models.types import (
|
||||||
BackboneModel,
|
BackboneModel,
|
||||||
BottleneckProtocol,
|
BottleneckProtocol,
|
||||||
DecoderProtocol,
|
DecoderProtocol,
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from batdetect2.models.blocks import (
|
|||||||
VerticalConv,
|
VerticalConv,
|
||||||
build_layer,
|
build_layer,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.models import BottleneckProtocol
|
from batdetect2.models.types import BottleneckProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from batdetect2.models.backbones import (
|
|||||||
build_backbone,
|
build_backbone,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Detector",
|
"Detector",
|
||||||
|
|||||||
86
src/batdetect2/models/types.py
Normal file
86
src/batdetect2/models/types.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import NamedTuple, Protocol
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackboneModel",
|
||||||
|
"BlockProtocol",
|
||||||
|
"BottleneckProtocol",
|
||||||
|
"DecoderProtocol",
|
||||||
|
"DetectionModel",
|
||||||
|
"EncoderDecoderModel",
|
||||||
|
"EncoderProtocol",
|
||||||
|
"ModelOutput",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockProtocol(Protocol):
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int: ...
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderProtocol(Protocol):
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
output_height: int
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckProtocol(Protocol):
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderProtocol(Protocol):
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
output_height: int
|
||||||
|
depth: int
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residuals: list[torch.Tensor],
|
||||||
|
) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutput(NamedTuple):
|
||||||
|
detection_probs: torch.Tensor
|
||||||
|
size_preds: torch.Tensor
|
||||||
|
class_probs: torch.Tensor
|
||||||
|
features: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneModel(ABC, torch.nn.Module):
|
||||||
|
input_height: int
|
||||||
|
out_channels: int
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderModel(BackboneModel):
|
||||||
|
bottleneck_channels: int
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionModel(ABC, torch.nn.Module):
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||||
@ -11,7 +11,7 @@ from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
|
|||||||
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
|
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
|
||||||
from batdetect2.outputs.formats.raw import RawOutputConfig
|
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||||
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
|
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatDetect2OutputConfig",
|
"BatDetect2OutputConfig",
|
||||||
|
|||||||
@ -4,10 +4,8 @@ from typing import Literal
|
|||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||||
from batdetect2.typing import (
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
OutputFormatterProtocol,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OutputFormatterProtocol",
|
"OutputFormatterProtocol",
|
||||||
|
|||||||
@ -12,12 +12,9 @@ from batdetect2.outputs.formats.base import (
|
|||||||
output_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.targets import terms
|
from batdetect2.targets import terms
|
||||||
from batdetect2.typing import (
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
ClipDetections,
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
Detection,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
OutputFormatterProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import NotRequired # type: ignore
|
from typing import NotRequired # type: ignore
|
||||||
|
|||||||
@ -13,12 +13,9 @@ from batdetect2.outputs.formats.base import (
|
|||||||
make_path_relative,
|
make_path_relative,
|
||||||
output_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
ClipDetections,
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
Detection,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
OutputFormatterProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ParquetOutputConfig(BaseConfig):
|
class ParquetOutputConfig(BaseConfig):
|
||||||
|
|||||||
@ -14,12 +14,9 @@ from batdetect2.outputs.formats.base import (
|
|||||||
make_path_relative,
|
make_path_relative,
|
||||||
output_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
ClipDetections,
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
Detection,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
OutputFormatterProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RawOutputConfig(BaseConfig):
|
class RawOutputConfig(BaseConfig):
|
||||||
|
|||||||
@ -8,12 +8,9 @@ from batdetect2.core import BaseConfig
|
|||||||
from batdetect2.outputs.formats.base import (
|
from batdetect2.outputs.formats.base import (
|
||||||
output_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
ClipDetections,
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
Detection,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
OutputFormatterProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SoundEventOutputConfig(BaseConfig):
|
class SoundEventOutputConfig(BaseConfig):
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import Protocol
|
|||||||
from soundevent.geometry import shift_geometry
|
from soundevent.geometry import shift_geometry
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.typing import ClipDetections, Detection
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OutputTransform",
|
"OutputTransform",
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from typing import Generic, List, Protocol, Sequence, TypeVar
|
from collections.abc import Sequence
|
||||||
|
from typing import Generic, Protocol, TypeVar
|
||||||
|
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OutputFormatterProtocol",
|
"OutputFormatterProtocol",
|
||||||
@ -12,7 +13,7 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
class OutputFormatterProtocol(Protocol, Generic[T]):
|
class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||||
def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ...
|
def format(self, predictions: Sequence[ClipDetections]) -> list[T]: ...
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
@ -21,4 +22,4 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
|
|||||||
audio_dir: PathLike | None = None,
|
audio_dir: PathLike | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def load(self, path: PathLike) -> List[T]: ...
|
def load(self, path: PathLike) -> list[T]: ...
|
||||||
@ -3,8 +3,8 @@ from soundevent import data, plot
|
|||||||
|
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip_annotation",
|
"plot_clip_annotation",
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
|
|||||||
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
||||||
|
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip_prediction",
|
"plot_clip_prediction",
|
||||||
|
|||||||
@ -4,9 +4,10 @@ from matplotlib.axes import Axes
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.plotting.common import plot_spectrogram
|
from batdetect2.plotting.common import plot_spectrogram
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip",
|
"plot_clip",
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Sequence
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.plotting.matches import (
|
from batdetect2.plotting.matches import (
|
||||||
MatchProtocol,
|
MatchProtocol,
|
||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
@ -10,7 +11,7 @@ from batdetect2.plotting.matches import (
|
|||||||
plot_false_positive_match,
|
plot_false_positive_match,
|
||||||
plot_true_positive_match,
|
plot_true_positive_match,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = ["plot_match_gallery"]
|
__all__ = ["plot_match_gallery"]
|
||||||
|
|
||||||
|
|||||||
@ -4,12 +4,10 @@ from matplotlib.axes import Axes
|
|||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.typing import (
|
from batdetect2.postprocess.types import Detection
|
||||||
AudioLoader,
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
Detection,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_false_positive_match",
|
"plot_false_positive_match",
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from batdetect2.typing import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class ClipTransform:
|
class ClipTransform:
|
||||||
|
|||||||
@ -5,11 +5,11 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.postprocess.types import (
|
||||||
ClipDetectionsArray,
|
ClipDetectionsArray,
|
||||||
Detection,
|
Detection,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"to_raw_predictions",
|
"to_raw_predictions",
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
from batdetect2.postprocess.types import ClipDetectionsTensor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_detection_peaks",
|
"extract_detection_peaks",
|
||||||
|
|||||||
@ -1,18 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.postprocess.config import (
|
from batdetect2.postprocess.config import (
|
||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.extraction import extract_detection_peaks
|
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||||
from batdetect2.typing import ModelOutput
|
from batdetect2.postprocess.types import (
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
|
|||||||
@ -19,8 +19,8 @@ import torch
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent.arrays import Dimensions
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
|
from batdetect2.postprocess.types import ClipDetectionsTensor
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"features_to_xarray",
|
"features_to_xarray",
|
||||||
|
|||||||
85
src/batdetect2/postprocess/types.py
Normal file
85
src/batdetect2/postprocess/types.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, NamedTuple, Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets.types import Position, Size
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipDetections",
|
||||||
|
"ClipDetectionsArray",
|
||||||
|
"ClipDetectionsTensor",
|
||||||
|
"ClipPrediction",
|
||||||
|
"Detection",
|
||||||
|
"GeometryDecoder",
|
||||||
|
"PostprocessorProtocol",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GeometryDecoder(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
class_name: str | None = None,
|
||||||
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Detection:
|
||||||
|
geometry: data.Geometry
|
||||||
|
detection_score: float
|
||||||
|
class_scores: np.ndarray
|
||||||
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionsArray(NamedTuple):
|
||||||
|
scores: np.ndarray
|
||||||
|
sizes: np.ndarray
|
||||||
|
class_scores: np.ndarray
|
||||||
|
times: np.ndarray
|
||||||
|
frequencies: np.ndarray
|
||||||
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionsTensor(NamedTuple):
|
||||||
|
scores: torch.Tensor
|
||||||
|
sizes: torch.Tensor
|
||||||
|
class_scores: torch.Tensor
|
||||||
|
times: torch.Tensor
|
||||||
|
frequencies: torch.Tensor
|
||||||
|
features: torch.Tensor
|
||||||
|
|
||||||
|
def numpy(self) -> ClipDetectionsArray:
|
||||||
|
return ClipDetectionsArray(
|
||||||
|
scores=self.scores.detach().cpu().numpy(),
|
||||||
|
sizes=self.sizes.detach().cpu().numpy(),
|
||||||
|
class_scores=self.class_scores.detach().cpu().numpy(),
|
||||||
|
times=self.times.detach().cpu().numpy(),
|
||||||
|
frequencies=self.frequencies.detach().cpu().numpy(),
|
||||||
|
features=self.features.detach().cpu().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipDetections:
|
||||||
|
clip: data.Clip
|
||||||
|
detections: list[Detection]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipPrediction:
|
||||||
|
clip: data.Clip
|
||||||
|
detection_score: float
|
||||||
|
class_scores: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessorProtocol(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self, output: "ModelOutput"
|
||||||
|
) -> list[ClipDetectionsTensor]: ...
|
||||||
@ -1,7 +1,7 @@
|
|||||||
"""Assembles the full batdetect2 preprocessing pipeline.
|
"""Assembles the full batdetect2 preprocessing pipeline.
|
||||||
|
|
||||||
This module defines :class:`Preprocessor`, the concrete implementation of
|
This module defines :class:`Preprocessor`, the concrete implementation of
|
||||||
:class:`~batdetect2.typing.PreprocessorProtocol`, and the
|
:class:`~batdetect2.preprocess.types.PreprocessorProtocol`, and the
|
||||||
:func:`build_preprocessor` factory function that constructs it from a
|
:func:`build_preprocessor` factory function that constructs it from a
|
||||||
:class:`~batdetect2.preprocess.config.PreprocessingConfig`.
|
:class:`~batdetect2.preprocess.config.PreprocessingConfig`.
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ from batdetect2.preprocess.spectrogram import (
|
|||||||
build_spectrogram_resizer,
|
build_spectrogram_resizer,
|
||||||
build_spectrogram_transform,
|
build_spectrogram_transform,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Preprocessor",
|
"Preprocessor",
|
||||||
@ -42,7 +42,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||||
"""Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`.
|
"""Standard implementation of the :class:`~batdetect2.preprocess.types.PreprocessorProtocol`.
|
||||||
|
|
||||||
Wraps all preprocessing stages as ``torch.nn.Module`` submodules so
|
Wraps all preprocessing stages as ``torch.nn.Module`` submodules so
|
||||||
that parameters (e.g. PCEN filter coefficients) can be tracked and
|
that parameters (e.g. PCEN filter coefficients) can be tracked and
|
||||||
|
|||||||
31
src/batdetect2/preprocess/types.py
Normal file
31
src/batdetect2/preprocess/types.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PreprocessorProtocol",
|
||||||
|
"SpectrogramBuilder",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramBuilder(Protocol):
|
||||||
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessorProtocol(Protocol):
|
||||||
|
max_freq: float
|
||||||
|
min_freq: float
|
||||||
|
input_samplerate: int
|
||||||
|
output_samplerate: float
|
||||||
|
|
||||||
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||||
|
return self(torch.tensor(wav)).numpy()
|
||||||
@ -16,7 +16,7 @@ from batdetect2.data.conditions import (
|
|||||||
)
|
)
|
||||||
from batdetect2.targets.rois import ROIMapperConfig
|
from batdetect2.targets.rois import ROIMapperConfig
|
||||||
from batdetect2.targets.terms import call_type, generic_class
|
from batdetect2.targets.terms import call_type, generic_class
|
||||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
|
|||||||
@ -27,17 +27,13 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||||
from batdetect2.core.arrays import spec_to_xarray
|
from batdetect2.core.arrays import spec_to_xarray
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing import (
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
AudioLoader,
|
from batdetect2.targets.types import Position, ROITargetMapper, Size
|
||||||
Position,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
ROITargetMapper,
|
|
||||||
Size,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Anchor",
|
"Anchor",
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from batdetect2.targets.rois import (
|
|||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
from batdetect2.targets.types import Position, Size, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class Targets(TargetProtocol):
|
class Targets(TargetProtocol):
|
||||||
|
|||||||
60
src/batdetect2/targets/types.py
Normal file
60
src/batdetect2/targets/types.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Position",
|
||||||
|
"ROITargetMapper",
|
||||||
|
"Size",
|
||||||
|
"SoundEventDecoder",
|
||||||
|
"SoundEventEncoder",
|
||||||
|
"SoundEventFilter",
|
||||||
|
"TargetProtocol",
|
||||||
|
]
|
||||||
|
|
||||||
|
SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None]
|
||||||
|
SoundEventDecoder = Callable[[str], list[data.Tag]]
|
||||||
|
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
|
Position = tuple[float, float]
|
||||||
|
Size = np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class TargetProtocol(Protocol):
|
||||||
|
class_names: list[str]
|
||||||
|
detection_class_tags: list[data.Tag]
|
||||||
|
detection_class_name: str
|
||||||
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||||
|
|
||||||
|
def encode_class(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
) -> str | None: ...
|
||||||
|
|
||||||
|
def decode_class(self, class_label: str) -> list[data.Tag]: ...
|
||||||
|
|
||||||
|
def encode_roi(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
) -> tuple[Position, Size]: ...
|
||||||
|
|
||||||
|
def decode_roi(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
class_name: str | None = None,
|
||||||
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ROITargetMapper(Protocol):
|
||||||
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, sound_event: data.SoundEvent
|
||||||
|
) -> tuple[Position, Size]: ...
|
||||||
|
|
||||||
|
def decode(self, position: Position, size: Size) -> data.Geometry: ...
|
||||||
@ -13,6 +13,7 @@ from soundevent.geometry import scale_geometry, shift_geometry
|
|||||||
|
|
||||||
from batdetect2.audio.clips import get_subclip_annotation
|
from batdetect2.audio.clips import get_subclip_annotation
|
||||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core.registries import (
|
||||||
@ -20,7 +21,7 @@ from batdetect2.core.registries import (
|
|||||||
Registry,
|
Registry,
|
||||||
add_import_config,
|
add_import_config,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import AudioLoader, Augmentation
|
from batdetect2.train.types import Augmentation
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationConfig",
|
"AugmentationConfig",
|
||||||
|
|||||||
@ -5,17 +5,15 @@ from lightning.pytorch.callbacks import Callback
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.typing import (
|
from batdetect2.train.types import TrainExample
|
||||||
ClipDetections,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
ModelOutput,
|
|
||||||
TrainExample,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
|
|||||||
@ -8,9 +8,11 @@ from torch.utils.data import DataLoader, Dataset
|
|||||||
|
|
||||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||||
from batdetect2.audio.clips import PaddedClipConfig
|
from batdetect2.audio.clips import PaddedClipConfig
|
||||||
|
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
@ -18,14 +20,7 @@ from batdetect2.train.augmentations import (
|
|||||||
build_augmentations,
|
build_augmentations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import (
|
from batdetect2.train.types import Augmentation, ClipLabeller, TrainExample
|
||||||
AudioLoader,
|
|
||||||
Augmentation,
|
|
||||||
ClipLabeller,
|
|
||||||
ClipperProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TrainExample,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingDataset",
|
"TrainingDataset",
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from soundevent import data
|
|||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||||
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
from batdetect2.train.types import ClipLabeller, Heatmaps
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
|
|||||||
@ -2,11 +2,12 @@ import lightning as L
|
|||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.train.optimizers import build_optimizer
|
from batdetect2.train.optimizers import build_optimizer
|
||||||
from batdetect2.train.schedulers import build_scheduler
|
from batdetect2.train.schedulers import build_scheduler
|
||||||
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
|
from batdetect2.train.types import LossProtocol, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
|
|||||||
@ -26,7 +26,8 @@ from pydantic import Field
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
from batdetect2.models.types import ModelOutput
|
||||||
|
from batdetect2.train.types import Losses, LossProtocol, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxLoss",
|
"BBoxLoss",
|
||||||
|
|||||||
@ -43,7 +43,7 @@ optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(optimizer_registry)
|
@add_import_config(optimizer_registry, arg_names=["params"])
|
||||||
class OptimizerImportConfig(ImportConfig):
|
class OptimizerImportConfig(ImportConfig):
|
||||||
"""Use any callable as an optimizer.
|
"""Use any callable as an optimizer.
|
||||||
|
|
||||||
@ -84,4 +84,4 @@ def build_optimizer(
|
|||||||
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
||||||
"""
|
"""
|
||||||
config = config or AdamOptimizerConfig()
|
config = config or AdamOptimizerConfig()
|
||||||
return optimizer_registry.build(config, params=parameters)
|
return optimizer_registry.build(config, parameters)
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class CosineAnnealingSchedulerConfig(BaseConfig):
|
|||||||
scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler")
|
scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(scheduler_registry)
|
@add_import_config(scheduler_registry, arg_names=["optimizer"])
|
||||||
class SchedulerImportConfig(ImportConfig):
|
class SchedulerImportConfig(ImportConfig):
|
||||||
"""Use any callable as a scheduler.
|
"""Use any callable as a scheduler.
|
||||||
|
|
||||||
@ -78,4 +78,4 @@ def build_scheduler(
|
|||||||
"""Build a scheduler from configuration."""
|
"""Build a scheduler from configuration."""
|
||||||
config = config or CosineAnnealingSchedulerConfig()
|
config = config or CosineAnnealingSchedulerConfig()
|
||||||
|
|
||||||
return scheduler_registry.build(config, optimizer=optimizer)
|
return scheduler_registry.build(config, optimizer)
|
||||||
|
|||||||
@ -1,32 +1,28 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from lightning import Trainer, seed_everything
|
from lightning import Trainer, seed_everything
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.evaluate import build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train import TrainingConfig
|
from batdetect2.train import TrainingConfig
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
from batdetect2.train.types import ClipLabeller
|
||||||
if TYPE_CHECKING:
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
ClipLabeller,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
|
|||||||
70
src/batdetect2/train/types.py
Normal file
70
src/batdetect2/train/types.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, NamedTuple, Protocol
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Augmentation",
|
||||||
|
"ClipLabeller",
|
||||||
|
"Heatmaps",
|
||||||
|
"Losses",
|
||||||
|
"LossProtocol",
|
||||||
|
"TrainExample",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Heatmaps(NamedTuple):
|
||||||
|
detection: torch.Tensor
|
||||||
|
classes: torch.Tensor
|
||||||
|
size: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessedExample(NamedTuple):
|
||||||
|
audio: torch.Tensor
|
||||||
|
spectrogram: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return PreprocessedExample(
|
||||||
|
audio=self.audio.clone(),
|
||||||
|
spectrogram=self.spectrogram.clone(),
|
||||||
|
detection_heatmap=self.detection_heatmap.clone(),
|
||||||
|
size_heatmap=self.size_heatmap.clone(),
|
||||||
|
class_heatmap=self.class_heatmap.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||||
|
|
||||||
|
|
||||||
|
Augmentation = Callable[
|
||||||
|
[torch.Tensor, data.ClipAnnotation],
|
||||||
|
tuple[torch.Tensor, data.ClipAnnotation],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TrainExample(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
idx: torch.Tensor
|
||||||
|
start_time: torch.Tensor
|
||||||
|
end_time: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Losses(NamedTuple):
|
||||||
|
detection: torch.Tensor
|
||||||
|
size: torch.Tensor
|
||||||
|
classification: torch.Tensor
|
||||||
|
total: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class LossProtocol(Protocol):
|
||||||
|
def __call__(self, pred: "ModelOutput", gt: TrainExample) -> Losses: ...
|
||||||
@ -1,18 +1,14 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import Any, NamedTuple, TypedDict
|
import sys
|
||||||
|
from typing import Any, NamedTuple, Protocol, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
if sys.version_info >= (3, 11):
|
||||||
from typing import Protocol
|
from typing import NotRequired
|
||||||
except ImportError:
|
else:
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
try:
|
|
||||||
from typing import NotRequired # type: ignore
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,75 +0,0 @@
|
|||||||
from batdetect2.typing.data import OutputFormatterProtocol
|
|
||||||
from batdetect2.typing.evaluate import (
|
|
||||||
AffinityFunction,
|
|
||||||
ClipMatches,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
MatcherProtocol,
|
|
||||||
MatchEvaluation,
|
|
||||||
MetricsProtocol,
|
|
||||||
PlotterProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
ClipDetections,
|
|
||||||
ClipDetectionsTensor,
|
|
||||||
Detection,
|
|
||||||
GeometryDecoder,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import (
|
|
||||||
AudioLoader,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.targets import (
|
|
||||||
Position,
|
|
||||||
ROITargetMapper,
|
|
||||||
Size,
|
|
||||||
SoundEventDecoder,
|
|
||||||
SoundEventEncoder,
|
|
||||||
SoundEventFilter,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.train import (
|
|
||||||
Augmentation,
|
|
||||||
ClipLabeller,
|
|
||||||
ClipperProtocol,
|
|
||||||
Heatmaps,
|
|
||||||
Losses,
|
|
||||||
LossProtocol,
|
|
||||||
TrainExample,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AffinityFunction",
|
|
||||||
"AudioLoader",
|
|
||||||
"Augmentation",
|
|
||||||
"BackboneModel",
|
|
||||||
"ClipDetections",
|
|
||||||
"ClipDetectionsTensor",
|
|
||||||
"ClipLabeller",
|
|
||||||
"ClipMatches",
|
|
||||||
"ClipperProtocol",
|
|
||||||
"DetectionModel",
|
|
||||||
"EvaluatorProtocol",
|
|
||||||
"GeometryDecoder",
|
|
||||||
"Heatmaps",
|
|
||||||
"LossProtocol",
|
|
||||||
"Losses",
|
|
||||||
"MatchEvaluation",
|
|
||||||
"MatcherProtocol",
|
|
||||||
"MetricsProtocol",
|
|
||||||
"ModelOutput",
|
|
||||||
"OutputFormatterProtocol",
|
|
||||||
"PlotterProtocol",
|
|
||||||
"Position",
|
|
||||||
"PostprocessorProtocol",
|
|
||||||
"PreprocessorProtocol",
|
|
||||||
"ROITargetMapper",
|
|
||||||
"Detection",
|
|
||||||
"Size",
|
|
||||||
"SoundEventDecoder",
|
|
||||||
"SoundEventEncoder",
|
|
||||||
"SoundEventFilter",
|
|
||||||
"TargetProtocol",
|
|
||||||
"TrainExample",
|
|
||||||
]
|
|
||||||
@ -1,287 +0,0 @@
|
|||||||
"""Defines shared interfaces (ABCs) and data structures for models.
|
|
||||||
|
|
||||||
This module centralizes the definitions of core data structures, like the
|
|
||||||
standard model output container (`ModelOutput`), and establishes abstract base
|
|
||||||
classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts
|
|
||||||
for fundamental model components, ensuring modularity and consistent
|
|
||||||
interaction within the `batdetect2.models` package.
|
|
||||||
|
|
||||||
Key components:
|
|
||||||
- `ModelOutput`: Standard structure for outputs from detection models.
|
|
||||||
- `BackboneModel`: Generic interface for any feature extraction backbone.
|
|
||||||
- `EncoderDecoderModel`: Specialized interface for backbones with distinct
|
|
||||||
encoder-decoder stages (e.g., U-Net), providing access to intermediate
|
|
||||||
features.
|
|
||||||
- `DetectionModel`: Interface for the complete end-to-end detection model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, NamedTuple, Protocol
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ModelOutput",
|
|
||||||
"BackboneModel",
|
|
||||||
"EncoderDecoderModel",
|
|
||||||
"DetectionModel",
|
|
||||||
"BlockProtocol",
|
|
||||||
"EncoderProtocol",
|
|
||||||
"BottleneckProtocol",
|
|
||||||
"DecoderProtocol",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BlockProtocol(Protocol):
|
|
||||||
"""Interface for blocks of network layers."""
|
|
||||||
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Forward pass of the block."""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
|
||||||
"""Calculate the output height based on input height."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderProtocol(Protocol):
|
|
||||||
"""Interface for the downsampling path of a network."""
|
|
||||||
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
input_height: int
|
|
||||||
output_height: int
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
||||||
"""Forward pass must return intermediate tensors for skip connections."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class BottleneckProtocol(Protocol):
|
|
||||||
"""Interface for the middle part of a U-Net-like network."""
|
|
||||||
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
input_height: int
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Processes the features from the encoder."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderProtocol(Protocol):
|
|
||||||
"""Interface for the upsampling reconstruction path."""
|
|
||||||
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
input_height: int
|
|
||||||
output_height: int
|
|
||||||
depth: int
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
residuals: List[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Upsamples features while integrating skip connections."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOutput(NamedTuple):
|
|
||||||
"""Standard container for the outputs of a BatDetect2 detection model.
|
|
||||||
|
|
||||||
This structure groups the different prediction tensors produced by the
|
|
||||||
model for a batch of input spectrograms. All tensors typically share the
|
|
||||||
same spatial dimensions (height H, width W) corresponding to the model's
|
|
||||||
output resolution, and the same batch size (N).
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
detection_probs : torch.Tensor
|
|
||||||
Tensor containing the probability of sound event presence at each
|
|
||||||
location in the output grid.
|
|
||||||
Shape: `(N, 1, H, W)`
|
|
||||||
size_preds : torch.Tensor
|
|
||||||
Tensor containing the predicted size dimensions
|
|
||||||
(e.g., width and height) for a potential bounding box at each location.
|
|
||||||
Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height)
|
|
||||||
class_probs : torch.Tensor
|
|
||||||
Tensor containing the predicted probabilities (or logits, depending on
|
|
||||||
the final activation) for each target class at each location.
|
|
||||||
The number of channels corresponds to the number of specific classes
|
|
||||||
defined in the `Targets` configuration.
|
|
||||||
Shape: `(N, num_classes, H, W)`
|
|
||||||
features : torch.Tensor
|
|
||||||
Tensor containing features extracted by the model's backbone. These
|
|
||||||
might be used for downstream tasks or analysis. The number of channels
|
|
||||||
depends on the specific model architecture.
|
|
||||||
Shape: `(N, num_features, H, W)`
|
|
||||||
"""
|
|
||||||
|
|
||||||
detection_probs: torch.Tensor
|
|
||||||
size_preds: torch.Tensor
|
|
||||||
class_probs: torch.Tensor
|
|
||||||
features: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneModel(ABC, torch.nn.Module):
|
|
||||||
"""Abstract Base Class for generic feature extraction backbone models.
|
|
||||||
|
|
||||||
Defines the minimal interface for a feature extractor network within a
|
|
||||||
BatDetect2 model. Its primary role is to process an input spectrogram
|
|
||||||
tensor and produce a spatially rich feature map tensor, which is then
|
|
||||||
typically consumed by separate prediction heads (for detection,
|
|
||||||
classification, size).
|
|
||||||
|
|
||||||
This base class is agnostic to the specific internal architecture (e.g.,
|
|
||||||
it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete
|
|
||||||
implementations must inherit from this class and `torch.nn.Module`,
|
|
||||||
implement the `forward` method, and define the required attributes.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
input_height : int
|
|
||||||
Expected height (number of frequency bins) of the input spectrogram
|
|
||||||
tensor that the backbone is designed to process.
|
|
||||||
out_channels : int
|
|
||||||
Number of channels in the final feature map tensor produced by the
|
|
||||||
backbone's `forward` method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_height: int
|
|
||||||
"""Expected input spectrogram height (frequency bins)."""
|
|
||||||
|
|
||||||
out_channels: int
|
|
||||||
"""Number of output channels in the final feature map."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Perform the forward pass to extract features from the spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : torch.Tensor
|
|
||||||
Input spectrogram tensor, typically with shape
|
|
||||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
|
||||||
`frequency_bins` should match `self.input_height`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Output feature map tensor, typically with shape
|
|
||||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
|
||||||
The spatial dimensions (`output_height`, `output_width`) depend
|
|
||||||
on the specific backbone architecture (e.g., they might match the
|
|
||||||
input or be downsampled).
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderModel(BackboneModel):
|
|
||||||
"""Abstract Base Class for Encoder-Decoder style backbone models.
|
|
||||||
|
|
||||||
This class specializes `BackboneModel` for architectures that have distinct
|
|
||||||
encoder stages (downsampling path), a bottleneck, and decoder stages
|
|
||||||
(upsampling path).
|
|
||||||
|
|
||||||
It provides separate abstract methods for the `encode` and `decode` steps,
|
|
||||||
allowing access to the intermediate "bottleneck" features produced by the
|
|
||||||
encoder. This can be useful for tasks like transfer learning or specialized
|
|
||||||
analyses.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
input_height : int
|
|
||||||
(Inherited from BackboneModel) Expected input spectrogram height.
|
|
||||||
out_channels : int
|
|
||||||
(Inherited from BackboneModel) Number of output channels in the final
|
|
||||||
feature map produced by the decoder/forward pass.
|
|
||||||
bottleneck_channels : int
|
|
||||||
Number of channels in the feature map produced by the encoder at its
|
|
||||||
deepest point (the bottleneck), before the decoder starts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
bottleneck_channels: int
|
|
||||||
"""Number of channels at the encoder's bottleneck."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def encode(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Process the input spectrogram through the encoder part.
|
|
||||||
|
|
||||||
Takes the input spectrogram and passes it through the downsampling path
|
|
||||||
of the network up to the bottleneck layer.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : torch.Tensor
|
|
||||||
Input spectrogram tensor, typically with shape
|
|
||||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
The encoded feature map from the bottleneck layer, typically with
|
|
||||||
shape `(batch_size, self.bottleneck_channels, bottleneck_height,
|
|
||||||
bottleneck_width)`. The spatial dimensions are usually downsampled
|
|
||||||
relative to the input.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Process the bottleneck features through the decoder part.
|
|
||||||
|
|
||||||
Takes the encoded feature map from the bottleneck and passes it through
|
|
||||||
the upsampling path (potentially using skip connections from the
|
|
||||||
encoder) to produce the final output feature map.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
encoded : torch.Tensor
|
|
||||||
The bottleneck feature map tensor produced by the `encode` method.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
The final output feature map tensor, typically with shape
|
|
||||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
|
||||||
This should match the output shape of the `forward` method.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, torch.nn.Module):
|
|
||||||
"""Abstract Base Class for complete BatDetect2 detection models.
|
|
||||||
|
|
||||||
Defines the interface for the overall model that takes an input spectrogram
|
|
||||||
and produces all necessary outputs for detection, classification, and size
|
|
||||||
prediction, packaged within a `ModelOutput` object.
|
|
||||||
|
|
||||||
Concrete implementations typically combine a `BackboneModel` for feature
|
|
||||||
extraction with specific prediction heads for each output type. They must
|
|
||||||
inherit from this class and `torch.nn.Module`, and implement the `forward`
|
|
||||||
method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
|
||||||
"""Perform the forward pass of the full detection model.
|
|
||||||
|
|
||||||
Processes the input spectrogram through the backbone and prediction
|
|
||||||
heads to generate all required output tensors.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : torch.Tensor
|
|
||||||
Input spectrogram tensor, typically with shape
|
|
||||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ModelOutput
|
|
||||||
A NamedTuple containing the prediction tensors: `detection_probs`,
|
|
||||||
`size_preds`, `class_probs`, and `features`.
|
|
||||||
"""
|
|
||||||
@ -1,104 +0,0 @@
|
|||||||
"""Defines shared interfaces and data structures for postprocessing.
|
|
||||||
|
|
||||||
This module centralizes the Protocol definitions and common data structures
|
|
||||||
used throughout the `batdetect2.postprocess` module.
|
|
||||||
|
|
||||||
The main component is the `PostprocessorProtocol`, which outlines the standard
|
|
||||||
interface for an object responsible for executing the entire postprocessing
|
|
||||||
pipeline. This pipeline transforms raw neural network outputs into interpretable
|
|
||||||
detections represented as `soundevent` objects. Using protocols ensures
|
|
||||||
modularity and consistent interaction between different parts of the BatDetect2
|
|
||||||
system that deal with model predictions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, NamedTuple, Protocol
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.typing.models import ModelOutput
|
|
||||||
from batdetect2.typing.targets import Position, Size
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Detection",
|
|
||||||
"PostprocessorProtocol",
|
|
||||||
"GeometryDecoder",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: update the docstring
|
|
||||||
class GeometryDecoder(Protocol):
|
|
||||||
"""Type alias for a function that recovers geometry from position and size.
|
|
||||||
|
|
||||||
This callable takes:
|
|
||||||
1. A position tuple `(time, frequency)`.
|
|
||||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
|
||||||
3. Optionally a class name of the highest scoring class. This is to accomodate
|
|
||||||
different ways of decoding geometry that depend on the predicted class.
|
|
||||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
|
||||||
`BoundingBox`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, position: Position, size: Size, class_name: str | None = None
|
|
||||||
) -> data.Geometry: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Detection:
|
|
||||||
geometry: data.Geometry
|
|
||||||
detection_score: float
|
|
||||||
class_scores: np.ndarray
|
|
||||||
features: np.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionsArray(NamedTuple):
|
|
||||||
scores: np.ndarray
|
|
||||||
sizes: np.ndarray
|
|
||||||
class_scores: np.ndarray
|
|
||||||
times: np.ndarray
|
|
||||||
frequencies: np.ndarray
|
|
||||||
features: np.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionsTensor(NamedTuple):
|
|
||||||
scores: torch.Tensor
|
|
||||||
sizes: torch.Tensor
|
|
||||||
class_scores: torch.Tensor
|
|
||||||
times: torch.Tensor
|
|
||||||
frequencies: torch.Tensor
|
|
||||||
features: torch.Tensor
|
|
||||||
|
|
||||||
def numpy(self) -> ClipDetectionsArray:
|
|
||||||
return ClipDetectionsArray(
|
|
||||||
scores=self.scores.detach().cpu().numpy(),
|
|
||||||
sizes=self.sizes.detach().cpu().numpy(),
|
|
||||||
class_scores=self.class_scores.detach().cpu().numpy(),
|
|
||||||
times=self.times.detach().cpu().numpy(),
|
|
||||||
frequencies=self.frequencies.detach().cpu().numpy(),
|
|
||||||
features=self.features.detach().cpu().numpy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipDetections:
|
|
||||||
clip: data.Clip
|
|
||||||
detections: List[Detection]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipPrediction:
|
|
||||||
clip: data.Clip
|
|
||||||
detection_score: float
|
|
||||||
class_scores: np.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
class PostprocessorProtocol(Protocol):
|
|
||||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
) -> List[ClipDetectionsTensor]: ...
|
|
||||||
@ -1,168 +0,0 @@
|
|||||||
"""Defines common interfaces (Protocols) for preprocessing components.
|
|
||||||
|
|
||||||
This module centralizes the Protocol definitions used throughout the
|
|
||||||
`batdetect2.preprocess` package. Protocols define expected methods and
|
|
||||||
signatures, allowing for flexible and interchangeable implementations of
|
|
||||||
components like audio loaders and spectrogram builders.
|
|
||||||
|
|
||||||
Using these protocols ensures that different parts of the preprocessing
|
|
||||||
pipeline can interact consistently, regardless of the specific underlying
|
|
||||||
implementation (e.g., different libraries or custom configurations).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AudioLoader",
|
|
||||||
"SpectrogramBuilder",
|
|
||||||
"PreprocessorProtocol",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class AudioLoader(Protocol):
|
|
||||||
"""Defines the interface for an audio loading and processing component.
|
|
||||||
|
|
||||||
An AudioLoader is responsible for retrieving audio data corresponding to
|
|
||||||
different soundevent objects (files, Recordings, Clips) and applying a
|
|
||||||
configured set of initial preprocessing steps. Adhering to this protocol
|
|
||||||
allows for different loading strategies or implementations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
samplerate: int
|
|
||||||
|
|
||||||
def load_file(
|
|
||||||
self,
|
|
||||||
path: data.PathLike,
|
|
||||||
audio_dir: data.PathLike | None = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess audio directly from a file path.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the audio file.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory prefix to prepend to the path if `path` is relative.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the audio file cannot be found.
|
|
||||||
Exception
|
|
||||||
If the audio file cannot be loaded or processed.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def load_recording(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: data.PathLike | None = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the entire audio for a Recording object.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : data.Recording
|
|
||||||
The Recording object containing metadata about the audio file.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory where the audio file associated with the recording
|
|
||||||
can be found, especially if the path in the recording is relative.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
The loaded and preprocessed audio waveform as a 1-D NumPy
|
|
||||||
array. Typically loads only the first channel.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the audio file associated with the recording cannot be found.
|
|
||||||
Exception
|
|
||||||
If the audio file cannot be loaded or processed.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def load_clip(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: data.PathLike | None = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the audio segment defined by a Clip object.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip : data.Clip
|
|
||||||
The Clip object specifying the recording and the start/end times
|
|
||||||
of the segment to load.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory where the audio file associated with the clip's
|
|
||||||
recording can be found.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
The loaded and preprocessed audio waveform for the specified
|
|
||||||
clip duration as a 1-D NumPy array. Typically loads only the
|
|
||||||
first channel.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the audio file associated with the clip cannot be found.
|
|
||||||
Exception
|
|
||||||
If the audio file cannot be loaded or processed.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramBuilder(Protocol):
|
|
||||||
"""Defines the interface for a spectrogram generation component."""
|
|
||||||
|
|
||||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Generate a spectrogram from an audio waveform."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessorProtocol(Protocol):
|
|
||||||
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
|
||||||
|
|
||||||
max_freq: float
|
|
||||||
|
|
||||||
min_freq: float
|
|
||||||
|
|
||||||
input_samplerate: int
|
|
||||||
|
|
||||||
output_samplerate: float
|
|
||||||
|
|
||||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
|
||||||
"""Run the full preprocessing pipeline on a NumPy waveform.
|
|
||||||
|
|
||||||
This default implementation converts the array to a
|
|
||||||
``torch.Tensor``, calls :meth:`__call__`, and converts the
|
|
||||||
result back to a NumPy array. Concrete implementations may
|
|
||||||
override this for efficiency.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : np.ndarray
|
|
||||||
Input waveform as a 1-D NumPy array.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
Preprocessed spectrogram as a NumPy array.
|
|
||||||
"""
|
|
||||||
return self(torch.tensor(wav)).numpy()
|
|
||||||
@ -1,298 +0,0 @@
|
|||||||
"""Defines the core interface (Protocol) for the target definition pipeline.
|
|
||||||
|
|
||||||
This module specifies the standard structure, attributes, and methods expected
|
|
||||||
from an object that encapsulates the complete configured logic for processing
|
|
||||||
sound event annotations within the `batdetect2.targets` system.
|
|
||||||
|
|
||||||
The main component defined here is the `TargetProtocol`. This protocol acts as
|
|
||||||
a contract for the entire target definition process, covering semantic aspects
|
|
||||||
(filtering, tag transformation, class encoding/decoding) as well as geometric
|
|
||||||
aspects (mapping regions of interest to target positions and sizes). It ensures
|
|
||||||
that components responsible for these tasks can be interacted with consistently
|
|
||||||
throughout BatDetect2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import List, Protocol
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TargetProtocol",
|
|
||||||
"SoundEventEncoder",
|
|
||||||
"SoundEventDecoder",
|
|
||||||
"SoundEventFilter",
|
|
||||||
"Position",
|
|
||||||
"Size",
|
|
||||||
]
|
|
||||||
|
|
||||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], str | None]
|
|
||||||
"""Type alias for a sound event class encoder function.
|
|
||||||
|
|
||||||
An encoder function takes a sound event annotation and returns the string name
|
|
||||||
of the target class it belongs to, based on a predefined set of rules.
|
|
||||||
If the annotation does not match any defined target class according to the
|
|
||||||
rules, the function returns None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventDecoder = Callable[[str], List[data.Tag]]
|
|
||||||
"""Type alias for a sound event class decoder function.
|
|
||||||
|
|
||||||
A decoder function takes a class name string (as predicted by the model or
|
|
||||||
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
|
|
||||||
that represent that class according to the configuration. This is used to
|
|
||||||
translate model outputs back into meaningful annotations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
|
||||||
"""Type alias for a filter function.
|
|
||||||
|
|
||||||
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
|
||||||
and returns True if the annotation should be kept based on the filter's
|
|
||||||
criteria, or False if it should be discarded.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Position = tuple[float, float]
|
|
||||||
"""A tuple representing (time, frequency) coordinates."""
|
|
||||||
|
|
||||||
Size = np.ndarray
|
|
||||||
"""A NumPy array representing the size dimensions of a target."""
|
|
||||||
|
|
||||||
|
|
||||||
class TargetProtocol(Protocol):
|
|
||||||
"""Protocol defining the interface for the target definition pipeline.
|
|
||||||
|
|
||||||
This protocol outlines the standard attributes and methods for an object
|
|
||||||
that encapsulates the complete, configured process for handling sound event
|
|
||||||
annotations (both tags and geometry). It defines how to:
|
|
||||||
- Select relevant annotations.
|
|
||||||
- Encode an annotation into a specific target class name.
|
|
||||||
- Decode a class name back into representative tags.
|
|
||||||
- Extract a target reference position from an annotation's geometry (ROI).
|
|
||||||
- Calculate target size dimensions from an annotation's geometry.
|
|
||||||
- Recover an approximate geometry (ROI) from a position and size
|
|
||||||
dimensions.
|
|
||||||
|
|
||||||
Implementations of this protocol bundle all configured logic for these
|
|
||||||
steps.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
class_names : List[str]
|
|
||||||
An ordered list of the unique names of the specific target classes
|
|
||||||
defined by the configuration.
|
|
||||||
generic_class_tags : List[data.Tag]
|
|
||||||
A list of `soundevent.data.Tag` objects representing the configured
|
|
||||||
generic class category (e.g., used when no specific class matches).
|
|
||||||
dimension_names : List[str]
|
|
||||||
A list containing the names of the size dimensions returned by
|
|
||||||
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class_names: List[str]
|
|
||||||
"""Ordered list of unique names for the specific target classes."""
|
|
||||||
|
|
||||||
detection_class_tags: List[data.Tag]
|
|
||||||
"""List of tags representing the detection category (unclassified)."""
|
|
||||||
|
|
||||||
detection_class_name: str
|
|
||||||
|
|
||||||
dimension_names: List[str]
|
|
||||||
"""Names of the size dimensions (e.g., ['width', 'height'])."""
|
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
|
||||||
"""Apply the filter to a sound event annotation.
|
|
||||||
|
|
||||||
Determines if the annotation should be included in further processing
|
|
||||||
and training based on the configured filtering rules.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation to filter.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the annotation should be kept (passes the filter),
|
|
||||||
False otherwise. Implementations should return True if no
|
|
||||||
filtering is configured.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def encode_class(
|
|
||||||
self,
|
|
||||||
sound_event: data.SoundEventAnnotation,
|
|
||||||
) -> str | None:
|
|
||||||
"""Encode a sound event annotation to its target class name.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The (potentially filtered and transformed) annotation to encode.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str or None
|
|
||||||
The string name of the matched target class if the annotation
|
|
||||||
matches a specific class definition. Returns None if the annotation
|
|
||||||
does not match any specific class rule (indicating it may belong
|
|
||||||
to a generic category or should be handled differently downstream).
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
|
||||||
"""Decode a predicted class name back into representative tags.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
class_label : str
|
|
||||||
The class name string (e.g., predicted by a model) to decode.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.Tag]
|
|
||||||
The list of tags corresponding to the input class name according
|
|
||||||
to the configuration. May return an empty list or raise an error
|
|
||||||
for unmapped labels, depending on the implementation's configuration
|
|
||||||
(e.g., `raise_on_unmapped` flag during building).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError, KeyError
|
|
||||||
Implementations might raise an error if the `class_label` is not
|
|
||||||
found in the configured mapping and error raising is enabled.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def encode_roi(
|
|
||||||
self, sound_event: data.SoundEventAnnotation
|
|
||||||
) -> tuple[Position, Size]:
|
|
||||||
"""Extract the target reference position from the annotation's geometry.
|
|
||||||
|
|
||||||
Calculates the `(time, frequency)` coordinate representing the primary
|
|
||||||
location of the sound event.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation containing the geometry (ROI) to process.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[float, float]
|
|
||||||
The calculated reference position `(time, frequency)`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the annotation lacks geometry or if the position cannot be
|
|
||||||
calculated for the geometry type or configured reference point.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
# TODO: Update docstrings
|
|
||||||
def decode_roi(
|
|
||||||
self,
|
|
||||||
position: Position,
|
|
||||||
size: Size,
|
|
||||||
class_name: str | None = None,
|
|
||||||
) -> data.Geometry:
|
|
||||||
"""Recover the ROI geometry from a position and dimensions.
|
|
||||||
|
|
||||||
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
|
||||||
a reference position `(time, frequency)` and an array of size
|
|
||||||
dimensions and reconstructs an approximate geometric representation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pos : Tuple[float, float]
|
|
||||||
The reference position `(time, frequency)`.
|
|
||||||
dims : np.ndarray
|
|
||||||
The NumPy array containing the dimensions (e.g., predicted
|
|
||||||
by the model), corresponding to the order in `dimension_names`.
|
|
||||||
class_name: str
|
|
||||||
class
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Geometry
|
|
||||||
The reconstructed geometry.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the number of provided `dims` does not match `dimension_names`,
|
|
||||||
if dimensions are invalid (e.g., negative after unscaling), or
|
|
||||||
if reconstruction fails based on the configured position type.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ROITargetMapper(Protocol):
|
|
||||||
"""Protocol defining the interface for ROI-to-target mapping.
|
|
||||||
|
|
||||||
Specifies the `encode` and `decode` methods required for converting a
|
|
||||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
|
||||||
position and a size vector) and for recovering an approximate ROI from that
|
|
||||||
representation.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
dimension_names : List[str]
|
|
||||||
A list containing the names of the dimensions in the `Size` array
|
|
||||||
returned by `encode` and expected by `decode`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dimension_names: List[str]
|
|
||||||
|
|
||||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
|
||||||
"""Encode a SoundEvent's geometry into a position and size.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEvent
|
|
||||||
The input sound event, which must have a geometry attribute.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[Position, Size]
|
|
||||||
A tuple containing:
|
|
||||||
- The reference position as (time, frequency) coordinates.
|
|
||||||
- A NumPy array with the calculated size dimensions.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the sound event does not have a geometry.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
|
||||||
"""Decode a position and size back into a geometric ROI.
|
|
||||||
|
|
||||||
Performs the inverse mapping: takes a reference position and size
|
|
||||||
dimensions and reconstructs a geometric representation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
position : Position
|
|
||||||
The reference position (time, frequency).
|
|
||||||
size : Size
|
|
||||||
NumPy array containing the size dimensions, matching the order
|
|
||||||
and meaning specified by `dimension_names`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Geometry
|
|
||||||
The reconstructed geometry, typically a `BoundingBox`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the `size` array has an unexpected shape or if reconstruction
|
|
||||||
fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@ -1,108 +0,0 @@
|
|||||||
from typing import Callable, NamedTuple, Protocol
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.typing.models import ModelOutput
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Augmentation",
|
|
||||||
"ClipLabeller",
|
|
||||||
"ClipperProtocol",
|
|
||||||
"Heatmaps",
|
|
||||||
"LossProtocol",
|
|
||||||
"Losses",
|
|
||||||
"TrainExample",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Heatmaps(NamedTuple):
|
|
||||||
"""Structure holding the generated heatmap targets."""
|
|
||||||
|
|
||||||
detection: torch.Tensor
|
|
||||||
classes: torch.Tensor
|
|
||||||
size: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessedExample(NamedTuple):
|
|
||||||
audio: torch.Tensor
|
|
||||||
spectrogram: torch.Tensor
|
|
||||||
detection_heatmap: torch.Tensor
|
|
||||||
class_heatmap: torch.Tensor
|
|
||||||
size_heatmap: torch.Tensor
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=self.audio.clone(),
|
|
||||||
spectrogram=self.spectrogram.clone(),
|
|
||||||
detection_heatmap=self.detection_heatmap.clone(),
|
|
||||||
size_heatmap=self.size_heatmap.clone(),
|
|
||||||
class_heatmap=self.class_heatmap.clone(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
|
||||||
"""Type alias for the final clip labelling function.
|
|
||||||
|
|
||||||
This function takes the complete annotations for a clip and the corresponding
|
|
||||||
spectrogram, applies all configured filtering, transformation, and encoding
|
|
||||||
steps, and returns the final `Heatmaps` used for model training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
Augmentation = Callable[
|
|
||||||
[torch.Tensor, data.ClipAnnotation],
|
|
||||||
tuple[torch.Tensor, data.ClipAnnotation],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TrainExample(NamedTuple):
|
|
||||||
spec: torch.Tensor
|
|
||||||
detection_heatmap: torch.Tensor
|
|
||||||
class_heatmap: torch.Tensor
|
|
||||||
size_heatmap: torch.Tensor
|
|
||||||
idx: torch.Tensor
|
|
||||||
start_time: torch.Tensor
|
|
||||||
end_time: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class Losses(NamedTuple):
|
|
||||||
"""Structure to hold the computed loss values.
|
|
||||||
|
|
||||||
Allows returning individual loss components along with the total weighted
|
|
||||||
loss for monitoring and analysis during training.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
detection : torch.Tensor
|
|
||||||
Scalar tensor representing the calculated detection loss component
|
|
||||||
(before weighting).
|
|
||||||
size : torch.Tensor
|
|
||||||
Scalar tensor representing the calculated size regression loss component
|
|
||||||
(before weighting).
|
|
||||||
classification : torch.Tensor
|
|
||||||
Scalar tensor representing the calculated classification loss component
|
|
||||||
(before weighting).
|
|
||||||
total : torch.Tensor
|
|
||||||
Scalar tensor representing the final combined loss, computed as the
|
|
||||||
weighted sum of the detection, size, and classification components.
|
|
||||||
This is the value typically used for backpropagation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
detection: torch.Tensor
|
|
||||||
size: torch.Tensor
|
|
||||||
classification: torch.Tensor
|
|
||||||
total: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class LossProtocol(Protocol):
|
|
||||||
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
|
|
||||||
|
|
||||||
|
|
||||||
class ClipperProtocol(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
) -> data.ClipAnnotation: ...
|
|
||||||
|
|
||||||
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
|
||||||
@ -11,23 +11,20 @@ from soundevent import data, terms
|
|||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.audio.clips import build_clipper
|
from batdetect2.audio.clips import build_clipper
|
||||||
|
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_targets,
|
build_targets,
|
||||||
call_type,
|
call_type,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.classes import TargetClassConfig
|
from batdetect2.targets.classes import TargetClassConfig
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import (
|
from batdetect2.train.types import ClipLabeller
|
||||||
ClipLabeller,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
from batdetect2.typing.train import ClipperProtocol
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -9,11 +9,8 @@ from batdetect2.outputs.formats import (
|
|||||||
ParquetOutputConfig,
|
ParquetOutputConfig,
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
ClipDetections,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
Detection,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -5,11 +5,8 @@ import pytest
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
|
from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
|
||||||
from batdetect2.typing import (
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
ClipDetections,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
Detection,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing import Detection
|
from batdetect2.postprocess.types import Detection
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.evaluate.tasks import build_task
|
from batdetect2.evaluate.tasks import build_task
|
||||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
from batdetect2.typing import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def test_classification(
|
def test_classification(
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.evaluate.tasks import build_task
|
from batdetect2.evaluate.tasks import build_task
|
||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
from batdetect2.typing import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def test_detection(
|
def test_detection(
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from batdetect2.models.backbones import (
|
|||||||
build_backbone,
|
build_backbone,
|
||||||
load_backbone_config,
|
load_backbone_config,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.models import BackboneModel
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
|
|
||||||
def test_unet_backbone_config_defaults():
|
def test_unet_backbone_config_defaults():
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from batdetect2.models.backbones import UNetBackboneConfig
|
|||||||
from batdetect2.models.detectors import Detector, build_detector
|
from batdetect2.models.detectors import Detector, build_detector
|
||||||
from batdetect2.models.encoder import Encoder
|
from batdetect2.models.encoder import Encoder
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.typing.models import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.outputs import build_output_transform
|
from batdetect2.outputs import build_output_transform
|
||||||
from batdetect2.typing import ClipDetections, Detection
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
|
|
||||||
|
|
||||||
def test_shift_time_to_clip_start(clip: data.Clip):
|
def test_shift_time_to_clip_start(clip: data.Clip):
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from batdetect2.postprocess.decoding import (
|
|||||||
get_generic_tags,
|
get_generic_tags,
|
||||||
get_prediction_features,
|
get_prediction_features,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import Detection, TargetProtocol
|
from batdetect2.postprocess.types import Detection
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from torch.optim import Adam
|
|||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
@ -19,7 +20,6 @@ from batdetect2.train import (
|
|||||||
from batdetect2.train.optimizers import AdamOptimizerConfig
|
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||||
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
|
|
||||||
|
|
||||||
def build_default_module(config: BatDetect2Config | None = None):
|
def build_default_module(config: BatDetect2Config | None = None):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user