Moved types to dedicated module

This commit is contained in:
mbsantiago 2025-08-24 10:55:48 +01:00
parent 02adc19070
commit 61115d562c
47 changed files with 280 additions and 238 deletions

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
from soundevent import data from soundevent import data
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
from batdetect2.targets.types import TargetProtocol from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events( def iterate_over_sound_events(

View File

@ -7,7 +7,7 @@ from batdetect2.data.summary import (
extract_recordings_df, extract_recordings_df,
extract_sound_events_df, extract_sound_events_df,
) )
from batdetect2.targets.types import TargetProtocol from batdetect2.typing.targets import TargetProtocol
def split_dataset_by_recordings( def split_dataset_by_recordings(

View File

@ -2,7 +2,7 @@ import pandas as pd
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
from batdetect2.targets.types import TargetProtocol from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"extract_recordings_df", "extract_recordings_df",

View File

@ -4,15 +4,15 @@ from typing import List, Literal, Optional, Tuple
import numpy as np import numpy as np
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from soundevent.evaluation import ( from soundevent.evaluation import match_geometries as optimal_match
match_geometries as optimal_match,
)
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation from batdetect2.typing import (
from batdetect2.postprocess.types import BatDetect2Prediction BatDetect2Prediction,
from batdetect2.targets.types import TargetProtocol MatchEvaluation,
TargetProtocol,
)
MatchingStrategy = Literal["greedy", "optimal"] MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'.""" """The type of matching algorithm to use: 'greedy' or 'optimal'."""

View File

@ -4,7 +4,7 @@ import pandas as pd
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.typing import MatchEvaluation, MetricsProtocol
__all__ = ["DetectionAveragePrecision"] __all__ = ["DetectionAveragePrecision"]

View File

@ -28,8 +28,11 @@ provided here.
from typing import Optional from typing import Optional
from loguru import logger import torch
from lightning import LightningModule
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
Backbone, Backbone,
BackboneConfig, BackboneConfig,
@ -53,24 +56,25 @@ from batdetect2.models.decoder import (
DecoderConfig, DecoderConfig,
build_decoder, build_decoder,
) )
from batdetect2.models.detectors import ( from batdetect2.models.detectors import Detector, build_detector
Detector,
build_detector,
)
from batdetect2.models.encoder import ( from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG, DEFAULT_ENCODER_CONFIG,
EncoderConfig, EncoderConfig,
build_encoder, build_encoder,
) )
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel, ModelOutput
from batdetect2.typing.postprocess import PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"BBoxHead", "BBoxHead",
"Backbone", "Backbone",
"BackboneConfig", "BackboneConfig",
"BackboneModel",
"BackboneModel",
"Bottleneck", "Bottleneck",
"BottleneckConfig", "BottleneckConfig",
"ClassifierHead", "ClassifierHead",
@ -78,65 +82,75 @@ __all__ = [
"DEFAULT_DECODER_CONFIG", "DEFAULT_DECODER_CONFIG",
"DEFAULT_ENCODER_CONFIG", "DEFAULT_ENCODER_CONFIG",
"DecoderConfig", "DecoderConfig",
"DetectionModel",
"Detector", "Detector",
"DetectorHead", "DetectorHead",
"EncoderConfig", "EncoderConfig",
"FreqCoordConvDownConfig", "FreqCoordConvDownConfig",
"FreqCoordConvUpConfig", "FreqCoordConvUpConfig",
"ModelOutput",
"StandardConvDownConfig", "StandardConvDownConfig",
"StandardConvUpConfig", "StandardConvUpConfig",
"build_backbone", "build_backbone",
"build_bottleneck", "build_bottleneck",
"build_decoder", "build_decoder",
"build_detector",
"build_encoder", "build_encoder",
"build_model", "build_detector",
"load_backbone_config", "load_backbone_config",
"Model",
"ModelConfig",
"build_model",
] ]
def build_model( class Model(LightningModule):
num_classes: int, detector: DetectionModel
config: Optional[BackboneConfig] = None, preprocessor: PreprocessorProtocol
) -> DetectionModel: postprocessor: PostprocessorProtocol
"""Build the complete BatDetect2 detection model. targets: TargetProtocol
This high-level factory function constructs the standard BatDetect2 model def __init__(
architecture. It first builds the feature extraction backbone (typically an self,
encoder-bottleneck-decoder structure) based on the provided detector: DetectionModel,
`BackboneConfig` (or defaults if None), and then attaches the standard preprocessor: PreprocessorProtocol,
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the postprocessor: PostprocessorProtocol,
`build_detector` function. targets: TargetProtocol,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
Parameters def forward(self, spec: torch.Tensor) -> ModelOutput:
---------- return self.detector(spec)
num_classes : int
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Returns
-------
DetectionModel
An initialized `Detector` model instance.
Raises class ModelConfig(BaseConfig):
------ model: BackboneConfig = Field(default_factory=BackboneConfig)
ValueError preprocess: PreprocessingConfig = Field(
If `num_classes` is not positive, or if errors occur during the default_factory=PreprocessingConfig
construction of the backbone or detector components (e.g., incompatible )
configurations, invalid parameters). postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
""" targets: TargetConfig = Field(default_factory=TargetConfig)
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}", def build_model(config: Optional[ModelConfig] = None):
lambda: config.to_yaml_string(), config = config or ModelConfig()
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
targets=targets,
config=config.postprocess,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config.model,
)
return Model(
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
) )
backbone = build_backbone(config)
return build_detector(num_classes, backbone)

