Moved types to dedicated module

This commit is contained in:
mbsantiago 2025-08-24 10:55:48 +01:00
parent 02adc19070
commit 61115d562c
47 changed files with 280 additions and 238 deletions

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
from soundevent import data
from batdetect2.data.datasets import Dataset
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events(

View File

@ -7,7 +7,7 @@ from batdetect2.data.summary import (
extract_recordings_df,
extract_sound_events_df,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.targets import TargetProtocol
def split_dataset_by_recordings(

View File

@ -2,7 +2,7 @@ import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.data.datasets import Dataset
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"extract_recordings_df",

View File

@ -4,15 +4,15 @@ from typing import List, Literal, Optional, Tuple
import numpy as np
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.evaluation import (
match_geometries as optimal_match,
)
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing import (
BatDetect2Prediction,
MatchEvaluation,
TargetProtocol,
)
MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""

View File

@ -4,7 +4,7 @@ import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.typing import MatchEvaluation, MetricsProtocol
__all__ = ["DetectionAveragePrecision"]

View File

@ -28,8 +28,11 @@ provided here.
from typing import Optional
from loguru import logger
import torch
from lightning import LightningModule
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
@ -53,24 +56,25 @@ from batdetect2.models.decoder import (
DecoderConfig,
build_decoder,
)
from batdetect2.models.detectors import (
Detector,
build_detector,
)
from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel, ModelOutput
from batdetect2.typing.postprocess import PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BBoxHead",
"Backbone",
"BackboneConfig",
"BackboneModel",
"BackboneModel",
"Bottleneck",
"BottleneckConfig",
"ClassifierHead",
@ -78,65 +82,75 @@ __all__ = [
"DEFAULT_DECODER_CONFIG",
"DEFAULT_ENCODER_CONFIG",
"DecoderConfig",
"DetectionModel",
"Detector",
"DetectorHead",
"EncoderConfig",
"FreqCoordConvDownConfig",
"FreqCoordConvUpConfig",
"ModelOutput",
"StandardConvDownConfig",
"StandardConvUpConfig",
"build_backbone",
"build_bottleneck",
"build_decoder",
"build_detector",
"build_encoder",
"build_model",
"build_detector",
"load_backbone_config",
"Model",
"ModelConfig",
"build_model",
]
def build_model(
num_classes: int,
config: Optional[BackboneConfig] = None,
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
class Model(LightningModule):
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
targets: TargetProtocol
This high-level factory function constructs the standard BatDetect2 model
architecture. It first builds the feature extraction backbone (typically an
encoder-bottleneck-decoder structure) based on the provided
`BackboneConfig` (or defaults if None), and then attaches the standard
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
`build_detector` function.
def __init__(
self,
detector: DetectionModel,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
Returns
-------
DetectionModel
An initialized `Detector` model instance.
Raises
------
ValueError
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
def build_model(config: Optional[ModelConfig] = None):
config = config or ModelConfig()
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
targets=targets,
config=config.postprocess,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config.model,
)
return Model(
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
)
backbone = build_backbone(config)
return build_detector(num_classes, backbone)

View File

@ -43,7 +43,7 @@ from batdetect2.models.encoder import (
EncoderConfig,
build_encoder,
)
from batdetect2.models.types import BackboneModel
from batdetect2.typing.models import BackboneModel
__all__ = [
"Backbone",

View File

@ -8,18 +8,25 @@ classifying them.
The primary components are:
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
- `build_detector`: A factory function to conveniently construct a standard
`Detector` instance given a backbone and the number of target classes.
This module focuses purely on the neural network architecture definition. The
logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
"""
import torch
from typing import Optional
import torch
from loguru import logger
from batdetect2.models.backbones import BackboneConfig, build_backbone
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
__all__ = [
"Detector",
"build_detector",
]
class Detector(DetectionModel):
@ -119,36 +126,41 @@ class Detector(DetectionModel):
)
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
"""Factory function to build a standard Detector model instance.
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
`BBoxHead`) configured appropriately based on the output channels of the
provided `backbone` and the specified `num_classes`. It then assembles
these components into a `Detector` model.
def build_detector(
num_classes: int, config: Optional[BackboneConfig] = None
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
Parameters
----------
num_classes : int
The number of specific target classes for the classification head
(excluding any implicit background class). Must be positive.
backbone : BackboneModel
An initialized feature extraction backbone module instance. The number
of output channels from this backbone (`backbone.out_channels`) is used
to configure the input channels for the prediction heads.
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Returns
-------
Detector
DetectionModel
An initialized `Detector` model instance.
Raises
------
ValueError
If `num_classes` is not positive.
AttributeError
If `backbone` does not have the required `out_channels` attribute.
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead(
num_classes=num_classes,
in_channels=backbone.out_channels,

View File

@ -4,7 +4,7 @@ from matplotlib.axes import Axes
from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_annotation",

View File

@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_prediction",

View File

@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
import pandas as pd
from batdetect2 import plotting
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import PreprocessorProtocol
@dataclass

View File

@ -6,13 +6,13 @@ from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
from batdetect2.typing.evaluate import MatchEvaluation
__all__ = [
"plot_matches",

View File

@ -36,7 +36,6 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
@ -62,13 +61,14 @@ from batdetect2.postprocess.remapping import (
features_to_xarray,
sizes_to_xarray,
)
from batdetect2.postprocess.types import (
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
@ -79,8 +79,6 @@ __all__ = [
"NMS_KERNEL_SIZE",
"PostprocessConfig",
"Postprocessor",
"PostprocessorProtocol",
"RawPrediction",
"TOP_K_PER_SEC",
"build_postprocessor",
"classification_to_xarray",

View File

@ -32,8 +32,8 @@ import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"convert_xr_dataset_to_raw_prediction",

View File

@ -57,7 +57,7 @@ from batdetect2.preprocess.spectrogram import (
build_spectrogram_builder,
get_spectrogram_resolution,
)
from batdetect2.preprocess.types import (
from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
@ -65,7 +65,6 @@ from batdetect2.preprocess.types import (
__all__ = [
"AudioConfig",
"AudioLoader",
"ConfigurableSpectrogramBuilder",
"DEFAULT_DURATION",
"FrequencyConfig",
@ -77,7 +76,6 @@ __all__ = [
"SCALE_RAW_AUDIO",
"STFTConfig",
"SpecSizeConfig",
"SpectrogramBuilder",
"SpectrogramConfig",
"StandardPreprocessor",
"TARGET_SAMPLERATE_HZ",

View File

@ -32,7 +32,7 @@ from soundevent.arrays import operations as ops
from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.types import AudioLoader
from batdetect2.typing.preprocess import AudioLoader
__all__ = [
"ResampleConfig",

View File

@ -31,7 +31,7 @@ from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.audio import convert_to_xr
from batdetect2.preprocess.types import SpectrogramBuilder
from batdetect2.typing.preprocess import SpectrogramBuilder
__all__ = [
"STFTConfig",
@ -540,7 +540,6 @@ def apply_pcen(
xr.DataArray
PCEN-scaled spectrogram.
"""
spec = spec * (2**31)
samplerate = 1 / spec.time.attrs["step"]
hop_size = spec.attrs["hop_size"]
@ -559,22 +558,24 @@ def apply_pcen(
[1, smoothing_constant - 1],
)[:]
spec_data = spec.data * (2**31)
# Smooth the input array along the given axis
smoothed, _ = signal.lfilter(
[smoothing_constant],
[1, smoothing_constant - 1],
spec.data,
spec_data,
zi=zi,
axis=axis, # type: ignore
)
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
data = (bias**power) * np.expm1(
power * np.log1p(spec.data * smooth / bias)
power * np.log1p(spec_data * smooth / bias)
)
return xr.DataArray(
data,
data.astype(spec.dtype),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,

View File

@ -80,7 +80,7 @@ from batdetect2.targets.transform import (
load_transformation_from_config,
register_derivation,
)
from batdetect2.targets.types import Position, Size, TargetProtocol
from batdetect2.typing.targets import Position, Size, TargetProtocol
__all__ = [
"ClassesConfig",
@ -99,7 +99,6 @@ __all__ = [
"TagInfo",
"TargetClass",
"TargetConfig",
"TargetProtocol",
"Targets",
"TermInfo",
"TransformConfig",

View File

@ -14,6 +14,7 @@ from batdetect2.targets.terms import (
default_term_registry,
get_tag_from_info,
)
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
__all__ = [
"DEFAULT_SPECIES_LIST",
@ -27,25 +28,6 @@ __all__ = [
]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
An encoder function takes a sound event annotation and returns the string name
of the target class it belongs to, based on a predefined set of rules.
If the annotation does not match any defined target class according to the
rules, the function returns None.
"""
SoundEventDecoder = Callable[[str], List[data.Tag]]
"""Type alias for a sound event class decoder function.
A decoder function takes a class name string (as predicted by the model or
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
that represent that class according to the configuration. This is used to
translate model outputs back into meaningful annotations.
"""
DEFAULT_SPECIES_LIST = [
"Barbastella barbastellus",
"Eptesicus serotinus",

View File

@ -1,6 +1,6 @@
import logging
from functools import partial
from typing import Callable, List, Literal, Optional, Set
from typing import List, Literal, Optional, Set
from pydantic import Field
from soundevent import data
@ -12,11 +12,11 @@ from batdetect2.targets.terms import (
default_term_registry,
get_tag_from_info,
)
from batdetect2.typing.targets import SoundEventFilter
__all__ = [
"FilterConfig",
"FilterRule",
"SoundEventFilter",
"build_sound_event_filter",
"build_filter_from_rule",
"load_filter_config",
@ -24,14 +24,6 @@ __all__ = [
]
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
"""Type alias for a filter function.
A filter function accepts a soundevent.data.SoundEventAnnotation object
and returns True if the annotation should be kept based on the filter's
criteria, or False if it should be discarded.
"""
logger = logging.getLogger(__name__)

View File

@ -28,8 +28,8 @@ from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import Position, Size
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import Position, Size
__all__ = [
"Anchor",

View File

@ -24,7 +24,6 @@ from batdetect2.train.config import (
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
TrainExample,
list_preprocessed_files,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
@ -64,7 +63,6 @@ __all__ = [
"RandomExampleSource",
"SizeLossConfig",
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainingConfig",
"TrainingModule",
"VolumeAugmentationConfig",

View File

@ -33,8 +33,7 @@ from pydantic import Field
from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.train.types import Augmentation
from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.utils.arrays import adjust_width
__all__ = [

View File

@ -14,16 +14,18 @@ from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.dataset import LabeledDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.types import ModelOutput
from batdetect2.typing import (
BatDetect2Prediction,
MatchEvaluation,
MetricsProtocol,
ModelOutput,
PostprocessorProtocol,
TargetProtocol,
TrainExample,
)
class ValidationMetrics(Callback):

View File

@ -6,7 +6,7 @@ from loguru import logger
from soundevent import arrays
from batdetect2.configs import BaseConfig
from batdetect2.train.types import ClipperProtocol
from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1

View File

@ -4,11 +4,8 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.targets import TargetConfig
from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
@ -85,16 +82,10 @@ def load_train_config(
return load_config(path, schema=TrainingConfig, field=field)
class FullTrainingConfig(BaseConfig):
class FullTrainingConfig(ModelConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)

View File

@ -8,7 +8,7 @@ from soundevent import data
from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.types import ClipperProtocol, TrainExample
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.utils.tensors import adjust_width
__all__ = [

View File

@ -34,10 +34,10 @@ from scipy.ndimage import gaussian_filter
from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.types import (
from batdetect2.typing import (
ClipLabeller,
Heatmaps,
TargetProtocol,
)
__all__ = [

View File

@ -3,12 +3,8 @@ import torch
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import ModelOutput
from batdetect2.models.types import DetectionModel
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample
from batdetect2.models import Model
from batdetect2.typing import ModelOutput, TrainExample
__all__ = [
"TrainingModule",
@ -18,11 +14,8 @@ __all__ = [
class TrainingModule(L.LightningModule):
def __init__(
self,
detector: DetectionModel,
model: Model,
loss: torch.nn.Module,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
@ -32,18 +25,14 @@ class TrainingModule(L.LightningModule):
self.t_max = t_max
self.loss = loss
self.targets = targets
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.model = model
self.save_hyperparameters(logger=False)
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
return self.model(spec)
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
outputs = self.model(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True)
@ -56,7 +45,7 @@ class TrainingModule(L.LightningModule):
batch: TrainExample,
batch_idx: int,
) -> ModelOutput:
outputs = self.forward(batch.spec)
outputs = self.model(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True)

View File

@ -28,9 +28,7 @@ from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.types import ModelOutput
from batdetect2.train.dataset import TrainExample
from batdetect2.train.types import Losses, LossProtocol
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
__all__ = [
"BBoxLoss",

View File

@ -34,10 +34,9 @@ from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.train.types import ClipLabeller
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
__all__ = [
"preprocess_annotations",

View File

@ -14,12 +14,6 @@ from batdetect2.evaluate.metrics import (
DetectionAveragePrecision,
)
from batdetect2.models import build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import (
PreprocessorProtocol,
build_preprocessor,
)
from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
@ -32,6 +26,7 @@ from batdetect2.train.dataset import (
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
from batdetect2.typing import PreprocessorProtocol, TargetProtocol
__all__ = [
"build_train_dataset",
@ -88,25 +83,11 @@ def train(
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
targets = build_targets(config=config.targets)
model = build_model(config=config)
loss = build_loss(config=config.train.loss)
preprocessor = build_preprocessor(config.preprocess)
postprocessor = build_postprocessor(
targets,
config=config.postprocess,
max_freq=preprocessor.max_freq,
min_freq=preprocessor.min_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=config.model,
)
return TrainingModule(
detector=model,
model=model,
loss=loss,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=targets,
learning_rate=config.train.learning_rate,
t_max=config.train.t_max,
)

View File

@ -0,0 +1,58 @@
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
GeometryDecoder,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
)
from batdetect2.typing.targets import (
Position,
Size,
SoundEventDecoder,
SoundEventEncoder,
SoundEventFilter,
TargetProtocol,
)
from batdetect2.typing.train import (
Augmentation,
ClipLabeller,
ClipperProtocol,
Heatmaps,
Losses,
LossProtocol,
TrainExample,
)
__all__ = [
"AudioLoader",
"Augmentation",
"BackboneModel",
"BatDetect2Prediction",
"ClipLabeller",
"ClipperProtocol",
"DetectionModel",
"GeometryDecoder",
"Heatmaps",
"LossProtocol",
"Losses",
"MatchEvaluation",
"MetricsProtocol",
"ModelOutput",
"Position",
"PostprocessorProtocol",
"PreprocessorProtocol",
"RawPrediction",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SpectrogramBuilder",
"TargetProtocol",
"TrainExample",
]

View File

@ -19,7 +19,6 @@ from abc import ABC, abstractmethod
from typing import NamedTuple
import torch
import torch.nn as nn
__all__ = [
"ModelOutput",
@ -65,7 +64,7 @@ class ModelOutput(NamedTuple):
features: torch.Tensor
class BackboneModel(ABC, nn.Module):
class BackboneModel(ABC, torch.nn.Module):
"""Abstract Base Class for generic feature extraction backbone models.
Defines the minimal interface for a feature extractor network within a
@ -191,7 +190,7 @@ class EncoderDecoderModel(BackboneModel):
...
class DetectionModel(ABC, nn.Module):
class DetectionModel(ABC, torch.nn.Module):
"""Abstract Base Class for complete BatDetect2 detection models.
Defines the interface for the overall model that takes an input spectrogram

View File

@ -18,8 +18,8 @@ import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.models.types import ModelOutput
from batdetect2.targets.types import Position, Size
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.targets import Position, Size
__all__ = [
"RawPrediction",

View File

@ -16,6 +16,12 @@ import numpy as np
import xarray as xr
from soundevent import data
__all__ = [
"AudioLoader",
"SpectrogramBuilder",
"PreprocessorProtocol",
]
class AudioLoader(Protocol):
"""Defines the interface for an audio loading and processing component.

View File

@ -12,6 +12,7 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2.
"""
from collections.abc import Callable
from typing import List, Optional, Protocol
import numpy as np
@ -19,10 +20,40 @@ from soundevent import data
__all__ = [
"TargetProtocol",
"SoundEventEncoder",
"SoundEventDecoder",
"SoundEventFilter",
"Position",
"Size",
]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
An encoder function takes a sound event annotation and returns the string name
of the target class it belongs to, based on a predefined set of rules.
If the annotation does not match any defined target class according to the
rules, the function returns None.
"""
SoundEventDecoder = Callable[[str], List[data.Tag]]
"""Type alias for a sound event class decoder function.
A decoder function takes a class name string (as predicted by the model or
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
that represent that class according to the configuration. This is used to
translate model outputs back into meaningful annotations.
"""
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
"""Type alias for a filter function.
A filter function accepts a soundevent.data.SoundEventAnnotation object
and returns True if the annotation should be kept based on the filter's
criteria, or False if it should be discarded.
"""
Position = tuple[float, float]
"""A tuple representing (time, frequency) coordinates."""

View File

@ -4,13 +4,15 @@ import torch
import xarray as xr
from soundevent import data
from batdetect2.models import ModelOutput
from batdetect2.typing.models import ModelOutput
__all__ = [
"Heatmaps",
"ClipLabeller",
"Augmentation",
"ClipLabeller",
"ClipperProtocol",
"Heatmaps",
"LossProtocol",
"Losses",
"TrainExample",
]

View File

@ -12,7 +12,6 @@ from soundevent import data, terms
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import (
TargetConfig,
TermRegistry,
@ -22,9 +21,12 @@ from batdetect2.targets import (
from batdetect2.targets.classes import ClassesConfig, TargetClass
from batdetect2.targets.filtering import FilterConfig, FilterRule
from batdetect2.targets.terms import TagInfo
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.types import ClipLabeller
from batdetect2.typing import (
ClipLabeller,
PreprocessorProtocol,
TargetProtocol,
)
@pytest.fixture

View File

@ -15,8 +15,7 @@ from batdetect2.postprocess.decoding import (
get_generic_tags,
get_prediction_features,
)
from batdetect2.postprocess.types import RawPrediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing import RawPrediction, TargetProtocol
@pytest.fixture

View File

@ -142,7 +142,7 @@ def test_audio_config_defaults():
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
assert config.resample.method == "poly"
assert config.scale == audio.SCALE_RAW_AUDIO
assert config.center is True
assert config.center is False
assert config.duration == audio.DEFAULT_DURATION

View File

@ -108,7 +108,7 @@ def test_spec_size_config_defaults():
def test_pcen_config_defaults():
config = PcenConfig()
assert config.time_constant == 0.4
assert config.time_constant == 0.01
assert config.gain == 0.98
assert config.bias == 2
assert config.power == 0.5
@ -202,13 +202,6 @@ def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
def test_apply_pcen(sample_spec: xr.DataArray):
if "original_samplerate" not in sample_spec.attrs:
sample_spec.attrs["original_samplerate"] = SAMPLERATE
if "nfft" not in sample_spec.attrs:
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
if "noverlap" not in sample_spec.attrs:
sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"])
pcen_config = PcenConfig()
pcen_spec = apply_pcen(
sample_spec,

View File

@ -5,14 +5,13 @@ import pytest
import xarray as xr
from soundevent import arrays, data
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import (
add_echo,
mix_examples,
)
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example
from batdetect2.train.types import ClipLabeller
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
def test_mix_examples(

View File

@ -6,7 +6,7 @@ from soundevent import data
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig
from batdetect2.targets.terms import TagInfo, TermRegistry
from batdetect2.targets.terms import TagInfo
from batdetect2.train.labels import generate_heatmaps
recording = data.Recording(

View File

@ -28,8 +28,8 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
recovered = TrainingModule.load_from_checkpoint(path)
spec1 = module.preprocessor.preprocess_clip(clip)
spec2 = recovered.preprocessor.preprocess_clip(clip)
spec1 = module.model.preprocessor.preprocess_clip(clip)
spec2 = recovered.model.preprocessor.preprocess_clip(clip)
xr.testing.assert_equal(spec1, spec2)

View File

@ -4,12 +4,12 @@ import xarray as xr
from soundevent import data
from soundevent.terms import get_term
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ModelOutput
@pytest.fixture