diff --git a/src/batdetect2/data/iterators.py b/src/batdetect2/data/iterators.py index 289f7ce..08c4411 100644 --- a/src/batdetect2/data/iterators.py +++ b/src/batdetect2/data/iterators.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from soundevent import data from batdetect2.data.datasets import Dataset -from batdetect2.targets.types import TargetProtocol +from batdetect2.typing.targets import TargetProtocol def iterate_over_sound_events( diff --git a/src/batdetect2/data/split.py b/src/batdetect2/data/split.py index aaa1c5e..adf426d 100644 --- a/src/batdetect2/data/split.py +++ b/src/batdetect2/data/split.py @@ -7,7 +7,7 @@ from batdetect2.data.summary import ( extract_recordings_df, extract_sound_events_df, ) -from batdetect2.targets.types import TargetProtocol +from batdetect2.typing.targets import TargetProtocol def split_dataset_by_recordings( diff --git a/src/batdetect2/data/summary.py b/src/batdetect2/data/summary.py index 713520b..f7828b5 100644 --- a/src/batdetect2/data/summary.py +++ b/src/batdetect2/data/summary.py @@ -2,7 +2,7 @@ import pandas as pd from soundevent.geometry import compute_bounds from batdetect2.data.datasets import Dataset -from batdetect2.targets.types import TargetProtocol +from batdetect2.typing.targets import TargetProtocol __all__ = [ "extract_recordings_df", diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 6c80cf8..b911e67 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -4,15 +4,15 @@ from typing import List, Literal, Optional, Tuple import numpy as np from soundevent import data from soundevent.evaluation import compute_affinity -from soundevent.evaluation import ( - match_geometries as optimal_match, -) +from soundevent.evaluation import match_geometries as optimal_match from soundevent.geometry import compute_bounds from batdetect2.configs import BaseConfig -from batdetect2.evaluate.types import MatchEvaluation -from batdetect2.postprocess.types import BatDetect2Prediction -from batdetect2.targets.types import TargetProtocol +from batdetect2.typing import ( + BatDetect2Prediction, + MatchEvaluation, + TargetProtocol, +) MatchingStrategy = Literal["greedy", "optimal"] """The type of matching algorithm to use: 'greedy' or 'optimal'.""" diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index 7b1c933..c5df1d0 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -4,7 +4,7 @@ import pandas as pd from sklearn import metrics from sklearn.preprocessing import label_binarize -from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol +from batdetect2.typing import MatchEvaluation, MetricsProtocol __all__ = ["DetectionAveragePrecision"] diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 7d6fb9f..471fd0b 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -28,8 +28,11 @@ provided here. from typing import Optional -from loguru import logger +import torch +from lightning import LightningModule +from pydantic import Field +from batdetect2.configs import BaseConfig from batdetect2.models.backbones import ( Backbone, BackboneConfig, @@ -53,24 +56,25 @@ from batdetect2.models.decoder import ( DecoderConfig, build_decoder, ) -from batdetect2.models.detectors import ( - Detector, - build_detector, -) +from batdetect2.models.detectors import Detector, build_detector from batdetect2.models.encoder import ( DEFAULT_ENCODER_CONFIG, EncoderConfig, build_encoder, ) from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead -from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput +from batdetect2.postprocess import PostprocessConfig, build_postprocessor +from batdetect2.preprocess import PreprocessingConfig, build_preprocessor +from batdetect2.targets import TargetConfig, build_targets +from batdetect2.typing.models import DetectionModel, ModelOutput +from batdetect2.typing.postprocess import PostprocessorProtocol +from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.typing.targets import TargetProtocol __all__ = [ "BBoxHead", "Backbone", "BackboneConfig", - "BackboneModel", - "BackboneModel", "Bottleneck", "BottleneckConfig", "ClassifierHead", @@ -78,65 +82,75 @@ __all__ = [ "DEFAULT_DECODER_CONFIG", "DEFAULT_ENCODER_CONFIG", "DecoderConfig", - "DetectionModel", "Detector", "DetectorHead", "EncoderConfig", "FreqCoordConvDownConfig", "FreqCoordConvUpConfig", - "ModelOutput", "StandardConvDownConfig", "StandardConvUpConfig", "build_backbone", "build_bottleneck", "build_decoder", - "build_detector", "build_encoder", - "build_model", + "build_detector", "load_backbone_config", + "Model", + "ModelConfig", + "build_model", ] -def build_model( - num_classes: int, - config: Optional[BackboneConfig] = None, -) -> DetectionModel: - """Build the complete BatDetect2 detection model. +class Model(LightningModule): + detector: DetectionModel + preprocessor: PreprocessorProtocol + postprocessor: PostprocessorProtocol + targets: TargetProtocol - This high-level factory function constructs the standard BatDetect2 model - architecture. It first builds the feature extraction backbone (typically an - encoder-bottleneck-decoder structure) based on the provided - `BackboneConfig` (or defaults if None), and then attaches the standard - prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the - `build_detector` function. + def __init__( + self, + detector: DetectionModel, + preprocessor: PreprocessorProtocol, + postprocessor: PostprocessorProtocol, + targets: TargetProtocol, + ): + super().__init__() + self.detector = detector + self.preprocessor = preprocessor + self.postprocessor = postprocessor + self.targets = targets - Parameters - ---------- - num_classes : int - The number of specific target classes the model should predict - (required for the `ClassifierHead`). Must be positive. - config : BackboneConfig, optional - Configuration object specifying the architecture of the backbone - (encoder, bottleneck, decoder). If None, default configurations defined - within the respective builder functions (`build_encoder`, etc.) will be - used to construct a default backbone architecture. + def forward(self, spec: torch.Tensor) -> ModelOutput: + return self.detector(spec) - Returns - ------- - DetectionModel - An initialized `Detector` model instance. - Raises - ------ - ValueError - If `num_classes` is not positive, or if errors occur during the - construction of the backbone or detector components (e.g., incompatible - configurations, invalid parameters). - """ - config = config or BackboneConfig() - logger.opt(lazy=True).debug( - "Building model with config: \n{}", - lambda: config.to_yaml_string(), +class ModelConfig(BaseConfig): + model: BackboneConfig = Field(default_factory=BackboneConfig) + preprocess: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + + +def build_model(config: Optional[ModelConfig] = None): + config = config or ModelConfig() + + targets = build_targets(config=config.targets) + preprocessor = build_preprocessor(config=config.preprocess) + postprocessor = build_postprocessor( + targets=targets, + config=config.postprocess, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + detector = build_detector( + num_classes=len(targets.class_names), + config=config.model, + ) + return Model( + detector=detector, + postprocessor=postprocessor, + preprocessor=preprocessor, + targets=targets, ) - backbone = build_backbone(config) - return build_detector(num_classes, backbone) diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index be55932..7fc377a 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -43,7 +43,7 @@ from batdetect2.models.encoder import ( EncoderConfig, build_encoder, ) -from batdetect2.models.types import BackboneModel +from batdetect2.typing.models import BackboneModel __all__ = [ "Backbone", diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 1b94f0a..518ae94 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -8,18 +8,25 @@ classifying them. The primary components are: - `Detector`: The `torch.nn.Module` subclass representing the complete model. -- `build_detector`: A factory function to conveniently construct a standard - `Detector` instance given a backbone and the number of target classes. This module focuses purely on the neural network architecture definition. The logic for preprocessing inputs and postprocessing/decoding outputs resides in the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. """ -import torch +from typing import Optional +import torch +from loguru import logger + +from batdetect2.models.backbones import BackboneConfig, build_backbone from batdetect2.models.heads import BBoxHead, ClassifierHead -from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput +from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput + +__all__ = [ + "Detector", + "build_detector", +] class Detector(DetectionModel): @@ -119,36 +126,41 @@ class Detector(DetectionModel): ) -def build_detector(num_classes: int, backbone: BackboneModel) -> Detector: - """Factory function to build a standard Detector model instance. - - Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`, - `BBoxHead`) configured appropriately based on the output channels of the - provided `backbone` and the specified `num_classes`. It then assembles - these components into a `Detector` model. +def build_detector( + num_classes: int, config: Optional[BackboneConfig] = None +) -> DetectionModel: + """Build the complete BatDetect2 detection model. Parameters ---------- num_classes : int - The number of specific target classes for the classification head - (excluding any implicit background class). Must be positive. - backbone : BackboneModel - An initialized feature extraction backbone module instance. The number - of output channels from this backbone (`backbone.out_channels`) is used - to configure the input channels for the prediction heads. + The number of specific target classes the model should predict + (required for the `ClassifierHead`). Must be positive. + config : BackboneConfig, optional + Configuration object specifying the architecture of the backbone + (encoder, bottleneck, decoder). If None, default configurations defined + within the respective builder functions (`build_encoder`, etc.) will be + used to construct a default backbone architecture. Returns ------- - Detector + DetectionModel An initialized `Detector` model instance. Raises ------ ValueError - If `num_classes` is not positive. - AttributeError - If `backbone` does not have the required `out_channels` attribute. + If `num_classes` is not positive, or if errors occur during the + construction of the backbone or detector components (e.g., incompatible + configurations, invalid parameters). """ + config = config or BackboneConfig() + + logger.opt(lazy=True).debug( + "Building model with config: \n{}", + lambda: config.to_yaml_string(), + ) + backbone = build_backbone(config=config) classifier_head = ClassifierHead( num_classes=num_classes, in_channels=backbone.out_channels, diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index 40016e5..276c5f4 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -4,7 +4,7 @@ from matplotlib.axes import Axes from soundevent import data, plot from batdetect2.plotting.clips import plot_clip -from batdetect2.preprocess import PreprocessorProtocol +from batdetect2.typing.preprocess import PreprocessorProtocol __all__ = [ "plot_clip_annotation", diff --git a/src/batdetect2/plotting/clip_predictions.py b/src/batdetect2/plotting/clip_predictions.py index b741e61..994f984 100644 --- a/src/batdetect2/plotting/clip_predictions.py +++ b/src/batdetect2/plotting/clip_predictions.py @@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag from batdetect2.plotting.clips import plot_clip -from batdetect2.preprocess import PreprocessorProtocol +from batdetect2.typing.preprocess import PreprocessorProtocol __all__ = [ "plot_clip_prediction", diff --git a/src/batdetect2/plotting/evaluation.py b/src/batdetect2/plotting/evaluation.py index 41b92fa..5abdcd2 100644 --- a/src/batdetect2/plotting/evaluation.py +++ b/src/batdetect2/plotting/evaluation.py @@ -7,8 +7,8 @@ import matplotlib.pyplot as plt import pandas as pd from batdetect2 import plotting -from batdetect2.evaluate.types import MatchEvaluation -from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.typing.evaluate import MatchEvaluation +from batdetect2.typing.preprocess import PreprocessorProtocol @dataclass diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index d462954..7d5b19c 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -6,13 +6,13 @@ from soundevent import data, plot from soundevent.geometry import compute_bounds from soundevent.plot.tags import TagColorMapper -from batdetect2.evaluate.types import MatchEvaluation from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clips import plot_clip from batdetect2.preprocess import ( PreprocessorProtocol, get_default_preprocessor, ) +from batdetect2.typing.evaluate import MatchEvaluation __all__ = [ "plot_matches", diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index 8dcc53b..e724e2b 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -36,7 +36,6 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config -from batdetect2.models.types import ModelOutput from batdetect2.postprocess.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, convert_raw_prediction_to_sound_event_prediction, @@ -62,13 +61,14 @@ from batdetect2.postprocess.remapping import ( features_to_xarray, sizes_to_xarray, ) -from batdetect2.postprocess.types import ( +from batdetect2.preprocess import MAX_FREQ, MIN_FREQ +from batdetect2.typing.models import ModelOutput +from batdetect2.typing.postprocess import ( BatDetect2Prediction, PostprocessorProtocol, RawPrediction, ) -from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.targets.types import TargetProtocol +from batdetect2.typing.targets import TargetProtocol __all__ = [ "DEFAULT_CLASSIFICATION_THRESHOLD", @@ -79,8 +79,6 @@ __all__ = [ "NMS_KERNEL_SIZE", "PostprocessConfig", "Postprocessor", - "PostprocessorProtocol", - "RawPrediction", "TOP_K_PER_SEC", "build_postprocessor", "classification_to_xarray", diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 599f34c..d105080 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -32,8 +32,8 @@ import numpy as np import xarray as xr from soundevent import data -from batdetect2.postprocess.types import GeometryDecoder, RawPrediction -from batdetect2.targets.types import TargetProtocol +from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction +from batdetect2.typing.targets import TargetProtocol __all__ = [ "convert_xr_dataset_to_raw_prediction", diff --git a/src/batdetect2/preprocess/__init__.py b/src/batdetect2/preprocess/__init__.py index 69ed341..37f7ad0 100644 --- a/src/batdetect2/preprocess/__init__.py +++ b/src/batdetect2/preprocess/__init__.py @@ -57,7 +57,7 @@ from batdetect2.preprocess.spectrogram import ( build_spectrogram_builder, get_spectrogram_resolution, ) -from batdetect2.preprocess.types import ( +from batdetect2.typing.preprocess import ( AudioLoader, PreprocessorProtocol, SpectrogramBuilder, @@ -65,7 +65,6 @@ from batdetect2.preprocess.types import ( __all__ = [ "AudioConfig", - "AudioLoader", "ConfigurableSpectrogramBuilder", "DEFAULT_DURATION", "FrequencyConfig", @@ -77,7 +76,6 @@ __all__ = [ "SCALE_RAW_AUDIO", "STFTConfig", "SpecSizeConfig", - "SpectrogramBuilder", "SpectrogramConfig", "StandardPreprocessor", "TARGET_SAMPLERATE_HZ", diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index 060c941..a67c065 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -32,7 +32,7 @@ from soundevent.arrays import operations as ops from soundfile import LibsndfileError from batdetect2.configs import BaseConfig -from batdetect2.preprocess.types import AudioLoader +from batdetect2.typing.preprocess import AudioLoader __all__ = [ "ResampleConfig", diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 8a799fa..77a311d 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -31,7 +31,7 @@ from soundevent.arrays import operations as ops from batdetect2.configs import BaseConfig from batdetect2.preprocess.audio import convert_to_xr -from batdetect2.preprocess.types import SpectrogramBuilder +from batdetect2.typing.preprocess import SpectrogramBuilder __all__ = [ "STFTConfig", @@ -540,7 +540,6 @@ def apply_pcen( xr.DataArray PCEN-scaled spectrogram. """ - spec = spec * (2**31) samplerate = 1 / spec.time.attrs["step"] hop_size = spec.attrs["hop_size"] @@ -559,22 +558,24 @@ def apply_pcen( [1, smoothing_constant - 1], )[:] + spec_data = spec.data * (2**31) + # Smooth the input array along the given axis smoothed, _ = signal.lfilter( [smoothing_constant], [1, smoothing_constant - 1], - spec.data, + spec_data, zi=zi, axis=axis, # type: ignore ) smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps))) data = (bias**power) * np.expm1( - power * np.log1p(spec.data * smooth / bias) + power * np.log1p(spec_data * smooth / bias) ) return xr.DataArray( - data, + data.astype(spec.dtype), dims=spec.dims, coords=spec.coords, attrs=spec.attrs, diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 1c50690..c976f71 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -80,7 +80,7 @@ from batdetect2.targets.transform import ( load_transformation_from_config, register_derivation, ) -from batdetect2.targets.types import Position, Size, TargetProtocol +from batdetect2.typing.targets import Position, Size, TargetProtocol __all__ = [ "ClassesConfig", @@ -99,7 +99,6 @@ __all__ = [ "TagInfo", "TargetClass", "TargetConfig", - "TargetProtocol", "Targets", "TermInfo", "TransformConfig", diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index d9083a3..95d339c 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -14,6 +14,7 @@ from batdetect2.targets.terms import ( default_term_registry, get_tag_from_info, ) +from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder __all__ = [ "DEFAULT_SPECIES_LIST", @@ -27,25 +28,6 @@ __all__ = [ ] -SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] -"""Type alias for a sound event class encoder function. - -An encoder function takes a sound event annotation and returns the string name -of the target class it belongs to, based on a predefined set of rules. -If the annotation does not match any defined target class according to the -rules, the function returns None. -""" - - -SoundEventDecoder = Callable[[str], List[data.Tag]] -"""Type alias for a sound event class decoder function. - -A decoder function takes a class name string (as predicted by the model or -assigned during encoding) and returns a list of `soundevent.data.Tag` objects -that represent that class according to the configuration. This is used to -translate model outputs back into meaningful annotations. -""" - DEFAULT_SPECIES_LIST = [ "Barbastella barbastellus", "Eptesicus serotinus", diff --git a/src/batdetect2/targets/filtering.py b/src/batdetect2/targets/filtering.py index 28e9c43..e532cc5 100644 --- a/src/batdetect2/targets/filtering.py +++ b/src/batdetect2/targets/filtering.py @@ -1,6 +1,6 @@ import logging from functools import partial -from typing import Callable, List, Literal, Optional, Set +from typing import List, Literal, Optional, Set from pydantic import Field from soundevent import data @@ -12,11 +12,11 @@ from batdetect2.targets.terms import ( default_term_registry, get_tag_from_info, ) +from batdetect2.typing.targets import SoundEventFilter __all__ = [ "FilterConfig", "FilterRule", - "SoundEventFilter", "build_sound_event_filter", "build_filter_from_rule", "load_filter_config", @@ -24,14 +24,6 @@ __all__ = [ ] -SoundEventFilter = Callable[[data.SoundEventAnnotation], bool] -"""Type alias for a filter function. - -A filter function accepts a soundevent.data.SoundEventAnnotation object -and returns True if the annotation should be kept based on the filter's -criteria, or False if it should be discarded. -""" - logger = logging.getLogger(__name__) diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 330f6a0..120459f 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -28,8 +28,8 @@ from soundevent import data from batdetect2.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.types import Position, Size +from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.typing.targets import Position, Size __all__ = [ "Anchor", diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index df5dc27..fd161e0 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -24,7 +24,6 @@ from batdetect2.train.config import ( from batdetect2.train.dataset import ( LabeledDataset, RandomExampleSource, - TrainExample, list_preprocessed_files, ) from batdetect2.train.labels import build_clip_labeler, load_label_config @@ -64,7 +63,6 @@ __all__ = [ "RandomExampleSource", "SizeLossConfig", "TimeMaskAugmentationConfig", - "TrainExample", "TrainingConfig", "TrainingModule", "VolumeAugmentationConfig", diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 6c08f31..a991e32 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -33,8 +33,7 @@ from pydantic import Field from soundevent import arrays, data from batdetect2.configs import BaseConfig, load_config -from batdetect2.preprocess import PreprocessorProtocol -from batdetect2.train.types import Augmentation +from batdetect2.typing import Augmentation, PreprocessorProtocol from batdetect2.utils.arrays import adjust_width __all__ = [ diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 554152d..f3b2830 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -14,16 +14,18 @@ from batdetect2.evaluate.match import ( MatchConfig, match_sound_events_and_raw_predictions, ) -from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.plotting.evaluation import plot_example_gallery -from batdetect2.postprocess.types import ( - BatDetect2Prediction, - PostprocessorProtocol, -) -from batdetect2.targets.types import TargetProtocol -from batdetect2.train.dataset import LabeledDataset, TrainExample +from batdetect2.train.dataset import LabeledDataset from batdetect2.train.lightning import TrainingModule -from batdetect2.train.types import ModelOutput +from batdetect2.typing import ( + BatDetect2Prediction, + MatchEvaluation, + MetricsProtocol, + ModelOutput, + PostprocessorProtocol, + TargetProtocol, + TrainExample, +) class ValidationMetrics(Callback): diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 87ee673..2d8c62e 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -6,7 +6,7 @@ from loguru import logger from soundevent import arrays from batdetect2.configs import BaseConfig -from batdetect2.train.types import ClipperProtocol +from batdetect2.typing import ClipperProtocol DEFAULT_TRAIN_CLIP_DURATION = 0.513 DEFAULT_MAX_EMPTY_CLIP = 0.1 diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index a21101b..e2cb1cd 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -4,11 +4,8 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config -from batdetect2.evaluate.config import EvaluationConfig -from batdetect2.models import BackboneConfig -from batdetect2.postprocess import PostprocessConfig -from batdetect2.preprocess import PreprocessingConfig -from batdetect2.targets import TargetConfig +from batdetect2.evaluate import EvaluationConfig +from batdetect2.models import ModelConfig from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, @@ -85,16 +82,10 @@ def load_train_config( return load_config(path, schema=TrainingConfig, field=field) -class FullTrainingConfig(BaseConfig): +class FullTrainingConfig(ModelConfig): """Full training configuration.""" train: TrainingConfig = Field(default_factory=TrainingConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) - model: BackboneConfig = Field(default_factory=BackboneConfig) - preprocess: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 59a429a..1e67f23 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -8,7 +8,7 @@ from soundevent import data from torch.utils.data import Dataset from batdetect2.train.augmentations import Augmentation -from batdetect2.train.types import ClipperProtocol, TrainExample +from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.utils.tensors import adjust_width __all__ = [ diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 642e71e..9865ee7 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -34,10 +34,10 @@ from scipy.ndimage import gaussian_filter from soundevent import arrays, data from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets.types import TargetProtocol -from batdetect2.train.types import ( +from batdetect2.typing import ( ClipLabeller, Heatmaps, + TargetProtocol, ) __all__ = [ diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 481fd9f..b2fb261 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -3,12 +3,8 @@ import torch from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.models import ModelOutput -from batdetect2.models.types import DetectionModel -from batdetect2.postprocess.types import PostprocessorProtocol -from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.types import TargetProtocol -from batdetect2.train import TrainExample +from batdetect2.models import Model +from batdetect2.typing import ModelOutput, TrainExample __all__ = [ "TrainingModule", @@ -18,11 +14,8 @@ __all__ = [ class TrainingModule(L.LightningModule): def __init__( self, - detector: DetectionModel, + model: Model, loss: torch.nn.Module, - targets: TargetProtocol, - preprocessor: PreprocessorProtocol, - postprocessor: PostprocessorProtocol, learning_rate: float = 0.001, t_max: int = 100, ): @@ -32,18 +25,14 @@ class TrainingModule(L.LightningModule): self.t_max = t_max self.loss = loss - self.targets = targets - self.detector = detector - self.preprocessor = preprocessor - self.postprocessor = postprocessor - + self.model = model self.save_hyperparameters(logger=False) def forward(self, spec: torch.Tensor) -> ModelOutput: - return self.detector(spec) + return self.model(spec) def training_step(self, batch: TrainExample): - outputs = self.forward(batch.spec) + outputs = self.model(batch.spec) losses = self.loss(outputs, batch) self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("detection_loss/train", losses.total, logger=True) @@ -56,7 +45,7 @@ class TrainingModule(L.LightningModule): batch: TrainExample, batch_idx: int, ) -> ModelOutput: - outputs = self.forward(batch.spec) + outputs = self.model(batch.spec) losses = self.loss(outputs, batch) self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("detection_loss/val", losses.total, logger=True) diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index 56bc092..ea5b91a 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -28,9 +28,7 @@ from pydantic import Field from torch import nn from batdetect2.configs import BaseConfig -from batdetect2.models.types import ModelOutput -from batdetect2.train.dataset import TrainExample -from batdetect2.train.types import Losses, LossProtocol +from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample __all__ = [ "BBoxLoss", diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index a6f4f92..5c7c907 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -34,10 +34,9 @@ from tqdm.auto import tqdm from batdetect2.configs import BaseConfig, load_config from batdetect2.data.datasets import Dataset from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import TargetConfig, build_targets from batdetect2.train.labels import LabelConfig, build_clip_labeler -from batdetect2.train.types import ClipLabeller +from batdetect2.typing import ClipLabeller, PreprocessorProtocol __all__ = [ "preprocess_annotations", diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index fed22b5..11d403f 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -14,12 +14,6 @@ from batdetect2.evaluate.metrics import ( DetectionAveragePrecision, ) from batdetect2.models import build_model -from batdetect2.postprocess import build_postprocessor -from batdetect2.preprocess import ( - PreprocessorProtocol, - build_preprocessor, -) -from batdetect2.targets import TargetProtocol, build_targets from batdetect2.train.augmentations import build_augmentations from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper @@ -32,6 +26,7 @@ from batdetect2.train.dataset import ( from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger from batdetect2.train.losses import build_loss +from batdetect2.typing import PreprocessorProtocol, TargetProtocol __all__ = [ "build_train_dataset", @@ -88,25 +83,11 @@ def train( def build_training_module(config: FullTrainingConfig) -> TrainingModule: - targets = build_targets(config=config.targets) + model = build_model(config=config) loss = build_loss(config=config.train.loss) - preprocessor = build_preprocessor(config.preprocess) - postprocessor = build_postprocessor( - targets, - config=config.postprocess, - max_freq=preprocessor.max_freq, - min_freq=preprocessor.min_freq, - ) - model = build_model( - num_classes=len(targets.class_names), - config=config.model, - ) return TrainingModule( - detector=model, + model=model, loss=loss, - preprocessor=preprocessor, - postprocessor=postprocessor, - targets=targets, learning_rate=config.train.learning_rate, t_max=config.train.t_max, ) diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py new file mode 100644 index 0000000..a9ef09e --- /dev/null +++ b/src/batdetect2/typing/__init__.py @@ -0,0 +1,58 @@ +from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol +from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput +from batdetect2.typing.postprocess import ( + BatDetect2Prediction, + GeometryDecoder, + PostprocessorProtocol, + RawPrediction, +) +from batdetect2.typing.preprocess import ( + AudioLoader, + PreprocessorProtocol, + SpectrogramBuilder, +) +from batdetect2.typing.targets import ( + Position, + Size, + SoundEventDecoder, + SoundEventEncoder, + SoundEventFilter, + TargetProtocol, +) +from batdetect2.typing.train import ( + Augmentation, + ClipLabeller, + ClipperProtocol, + Heatmaps, + Losses, + LossProtocol, + TrainExample, +) + +__all__ = [ + "AudioLoader", + "Augmentation", + "BackboneModel", + "BatDetect2Prediction", + "ClipLabeller", + "ClipperProtocol", + "DetectionModel", + "GeometryDecoder", + "Heatmaps", + "LossProtocol", + "Losses", + "MatchEvaluation", + "MetricsProtocol", + "ModelOutput", + "Position", + "PostprocessorProtocol", + "PreprocessorProtocol", + "RawPrediction", + "Size", + "SoundEventDecoder", + "SoundEventEncoder", + "SoundEventFilter", + "SpectrogramBuilder", + "TargetProtocol", + "TrainExample", +] diff --git a/src/batdetect2/evaluate/types.py b/src/batdetect2/typing/evaluate.py similarity index 100% rename from src/batdetect2/evaluate/types.py rename to src/batdetect2/typing/evaluate.py diff --git a/src/batdetect2/models/types.py b/src/batdetect2/typing/models.py similarity index 98% rename from src/batdetect2/models/types.py rename to src/batdetect2/typing/models.py index 4f3238d..d71193f 100644 --- a/src/batdetect2/models/types.py +++ b/src/batdetect2/typing/models.py @@ -19,7 +19,6 @@ from abc import ABC, abstractmethod from typing import NamedTuple import torch -import torch.nn as nn __all__ = [ "ModelOutput", @@ -65,7 +64,7 @@ class ModelOutput(NamedTuple): features: torch.Tensor -class BackboneModel(ABC, nn.Module): +class BackboneModel(ABC, torch.nn.Module): """Abstract Base Class for generic feature extraction backbone models. Defines the minimal interface for a feature extractor network within a @@ -191,7 +190,7 @@ class EncoderDecoderModel(BackboneModel): ... -class DetectionModel(ABC, nn.Module): +class DetectionModel(ABC, torch.nn.Module): """Abstract Base Class for complete BatDetect2 detection models. Defines the interface for the overall model that takes an input spectrogram diff --git a/src/batdetect2/postprocess/types.py b/src/batdetect2/typing/postprocess.py similarity index 99% rename from src/batdetect2/postprocess/types.py rename to src/batdetect2/typing/postprocess.py index 533b3e1..9aeca94 100644 --- a/src/batdetect2/postprocess/types.py +++ b/src/batdetect2/typing/postprocess.py @@ -18,8 +18,8 @@ import numpy as np import xarray as xr from soundevent import data -from batdetect2.models.types import ModelOutput -from batdetect2.targets.types import Position, Size +from batdetect2.typing.models import ModelOutput +from batdetect2.typing.targets import Position, Size __all__ = [ "RawPrediction", diff --git a/src/batdetect2/preprocess/types.py b/src/batdetect2/typing/preprocess.py similarity index 99% rename from src/batdetect2/preprocess/types.py rename to src/batdetect2/typing/preprocess.py index b8a0650..71ad9bd 100644 --- a/src/batdetect2/preprocess/types.py +++ b/src/batdetect2/typing/preprocess.py @@ -16,6 +16,12 @@ import numpy as np import xarray as xr from soundevent import data +__all__ = [ + "AudioLoader", + "SpectrogramBuilder", + "PreprocessorProtocol", +] + class AudioLoader(Protocol): """Defines the interface for an audio loading and processing component. diff --git a/src/batdetect2/targets/types.py b/src/batdetect2/typing/targets.py similarity index 86% rename from src/batdetect2/targets/types.py rename to src/batdetect2/typing/targets.py index 221897a..2846a0e 100644 --- a/src/batdetect2/targets/types.py +++ b/src/batdetect2/typing/targets.py @@ -12,6 +12,7 @@ that components responsible for these tasks can be interacted with consistently throughout BatDetect2. """ +from collections.abc import Callable from typing import List, Optional, Protocol import numpy as np @@ -19,10 +20,40 @@ from soundevent import data __all__ = [ "TargetProtocol", + "SoundEventEncoder", + "SoundEventDecoder", + "SoundEventFilter", "Position", "Size", ] +SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] +"""Type alias for a sound event class encoder function. + +An encoder function takes a sound event annotation and returns the string name +of the target class it belongs to, based on a predefined set of rules. +If the annotation does not match any defined target class according to the +rules, the function returns None. +""" + + +SoundEventDecoder = Callable[[str], List[data.Tag]] +"""Type alias for a sound event class decoder function. + +A decoder function takes a class name string (as predicted by the model or +assigned during encoding) and returns a list of `soundevent.data.Tag` objects +that represent that class according to the configuration. This is used to +translate model outputs back into meaningful annotations. +""" + +SoundEventFilter = Callable[[data.SoundEventAnnotation], bool] +"""Type alias for a filter function. + +A filter function accepts a soundevent.data.SoundEventAnnotation object +and returns True if the annotation should be kept based on the filter's +criteria, or False if it should be discarded. +""" + Position = tuple[float, float] """A tuple representing (time, frequency) coordinates.""" diff --git a/src/batdetect2/train/types.py b/src/batdetect2/typing/train.py similarity index 97% rename from src/batdetect2/train/types.py rename to src/batdetect2/typing/train.py index 02e84c7..950624a 100644 --- a/src/batdetect2/train/types.py +++ b/src/batdetect2/typing/train.py @@ -4,13 +4,15 @@ import torch import xarray as xr from soundevent import data -from batdetect2.models import ModelOutput +from batdetect2.typing.models import ModelOutput __all__ = [ - "Heatmaps", - "ClipLabeller", "Augmentation", + "ClipLabeller", + "ClipperProtocol", + "Heatmaps", "LossProtocol", + "Losses", "TrainExample", ] diff --git a/tests/conftest.py b/tests/conftest.py index 1c8b065..b56bce8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ from soundevent import data, terms from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor -from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import ( TargetConfig, TermRegistry, @@ -22,9 +21,12 @@ from batdetect2.targets import ( from batdetect2.targets.classes import ClassesConfig, TargetClass from batdetect2.targets.filtering import FilterConfig, FilterRule from batdetect2.targets.terms import TagInfo -from batdetect2.targets.types import TargetProtocol from batdetect2.train.labels import build_clip_labeler -from batdetect2.train.types import ClipLabeller +from batdetect2.typing import ( + ClipLabeller, + PreprocessorProtocol, + TargetProtocol, +) @pytest.fixture diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index 4dbb431..771aa69 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -15,8 +15,7 @@ from batdetect2.postprocess.decoding import ( get_generic_tags, get_prediction_features, ) -from batdetect2.postprocess.types import RawPrediction -from batdetect2.targets.types import TargetProtocol +from batdetect2.typing import RawPrediction, TargetProtocol @pytest.fixture diff --git a/tests/test_preprocessing/test_audio.py b/tests/test_preprocessing/test_audio.py index 15fc13d..70a63f6 100644 --- a/tests/test_preprocessing/test_audio.py +++ b/tests/test_preprocessing/test_audio.py @@ -142,7 +142,7 @@ def test_audio_config_defaults(): assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ assert config.resample.method == "poly" assert config.scale == audio.SCALE_RAW_AUDIO - assert config.center is True + assert config.center is False assert config.duration == audio.DEFAULT_DURATION diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index 5c5beff..9f64494 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -108,7 +108,7 @@ def test_spec_size_config_defaults(): def test_pcen_config_defaults(): config = PcenConfig() - assert config.time_constant == 0.4 + assert config.time_constant == 0.01 assert config.gain == 0.98 assert config.bias == 2 assert config.power == 0.5 @@ -202,13 +202,6 @@ def test_crop_spectrogram_full_range(sample_spec: xr.DataArray): def test_apply_pcen(sample_spec: xr.DataArray): - if "original_samplerate" not in sample_spec.attrs: - sample_spec.attrs["original_samplerate"] = SAMPLERATE - if "nfft" not in sample_spec.attrs: - sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE) - if "noverlap" not in sample_spec.attrs: - sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"]) - pcen_config = PcenConfig() pcen_spec = apply_pcen( sample_spec, diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index 65579a7..4dbd2fa 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -5,14 +5,13 @@ import pytest import xarray as xr from soundevent import arrays, data -from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.train.augmentations import ( add_echo, mix_examples, ) from batdetect2.train.clips import select_subclip from batdetect2.train.preprocess import generate_train_example -from batdetect2.train.types import ClipLabeller +from batdetect2.typing import ClipLabeller, PreprocessorProtocol def test_mix_examples( diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 17e21bd..2ce8302 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -6,7 +6,7 @@ from soundevent import data from batdetect2.targets import TargetConfig, TargetProtocol, build_targets from batdetect2.targets.rois import AnchorBBoxMapperConfig -from batdetect2.targets.terms import TagInfo, TermRegistry +from batdetect2.targets.terms import TagInfo from batdetect2.train.labels import generate_heatmaps recording = data.Recording( diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 828a464..13da352 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -28,8 +28,8 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip): recovered = TrainingModule.load_from_checkpoint(path) - spec1 = module.preprocessor.preprocess_clip(clip) - spec2 = recovered.preprocessor.preprocess_clip(clip) + spec1 = module.model.preprocessor.preprocess_clip(clip) + spec2 = recovered.model.preprocessor.preprocess_clip(clip) xr.testing.assert_equal(spec1, spec2) diff --git a/tests/test_train/test_preprocessing.py b/tests/test_train/test_preprocessing.py index ee97fd2..83b05c5 100644 --- a/tests/test_train/test_preprocessing.py +++ b/tests/test_train/test_preprocessing.py @@ -4,12 +4,12 @@ import xarray as xr from soundevent import data from soundevent.terms import get_term -from batdetect2.models.types import ModelOutput from batdetect2.postprocess import build_postprocessor, load_postprocess_config from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.targets import build_targets, load_target_config from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.preprocess import generate_train_example +from batdetect2.typing import ModelOutput @pytest.fixture