View File

@ -43,7 +43,7 @@ from batdetect2.models.encoder import (
EncoderConfig, EncoderConfig,
build_encoder, build_encoder,
) )
from batdetect2.models.types import BackboneModel from batdetect2.typing.models import BackboneModel
__all__ = [ __all__ = [
"Backbone", "Backbone",

View File

@ -8,18 +8,25 @@ classifying them.
The primary components are: The primary components are:
- `Detector`: The `torch.nn.Module` subclass representing the complete model. - `Detector`: The `torch.nn.Module` subclass representing the complete model.
- `build_detector`: A factory function to conveniently construct a standard
`Detector` instance given a backbone and the number of target classes.
This module focuses purely on the neural network architecture definition. The This module focuses purely on the neural network architecture definition. The
logic for preprocessing inputs and postprocessing/decoding outputs resides in logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
""" """
import torch from typing import Optional
import torch
from loguru import logger
from batdetect2.models.backbones import BackboneConfig, build_backbone
from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
__all__ = [
"Detector",
"build_detector",
]
class Detector(DetectionModel): class Detector(DetectionModel):
@ -119,36 +126,41 @@ class Detector(DetectionModel):
) )
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector: def build_detector(
"""Factory function to build a standard Detector model instance. num_classes: int, config: Optional[BackboneConfig] = None
) -> DetectionModel:
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`, """Build the complete BatDetect2 detection model.
`BBoxHead`) configured appropriately based on the output channels of the
provided `backbone` and the specified `num_classes`. It then assembles
these components into a `Detector` model.
Parameters Parameters
---------- ----------
num_classes : int num_classes : int
The number of specific target classes for the classification head The number of specific target classes the model should predict
(excluding any implicit background class). Must be positive. (required for the `ClassifierHead`). Must be positive.
backbone : BackboneModel config : BackboneConfig, optional
An initialized feature extraction backbone module instance. The number Configuration object specifying the architecture of the backbone
of output channels from this backbone (`backbone.out_channels`) is used (encoder, bottleneck, decoder). If None, default configurations defined
to configure the input channels for the prediction heads. within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Returns Returns
------- -------
Detector DetectionModel
An initialized `Detector` model instance. An initialized `Detector` model instance.
Raises Raises
------ ------
ValueError ValueError
If `num_classes` is not positive. If `num_classes` is not positive, or if errors occur during the
AttributeError construction of the backbone or detector components (e.g., incompatible
If `backbone` does not have the required `out_channels` attribute. configurations, invalid parameters).
""" """
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead( classifier_head = ClassifierHead(
num_classes=num_classes, num_classes=num_classes,
in_channels=backbone.out_channels, in_channels=backbone.out_channels,

View File

@ -4,7 +4,7 @@ from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
__all__ = [ __all__ = [
"plot_clip_annotation", "plot_clip_annotation",

View File

@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
__all__ = [ __all__ = [
"plot_clip_prediction", "plot_clip_prediction",

View File

@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from batdetect2 import plotting from batdetect2 import plotting
from batdetect2.evaluate.types import MatchEvaluation from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
@dataclass @dataclass

View File

@ -6,13 +6,13 @@ from soundevent import data, plot
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper from soundevent.plot.tags import TagColorMapper
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import ( from batdetect2.preprocess import (
PreprocessorProtocol, PreprocessorProtocol,
get_default_preprocessor, get_default_preprocessor,
) )
from batdetect2.typing.evaluate import MatchEvaluation
__all__ = [ __all__ = [
"plot_matches", "plot_matches",

View File

@ -36,7 +36,6 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import ( from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction, convert_raw_prediction_to_sound_event_prediction,
@ -62,13 +61,14 @@ from batdetect2.postprocess.remapping import (
features_to_xarray, features_to_xarray,
sizes_to_xarray, sizes_to_xarray,
) )
from batdetect2.postprocess.types import ( from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction, BatDetect2Prediction,
PostprocessorProtocol, PostprocessorProtocol,
RawPrediction, RawPrediction,
) )
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
@ -79,8 +79,6 @@ __all__ = [
"NMS_KERNEL_SIZE", "NMS_KERNEL_SIZE",
"PostprocessConfig", "PostprocessConfig",
"Postprocessor", "Postprocessor",
"PostprocessorProtocol",
"RawPrediction",
"TOP_K_PER_SEC", "TOP_K_PER_SEC",
"build_postprocessor", "build_postprocessor",
"classification_to_xarray", "classification_to_xarray",

View File

@ -32,8 +32,8 @@ import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction
from batdetect2.targets.types import TargetProtocol from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"convert_xr_dataset_to_raw_prediction", "convert_xr_dataset_to_raw_prediction",

View File

@ -57,7 +57,7 @@ from batdetect2.preprocess.spectrogram import (
build_spectrogram_builder, build_spectrogram_builder,
get_spectrogram_resolution, get_spectrogram_resolution,
) )
from batdetect2.preprocess.types import ( from batdetect2.typing.preprocess import (
AudioLoader, AudioLoader,
PreprocessorProtocol, PreprocessorProtocol,
SpectrogramBuilder, SpectrogramBuilder,
@ -65,7 +65,6 @@ from batdetect2.preprocess.types import (
__all__ = [ __all__ = [
"AudioConfig", "AudioConfig",
"AudioLoader",
"ConfigurableSpectrogramBuilder", "ConfigurableSpectrogramBuilder",
"DEFAULT_DURATION", "DEFAULT_DURATION",
"FrequencyConfig", "FrequencyConfig",
@ -77,7 +76,6 @@ __all__ = [
"SCALE_RAW_AUDIO", "SCALE_RAW_AUDIO",
"STFTConfig", "STFTConfig",
"SpecSizeConfig", "SpecSizeConfig",
"SpectrogramBuilder",
"SpectrogramConfig", "SpectrogramConfig",
"StandardPreprocessor", "StandardPreprocessor",
"TARGET_SAMPLERATE_HZ", "TARGET_SAMPLERATE_HZ",

View File

@ -32,7 +32,7 @@ from soundevent.arrays import operations as ops
from soundfile import LibsndfileError from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.preprocess.types import AudioLoader from batdetect2.typing.preprocess import AudioLoader
__all__ = [ __all__ = [
"ResampleConfig", "ResampleConfig",

View File

@ -31,7 +31,7 @@ from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.preprocess.audio import convert_to_xr from batdetect2.preprocess.audio import convert_to_xr
from batdetect2.preprocess.types import SpectrogramBuilder from batdetect2.typing.preprocess import SpectrogramBuilder
__all__ = [ __all__ = [
"STFTConfig", "STFTConfig",
@ -540,7 +540,6 @@ def apply_pcen(
xr.DataArray xr.DataArray
PCEN-scaled spectrogram. PCEN-scaled spectrogram.
""" """
spec = spec * (2**31)
samplerate = 1 / spec.time.attrs["step"] samplerate = 1 / spec.time.attrs["step"]
hop_size = spec.attrs["hop_size"] hop_size = spec.attrs["hop_size"]
@ -559,22 +558,24 @@ def apply_pcen(
[1, smoothing_constant - 1], [1, smoothing_constant - 1],
)[:] )[:]
spec_data = spec.data * (2**31)
# Smooth the input array along the given axis # Smooth the input array along the given axis
smoothed, _ = signal.lfilter( smoothed, _ = signal.lfilter(
[smoothing_constant], [smoothing_constant],
[1, smoothing_constant - 1], [1, smoothing_constant - 1],
spec.data, spec_data,
zi=zi, zi=zi,
axis=axis, # type: ignore axis=axis, # type: ignore
) )
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps))) smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
data = (bias**power) * np.expm1( data = (bias**power) * np.expm1(
power * np.log1p(spec.data * smooth / bias) power * np.log1p(spec_data * smooth / bias)
) )
return xr.DataArray( return xr.DataArray(
data, data.astype(spec.dtype),
dims=spec.dims, dims=spec.dims,
coords=spec.coords, coords=spec.coords,
attrs=spec.attrs, attrs=spec.attrs,

View File

@ -80,7 +80,7 @@ from batdetect2.targets.transform import (
load_transformation_from_config, load_transformation_from_config,
register_derivation, register_derivation,
) )
from batdetect2.targets.types import Position, Size, TargetProtocol from batdetect2.typing.targets import Position, Size, TargetProtocol
__all__ = [ __all__ = [
"ClassesConfig", "ClassesConfig",
@ -99,7 +99,6 @@ __all__ = [
"TagInfo", "TagInfo",
"TargetClass", "TargetClass",
"TargetConfig", "TargetConfig",
"TargetProtocol",
"Targets", "Targets",
"TermInfo", "TermInfo",
"TransformConfig", "TransformConfig",

View File

@ -14,6 +14,7 @@ from batdetect2.targets.terms import (
default_term_registry, default_term_registry,
get_tag_from_info, get_tag_from_info,
) )
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
__all__ = [ __all__ = [
"DEFAULT_SPECIES_LIST", "DEFAULT_SPECIES_LIST",
@ -27,25 +28,6 @@ __all__ = [
] ]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
An encoder function takes a sound event annotation and returns the string name
of the target class it belongs to, based on a predefined set of rules.
If the annotation does not match any defined target class according to the
rules, the function returns None.
"""
SoundEventDecoder = Callable[[str], List[data.Tag]]
"""Type alias for a sound event class decoder function.
A decoder function takes a class name string (as predicted by the model or
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
that represent that class according to the configuration. This is used to
translate model outputs back into meaningful annotations.
"""
DEFAULT_SPECIES_LIST = [ DEFAULT_SPECIES_LIST = [
"Barbastella barbastellus", "Barbastella barbastellus",
"Eptesicus serotinus", "Eptesicus serotinus",

View File

@ -1,6 +1,6 @@
import logging import logging
from functools import partial from functools import partial
from typing import Callable, List, Literal, Optional, Set from typing import List, Literal, Optional, Set
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -12,11 +12,11 @@ from batdetect2.targets.terms import (
default_term_registry, default_term_registry,
get_tag_from_info, get_tag_from_info,
) )
from batdetect2.typing.targets import SoundEventFilter
__all__ = [ __all__ = [
"FilterConfig", "FilterConfig",
"FilterRule", "FilterRule",
"SoundEventFilter",
"build_sound_event_filter", "build_sound_event_filter",
"build_filter_from_rule", "build_filter_from_rule",
"load_filter_config", "load_filter_config",
@ -24,14 +24,6 @@ __all__ = [
] ]
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
"""Type alias for a filter function.
A filter function accepts a soundevent.data.SoundEventAnnotation object
and returns True if the annotation should be kept based on the filter's
criteria, or False if it should be discarded.
"""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -28,8 +28,8 @@ from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.targets.types import Position, Size from batdetect2.typing.targets import Position, Size
__all__ = [ __all__ = [
"Anchor", "Anchor",

View File

@ -24,7 +24,6 @@ from batdetect2.train.config import (
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
TrainExample,
list_preprocessed_files, list_preprocessed_files,
) )
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
@ -64,7 +63,6 @@ __all__ = [
"RandomExampleSource", "RandomExampleSource",
"SizeLossConfig", "SizeLossConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainExample",
"TrainingConfig", "TrainingConfig",
"TrainingModule", "TrainingModule",
"VolumeAugmentationConfig", "VolumeAugmentationConfig",

View File

@ -33,8 +33,7 @@ from pydantic import Field
from soundevent import arrays, data from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessorProtocol from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.train.types import Augmentation
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width
__all__ = [ __all__ = [

View File

@ -14,16 +14,18 @@ from batdetect2.evaluate.match import (
MatchConfig, MatchConfig,
match_sound_events_and_raw_predictions, match_sound_events_and_raw_predictions,
) )
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import ( from batdetect2.train.dataset import LabeledDataset
BatDetect2Prediction,
PostprocessorProtocol,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.types import ModelOutput from batdetect2.typing import (
BatDetect2Prediction,
MatchEvaluation,
MetricsProtocol,
ModelOutput,
PostprocessorProtocol,
TargetProtocol,
TrainExample,
)
class ValidationMetrics(Callback): class ValidationMetrics(Callback):

View File

@ -6,7 +6,7 @@ from loguru import logger
from soundevent import arrays from soundevent import arrays
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.train.types import ClipperProtocol from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.513 DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1

View File

@ -4,11 +4,8 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import BackboneConfig from batdetect2.models import ModelConfig
from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG, DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig, AugmentationsConfig,
@ -85,16 +82,10 @@ def load_train_config(
return load_config(path, schema=TrainingConfig, field=field) return load_config(path, schema=TrainingConfig, field=field)
class FullTrainingConfig(BaseConfig): class FullTrainingConfig(ModelConfig):
"""Full training configuration.""" """Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig) train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)

View File

@ -8,7 +8,7 @@ from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation from batdetect2.train.augmentations import Augmentation
from batdetect2.train.types import ClipperProtocol, TrainExample from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.utils.tensors import adjust_width from batdetect2.utils.tensors import adjust_width
__all__ = [ __all__ = [

View File

@ -34,10 +34,10 @@ from scipy.ndimage import gaussian_filter
from soundevent import arrays, data from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.types import TargetProtocol from batdetect2.typing import (
from batdetect2.train.types import (
ClipLabeller, ClipLabeller,
Heatmaps, Heatmaps,
TargetProtocol,
) )
__all__ = [ __all__ = [

View File

@ -3,12 +3,8 @@ import torch
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import ModelOutput from batdetect2.models import Model
from batdetect2.models.types import DetectionModel from batdetect2.typing import ModelOutput, TrainExample
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample
__all__ = [ __all__ = [
"TrainingModule", "TrainingModule",
@ -18,11 +14,8 @@ __all__ = [
class TrainingModule(L.LightningModule): class TrainingModule(L.LightningModule):
def __init__( def __init__(
self, self,
detector: DetectionModel, model: Model,
loss: torch.nn.Module, loss: torch.nn.Module,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001, learning_rate: float = 0.001,
t_max: int = 100, t_max: int = 100,
): ):
@ -32,18 +25,14 @@ class TrainingModule(L.LightningModule):
self.t_max = t_max self.t_max = t_max
self.loss = loss self.loss = loss
self.targets = targets self.model = model
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.save_hyperparameters(logger=False) self.save_hyperparameters(logger=False)
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec) return self.model(spec)
def training_step(self, batch: TrainExample): def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec) outputs = self.model(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True) self.log("detection_loss/train", losses.total, logger=True)
@ -56,7 +45,7 @@ class TrainingModule(L.LightningModule):
batch: TrainExample, batch: TrainExample,
batch_idx: int, batch_idx: int,
) -> ModelOutput: ) -> ModelOutput:
outputs = self.forward(batch.spec) outputs = self.model(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True) self.log("detection_loss/val", losses.total, logger=True)

View File

@ -28,9 +28,7 @@ from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.types import ModelOutput from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
from batdetect2.train.dataset import TrainExample
from batdetect2.train.types import Losses, LossProtocol
__all__ = [ __all__ = [
"BBoxLoss", "BBoxLoss",

View File

@ -34,10 +34,9 @@ from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import TargetConfig, build_targets from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.train.types import ClipLabeller from batdetect2.typing import ClipLabeller, PreprocessorProtocol
__all__ = [ __all__ = [
"preprocess_annotations", "preprocess_annotations",

View File

@ -14,12 +14,6 @@ from batdetect2.evaluate.metrics import (
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import build_model from batdetect2.models import build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import (
PreprocessorProtocol,
build_preprocessor,
)
from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
@ -32,6 +26,7 @@ from batdetect2.train.dataset import (
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss from batdetect2.train.losses import build_loss
from batdetect2.typing import PreprocessorProtocol, TargetProtocol
__all__ = [ __all__ = [
"build_train_dataset", "build_train_dataset",
@ -88,25 +83,11 @@ def train(
def build_training_module(config: FullTrainingConfig) -> TrainingModule: def build_training_module(config: FullTrainingConfig) -> TrainingModule:
targets = build_targets(config=config.targets) model = build_model(config=config)
loss = build_loss(config=config.train.loss) loss = build_loss(config=config.train.loss)
preprocessor = build_preprocessor(config.preprocess)
postprocessor = build_postprocessor(
targets,
config=config.postprocess,
max_freq=preprocessor.max_freq,
min_freq=preprocessor.min_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=config.model,
)
return TrainingModule( return TrainingModule(
detector=model, model=model,
loss=loss, loss=loss,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=targets,
learning_rate=config.train.learning_rate, learning_rate=config.train.learning_rate,
t_max=config.train.t_max, t_max=config.train.t_max,
) )

View File

@ -0,0 +1,58 @@
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
GeometryDecoder,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
)
from batdetect2.typing.targets import (
Position,
Size,
SoundEventDecoder,
SoundEventEncoder,
SoundEventFilter,
TargetProtocol,
)
from batdetect2.typing.train import (
Augmentation,
ClipLabeller,
ClipperProtocol,
Heatmaps,
Losses,
LossProtocol,
TrainExample,
)
__all__ = [
"AudioLoader",
"Augmentation",
"BackboneModel",
"BatDetect2Prediction",
"ClipLabeller",
"ClipperProtocol",
"DetectionModel",
"GeometryDecoder",
"Heatmaps",
"LossProtocol",
"Losses",
"MatchEvaluation",
"MetricsProtocol",
"ModelOutput",
"Position",
"PostprocessorProtocol",
"PreprocessorProtocol",
"RawPrediction",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SpectrogramBuilder",
"TargetProtocol",
"TrainExample",
]

View File

@ -19,7 +19,6 @@ from abc import ABC, abstractmethod
from typing import NamedTuple from typing import NamedTuple
import torch import torch
import torch.nn as nn
__all__ = [ __all__ = [
"ModelOutput", "ModelOutput",
@ -65,7 +64,7 @@ class ModelOutput(NamedTuple):
features: torch.Tensor features: torch.Tensor
class BackboneModel(ABC, nn.Module): class BackboneModel(ABC, torch.nn.Module):
"""Abstract Base Class for generic feature extraction backbone models. """Abstract Base Class for generic feature extraction backbone models.
Defines the minimal interface for a feature extractor network within a Defines the minimal interface for a feature extractor network within a
@ -191,7 +190,7 @@ class EncoderDecoderModel(BackboneModel):
... ...
class DetectionModel(ABC, nn.Module): class DetectionModel(ABC, torch.nn.Module):
"""Abstract Base Class for complete BatDetect2 detection models. """Abstract Base Class for complete BatDetect2 detection models.
Defines the interface for the overall model that takes an input spectrogram Defines the interface for the overall model that takes an input spectrogram

View File

@ -18,8 +18,8 @@ import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.models.types import ModelOutput from batdetect2.typing.models import ModelOutput
from batdetect2.targets.types import Position, Size from batdetect2.typing.targets import Position, Size
__all__ = [ __all__ = [
"RawPrediction", "RawPrediction",

View File

@ -16,6 +16,12 @@ import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
__all__ = [
"AudioLoader",
"SpectrogramBuilder",
"PreprocessorProtocol",
]
class AudioLoader(Protocol): class AudioLoader(Protocol):
"""Defines the interface for an audio loading and processing component. """Defines the interface for an audio loading and processing component.

View File

@ -12,6 +12,7 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2. throughout BatDetect2.
""" """
from collections.abc import Callable
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
import numpy as np import numpy as np
@ -19,10 +20,40 @@ from soundevent import data
__all__ = [ __all__ = [
"TargetProtocol", "TargetProtocol",
"SoundEventEncoder",
"SoundEventDecoder",
"SoundEventFilter",
"Position", "Position",
"Size", "Size",
] ]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
An encoder function takes a sound event annotation and returns the string name
of the target class it belongs to, based on a predefined set of rules.
If the annotation does not match any defined target class according to the
rules, the function returns None.
"""
SoundEventDecoder = Callable[[str], List[data.Tag]]
"""Type alias for a sound event class decoder function.
A decoder function takes a class name string (as predicted by the model or
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
that represent that class according to the configuration. This is used to
translate model outputs back into meaningful annotations.
"""
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
"""Type alias for a filter function.
A filter function accepts a soundevent.data.SoundEventAnnotation object
and returns True if the annotation should be kept based on the filter's
criteria, or False if it should be discarded.
"""
Position = tuple[float, float] Position = tuple[float, float]
"""A tuple representing (time, frequency) coordinates.""" """A tuple representing (time, frequency) coordinates."""

View File

@ -4,13 +4,15 @@ import torch
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.models import ModelOutput from batdetect2.typing.models import ModelOutput
__all__ = [ __all__ = [
"Heatmaps",
"ClipLabeller",
"Augmentation", "Augmentation",
"ClipLabeller",
"ClipperProtocol",
"Heatmaps",
"LossProtocol", "LossProtocol",
"Losses",
"TrainExample", "TrainExample",
] ]

View File

@ -12,7 +12,6 @@ from soundevent import data, terms
from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import ( from batdetect2.targets import (
TargetConfig, TargetConfig,
TermRegistry, TermRegistry,
@ -22,9 +21,12 @@ from batdetect2.targets import (
from batdetect2.targets.classes import ClassesConfig, TargetClass from batdetect2.targets.classes import ClassesConfig, TargetClass
from batdetect2.targets.filtering import FilterConfig, FilterRule from batdetect2.targets.filtering import FilterConfig, FilterRule
from batdetect2.targets.terms import TagInfo from batdetect2.targets.terms import TagInfo
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.types import ClipLabeller from batdetect2.typing import (
ClipLabeller,
PreprocessorProtocol,
TargetProtocol,
)
@pytest.fixture @pytest.fixture

View File

@ -15,8 +15,7 @@ from batdetect2.postprocess.decoding import (
get_generic_tags, get_generic_tags,
get_prediction_features, get_prediction_features,
) )
from batdetect2.postprocess.types import RawPrediction from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.targets.types import TargetProtocol
@pytest.fixture @pytest.fixture

View File

@ -142,7 +142,7 @@ def test_audio_config_defaults():
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
assert config.resample.method == "poly" assert config.resample.method == "poly"
assert config.scale == audio.SCALE_RAW_AUDIO assert config.scale == audio.SCALE_RAW_AUDIO
assert config.center is True assert config.center is False
assert config.duration == audio.DEFAULT_DURATION assert config.duration == audio.DEFAULT_DURATION

View File

@ -108,7 +108,7 @@ def test_spec_size_config_defaults():
def test_pcen_config_defaults(): def test_pcen_config_defaults():
config = PcenConfig() config = PcenConfig()
assert config.time_constant == 0.4 assert config.time_constant == 0.01
assert config.gain == 0.98 assert config.gain == 0.98
assert config.bias == 2 assert config.bias == 2
assert config.power == 0.5 assert config.power == 0.5
@ -202,13 +202,6 @@ def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
def test_apply_pcen(sample_spec: xr.DataArray): def test_apply_pcen(sample_spec: xr.DataArray):
if "original_samplerate" not in sample_spec.attrs:
sample_spec.attrs["original_samplerate"] = SAMPLERATE
if "nfft" not in sample_spec.attrs:
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
if "noverlap" not in sample_spec.attrs:
sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"])
pcen_config = PcenConfig() pcen_config = PcenConfig()
pcen_spec = apply_pcen( pcen_spec = apply_pcen(
sample_spec, sample_spec,

View File

@ -5,14 +5,13 @@ import pytest
import xarray as xr import xarray as xr
from soundevent import arrays, data from soundevent import arrays, data
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
add_echo, add_echo,
mix_examples, mix_examples,
) )
from batdetect2.train.clips import select_subclip from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
from batdetect2.train.types import ClipLabeller from batdetect2.typing import ClipLabeller, PreprocessorProtocol
def test_mix_examples( def test_mix_examples(

View File

@ -6,7 +6,7 @@ from soundevent import data
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig from batdetect2.targets.rois import AnchorBBoxMapperConfig
from batdetect2.targets.terms import TagInfo, TermRegistry from batdetect2.targets.terms import TagInfo
from batdetect2.train.labels import generate_heatmaps from batdetect2.train.labels import generate_heatmaps
recording = data.Recording( recording = data.Recording(

View File

@ -28,8 +28,8 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
recovered = TrainingModule.load_from_checkpoint(path) recovered = TrainingModule.load_from_checkpoint(path)
spec1 = module.preprocessor.preprocess_clip(clip) spec1 = module.model.preprocessor.preprocess_clip(clip)
spec2 = recovered.preprocessor.preprocess_clip(clip) spec2 = recovered.model.preprocessor.preprocess_clip(clip)
xr.testing.assert_equal(spec1, spec2) xr.testing.assert_equal(spec1, spec2)

View File

@ -4,12 +4,12 @@ import xarray as xr
from soundevent import data from soundevent import data
from soundevent.terms import get_term from soundevent.terms import get_term
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess import build_postprocessor, load_postprocess_config from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config from batdetect2.targets import build_targets, load_target_config
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ModelOutput
@pytest.fixture @pytest.fixture