mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 09:29:33 +01:00
Moved types to dedicated module
This commit is contained in:
parent
02adc19070
commit
61115d562c
@ -4,7 +4,7 @@ from typing import Optional, Tuple
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from batdetect2.data.summary import (
|
|||||||
extract_recordings_df,
|
extract_recordings_df,
|
||||||
extract_sound_events_df,
|
extract_sound_events_df,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def split_dataset_by_recordings(
|
def split_dataset_by_recordings(
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import pandas as pd
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_recordings_df",
|
"extract_recordings_df",
|
||||||
|
|||||||
@ -4,15 +4,15 @@ from typing import List, Literal, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
from soundevent.evaluation import (
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
match_geometries as optimal_match,
|
|
||||||
)
|
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.evaluate.types import MatchEvaluation
|
from batdetect2.typing import (
|
||||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
BatDetect2Prediction,
|
||||||
from batdetect2.targets.types import TargetProtocol
|
MatchEvaluation,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
MatchingStrategy = Literal["greedy", "optimal"]
|
MatchingStrategy = Literal["greedy", "optimal"]
|
||||||
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import pandas as pd
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from sklearn.preprocessing import label_binarize
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
from batdetect2.typing import MatchEvaluation, MetricsProtocol
|
||||||
|
|
||||||
__all__ = ["DetectionAveragePrecision"]
|
__all__ = ["DetectionAveragePrecision"]
|
||||||
|
|
||||||
|
|||||||
@ -28,8 +28,11 @@ provided here.
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from loguru import logger
|
import torch
|
||||||
|
from lightning import LightningModule
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Backbone,
|
Backbone,
|
||||||
BackboneConfig,
|
BackboneConfig,
|
||||||
@ -53,24 +56,25 @@ from batdetect2.models.decoder import (
|
|||||||
DecoderConfig,
|
DecoderConfig,
|
||||||
build_decoder,
|
build_decoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.detectors import (
|
from batdetect2.models.detectors import Detector, build_detector
|
||||||
Detector,
|
|
||||||
build_detector,
|
|
||||||
)
|
|
||||||
from batdetect2.models.encoder import (
|
from batdetect2.models.encoder import (
|
||||||
DEFAULT_ENCODER_CONFIG,
|
DEFAULT_ENCODER_CONFIG,
|
||||||
EncoderConfig,
|
EncoderConfig,
|
||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
|
from batdetect2.typing.models import DetectionModel, ModelOutput
|
||||||
|
from batdetect2.typing.postprocess import PostprocessorProtocol
|
||||||
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
"Backbone",
|
"Backbone",
|
||||||
"BackboneConfig",
|
"BackboneConfig",
|
||||||
"BackboneModel",
|
|
||||||
"BackboneModel",
|
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
"ClassifierHead",
|
"ClassifierHead",
|
||||||
@ -78,65 +82,75 @@ __all__ = [
|
|||||||
"DEFAULT_DECODER_CONFIG",
|
"DEFAULT_DECODER_CONFIG",
|
||||||
"DEFAULT_ENCODER_CONFIG",
|
"DEFAULT_ENCODER_CONFIG",
|
||||||
"DecoderConfig",
|
"DecoderConfig",
|
||||||
"DetectionModel",
|
|
||||||
"Detector",
|
"Detector",
|
||||||
"DetectorHead",
|
"DetectorHead",
|
||||||
"EncoderConfig",
|
"EncoderConfig",
|
||||||
"FreqCoordConvDownConfig",
|
"FreqCoordConvDownConfig",
|
||||||
"FreqCoordConvUpConfig",
|
"FreqCoordConvUpConfig",
|
||||||
"ModelOutput",
|
|
||||||
"StandardConvDownConfig",
|
"StandardConvDownConfig",
|
||||||
"StandardConvUpConfig",
|
"StandardConvUpConfig",
|
||||||
"build_backbone",
|
"build_backbone",
|
||||||
"build_bottleneck",
|
"build_bottleneck",
|
||||||
"build_decoder",
|
"build_decoder",
|
||||||
"build_detector",
|
|
||||||
"build_encoder",
|
"build_encoder",
|
||||||
"build_model",
|
"build_detector",
|
||||||
"load_backbone_config",
|
"load_backbone_config",
|
||||||
|
"Model",
|
||||||
|
"ModelConfig",
|
||||||
|
"build_model",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_model(
|
class Model(LightningModule):
|
||||||
num_classes: int,
|
detector: DetectionModel
|
||||||
config: Optional[BackboneConfig] = None,
|
preprocessor: PreprocessorProtocol
|
||||||
) -> DetectionModel:
|
postprocessor: PostprocessorProtocol
|
||||||
"""Build the complete BatDetect2 detection model.
|
targets: TargetProtocol
|
||||||
|
|
||||||
This high-level factory function constructs the standard BatDetect2 model
|
def __init__(
|
||||||
architecture. It first builds the feature extraction backbone (typically an
|
self,
|
||||||
encoder-bottleneck-decoder structure) based on the provided
|
detector: DetectionModel,
|
||||||
`BackboneConfig` (or defaults if None), and then attaches the standard
|
preprocessor: PreprocessorProtocol,
|
||||||
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
|
postprocessor: PostprocessorProtocol,
|
||||||
`build_detector` function.
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.detector = detector
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.postprocessor = postprocessor
|
||||||
|
self.targets = targets
|
||||||
|
|
||||||
Parameters
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
----------
|
return self.detector(spec)
|
||||||
num_classes : int
|
|
||||||
The number of specific target classes the model should predict
|
|
||||||
(required for the `ClassifierHead`). Must be positive.
|
|
||||||
config : BackboneConfig, optional
|
|
||||||
Configuration object specifying the architecture of the backbone
|
|
||||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
|
||||||
within the respective builder functions (`build_encoder`, etc.) will be
|
|
||||||
used to construct a default backbone architecture.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
DetectionModel
|
|
||||||
An initialized `Detector` model instance.
|
|
||||||
|
|
||||||
Raises
|
class ModelConfig(BaseConfig):
|
||||||
------
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
ValueError
|
preprocess: PreprocessingConfig = Field(
|
||||||
If `num_classes` is not positive, or if errors occur during the
|
default_factory=PreprocessingConfig
|
||||||
construction of the backbone or detector components (e.g., incompatible
|
)
|
||||||
configurations, invalid parameters).
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
"""
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
config = config or BackboneConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building model with config: \n{}",
|
def build_model(config: Optional[ModelConfig] = None):
|
||||||
lambda: config.to_yaml_string(),
|
config = config or ModelConfig()
|
||||||
|
|
||||||
|
targets = build_targets(config=config.targets)
|
||||||
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
targets=targets,
|
||||||
|
config=config.postprocess,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
detector = build_detector(
|
||||||
|
num_classes=len(targets.class_names),
|
||||||
|
config=config.model,
|
||||||
|
)
|
||||||
|
return Model(
|
||||||
|
detector=detector,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
targets=targets,
|
||||||
)
|
)
|
||||||
backbone = build_backbone(config)
|
|
||||||
return build_detector(num_classes, backbone)
|
|
||||||
|
|||||||
@ -43,7 +43,7 @@ from batdetect2.models.encoder import (
|
|||||||
EncoderConfig,
|
EncoderConfig,
|
||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.types import BackboneModel
|
from batdetect2.typing.models import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Backbone",
|
"Backbone",
|
||||||
|
|||||||
@ -8,18 +8,25 @@ classifying them.
|
|||||||
|
|
||||||
The primary components are:
|
The primary components are:
|
||||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||||
- `build_detector`: A factory function to conveniently construct a standard
|
|
||||||
`Detector` instance given a backbone and the number of target classes.
|
|
||||||
|
|
||||||
This module focuses purely on the neural network architecture definition. The
|
This module focuses purely on the neural network architecture definition. The
|
||||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Detector",
|
||||||
|
"build_detector",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Detector(DetectionModel):
|
class Detector(DetectionModel):
|
||||||
@ -119,36 +126,41 @@ class Detector(DetectionModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
def build_detector(
|
||||||
"""Factory function to build a standard Detector model instance.
|
num_classes: int, config: Optional[BackboneConfig] = None
|
||||||
|
) -> DetectionModel:
|
||||||
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
|
"""Build the complete BatDetect2 detection model.
|
||||||
`BBoxHead`) configured appropriately based on the output channels of the
|
|
||||||
provided `backbone` and the specified `num_classes`. It then assembles
|
|
||||||
these components into a `Detector` model.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
The number of specific target classes for the classification head
|
The number of specific target classes the model should predict
|
||||||
(excluding any implicit background class). Must be positive.
|
(required for the `ClassifierHead`). Must be positive.
|
||||||
backbone : BackboneModel
|
config : BackboneConfig, optional
|
||||||
An initialized feature extraction backbone module instance. The number
|
Configuration object specifying the architecture of the backbone
|
||||||
of output channels from this backbone (`backbone.out_channels`) is used
|
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||||
to configure the input channels for the prediction heads.
|
within the respective builder functions (`build_encoder`, etc.) will be
|
||||||
|
used to construct a default backbone architecture.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Detector
|
DetectionModel
|
||||||
An initialized `Detector` model instance.
|
An initialized `Detector` model instance.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `num_classes` is not positive.
|
If `num_classes` is not positive, or if errors occur during the
|
||||||
AttributeError
|
construction of the backbone or detector components (e.g., incompatible
|
||||||
If `backbone` does not have the required `out_channels` attribute.
|
configurations, invalid parameters).
|
||||||
"""
|
"""
|
||||||
|
config = config or BackboneConfig()
|
||||||
|
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building model with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
backbone = build_backbone(config=config)
|
||||||
classifier_head = ClassifierHead(
|
classifier_head = ClassifierHead(
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from matplotlib.axes import Axes
|
|||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
|
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip_annotation",
|
"plot_clip_annotation",
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
|
|||||||
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
||||||
|
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip_prediction",
|
"plot_clip_prediction",
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from batdetect2 import plotting
|
from batdetect2 import plotting
|
||||||
from batdetect2.evaluate.types import MatchEvaluation
|
from batdetect2.typing.evaluate import MatchEvaluation
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -6,13 +6,13 @@ from soundevent import data, plot
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.plot.tags import TagColorMapper
|
from soundevent.plot.tags import TagColorMapper
|
||||||
|
|
||||||
from batdetect2.evaluate.types import MatchEvaluation
|
|
||||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
get_default_preprocessor,
|
get_default_preprocessor,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.evaluate import MatchEvaluation
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_matches",
|
"plot_matches",
|
||||||
|
|||||||
@ -36,7 +36,6 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.types import ModelOutput
|
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.decoding import (
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
@ -62,13 +61,14 @@ from batdetect2.postprocess.remapping import (
|
|||||||
features_to_xarray,
|
features_to_xarray,
|
||||||
sizes_to_xarray,
|
sizes_to_xarray,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import (
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
from batdetect2.typing.models import ModelOutput
|
||||||
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
@ -79,8 +79,6 @@ __all__ = [
|
|||||||
"NMS_KERNEL_SIZE",
|
"NMS_KERNEL_SIZE",
|
||||||
"PostprocessConfig",
|
"PostprocessConfig",
|
||||||
"Postprocessor",
|
"Postprocessor",
|
||||||
"PostprocessorProtocol",
|
|
||||||
"RawPrediction",
|
|
||||||
"TOP_K_PER_SEC",
|
"TOP_K_PER_SEC",
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
"classification_to_xarray",
|
"classification_to_xarray",
|
||||||
|
|||||||
@ -32,8 +32,8 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_xr_dataset_to_raw_prediction",
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
|
|||||||
@ -57,7 +57,7 @@ from batdetect2.preprocess.spectrogram import (
|
|||||||
build_spectrogram_builder,
|
build_spectrogram_builder,
|
||||||
get_spectrogram_resolution,
|
get_spectrogram_resolution,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.types import (
|
from batdetect2.typing.preprocess import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
SpectrogramBuilder,
|
SpectrogramBuilder,
|
||||||
@ -65,7 +65,6 @@ from batdetect2.preprocess.types import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AudioConfig",
|
"AudioConfig",
|
||||||
"AudioLoader",
|
|
||||||
"ConfigurableSpectrogramBuilder",
|
"ConfigurableSpectrogramBuilder",
|
||||||
"DEFAULT_DURATION",
|
"DEFAULT_DURATION",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
@ -77,7 +76,6 @@ __all__ = [
|
|||||||
"SCALE_RAW_AUDIO",
|
"SCALE_RAW_AUDIO",
|
||||||
"STFTConfig",
|
"STFTConfig",
|
||||||
"SpecSizeConfig",
|
"SpecSizeConfig",
|
||||||
"SpectrogramBuilder",
|
|
||||||
"SpectrogramConfig",
|
"SpectrogramConfig",
|
||||||
"StandardPreprocessor",
|
"StandardPreprocessor",
|
||||||
"TARGET_SAMPLERATE_HZ",
|
"TARGET_SAMPLERATE_HZ",
|
||||||
|
|||||||
@ -32,7 +32,7 @@ from soundevent.arrays import operations as ops
|
|||||||
from soundfile import LibsndfileError
|
from soundfile import LibsndfileError
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess.types import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ResampleConfig",
|
"ResampleConfig",
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from soundevent.arrays import operations as ops
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess.audio import convert_to_xr
|
from batdetect2.preprocess.audio import convert_to_xr
|
||||||
from batdetect2.preprocess.types import SpectrogramBuilder
|
from batdetect2.typing.preprocess import SpectrogramBuilder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"STFTConfig",
|
"STFTConfig",
|
||||||
@ -540,7 +540,6 @@ def apply_pcen(
|
|||||||
xr.DataArray
|
xr.DataArray
|
||||||
PCEN-scaled spectrogram.
|
PCEN-scaled spectrogram.
|
||||||
"""
|
"""
|
||||||
spec = spec * (2**31)
|
|
||||||
samplerate = 1 / spec.time.attrs["step"]
|
samplerate = 1 / spec.time.attrs["step"]
|
||||||
hop_size = spec.attrs["hop_size"]
|
hop_size = spec.attrs["hop_size"]
|
||||||
|
|
||||||
@ -559,22 +558,24 @@ def apply_pcen(
|
|||||||
[1, smoothing_constant - 1],
|
[1, smoothing_constant - 1],
|
||||||
)[:]
|
)[:]
|
||||||
|
|
||||||
|
spec_data = spec.data * (2**31)
|
||||||
|
|
||||||
# Smooth the input array along the given axis
|
# Smooth the input array along the given axis
|
||||||
smoothed, _ = signal.lfilter(
|
smoothed, _ = signal.lfilter(
|
||||||
[smoothing_constant],
|
[smoothing_constant],
|
||||||
[1, smoothing_constant - 1],
|
[1, smoothing_constant - 1],
|
||||||
spec.data,
|
spec_data,
|
||||||
zi=zi,
|
zi=zi,
|
||||||
axis=axis, # type: ignore
|
axis=axis, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
|
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
|
||||||
data = (bias**power) * np.expm1(
|
data = (bias**power) * np.expm1(
|
||||||
power * np.log1p(spec.data * smooth / bias)
|
power * np.log1p(spec_data * smooth / bias)
|
||||||
)
|
)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data,
|
data.astype(spec.dtype),
|
||||||
dims=spec.dims,
|
dims=spec.dims,
|
||||||
coords=spec.coords,
|
coords=spec.coords,
|
||||||
attrs=spec.attrs,
|
attrs=spec.attrs,
|
||||||
|
|||||||
@ -80,7 +80,7 @@ from batdetect2.targets.transform import (
|
|||||||
load_transformation_from_config,
|
load_transformation_from_config,
|
||||||
register_derivation,
|
register_derivation,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassesConfig",
|
"ClassesConfig",
|
||||||
@ -99,7 +99,6 @@ __all__ = [
|
|||||||
"TagInfo",
|
"TagInfo",
|
||||||
"TargetClass",
|
"TargetClass",
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"TargetProtocol",
|
|
||||||
"Targets",
|
"Targets",
|
||||||
"TermInfo",
|
"TermInfo",
|
||||||
"TransformConfig",
|
"TransformConfig",
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from batdetect2.targets.terms import (
|
|||||||
default_term_registry,
|
default_term_registry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_SPECIES_LIST",
|
"DEFAULT_SPECIES_LIST",
|
||||||
@ -27,25 +28,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
|
||||||
"""Type alias for a sound event class encoder function.
|
|
||||||
|
|
||||||
An encoder function takes a sound event annotation and returns the string name
|
|
||||||
of the target class it belongs to, based on a predefined set of rules.
|
|
||||||
If the annotation does not match any defined target class according to the
|
|
||||||
rules, the function returns None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventDecoder = Callable[[str], List[data.Tag]]
|
|
||||||
"""Type alias for a sound event class decoder function.
|
|
||||||
|
|
||||||
A decoder function takes a class name string (as predicted by the model or
|
|
||||||
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
|
|
||||||
that represent that class according to the configuration. This is used to
|
|
||||||
translate model outputs back into meaningful annotations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_SPECIES_LIST = [
|
DEFAULT_SPECIES_LIST = [
|
||||||
"Barbastella barbastellus",
|
"Barbastella barbastellus",
|
||||||
"Eptesicus serotinus",
|
"Eptesicus serotinus",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, Literal, Optional, Set
|
from typing import List, Literal, Optional, Set
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -12,11 +12,11 @@ from batdetect2.targets.terms import (
|
|||||||
default_term_registry,
|
default_term_registry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.targets import SoundEventFilter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FilterConfig",
|
"FilterConfig",
|
||||||
"FilterRule",
|
"FilterRule",
|
||||||
"SoundEventFilter",
|
|
||||||
"build_sound_event_filter",
|
"build_sound_event_filter",
|
||||||
"build_filter_from_rule",
|
"build_filter_from_rule",
|
||||||
"load_filter_config",
|
"load_filter_config",
|
||||||
@ -24,14 +24,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
|
||||||
"""Type alias for a filter function.
|
|
||||||
|
|
||||||
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
|
||||||
and returns True if the annotation should be kept based on the filter's
|
|
||||||
criteria, or False if it should be discarded.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,8 +28,8 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.targets.types import Position, Size
|
from batdetect2.typing.targets import Position, Size
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Anchor",
|
"Anchor",
|
||||||
|
|||||||
@ -24,7 +24,6 @@ from batdetect2.train.config import (
|
|||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
TrainExample,
|
|
||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
@ -64,7 +63,6 @@ __all__ = [
|
|||||||
"RandomExampleSource",
|
"RandomExampleSource",
|
||||||
"SizeLossConfig",
|
"SizeLossConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"VolumeAugmentationConfig",
|
"VolumeAugmentationConfig",
|
||||||
|
|||||||
@ -33,8 +33,7 @@ from pydantic import Field
|
|||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.preprocess import PreprocessorProtocol
|
from batdetect2.typing import Augmentation, PreprocessorProtocol
|
||||||
from batdetect2.train.types import Augmentation
|
|
||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -14,16 +14,18 @@ from batdetect2.evaluate.match import (
|
|||||||
MatchConfig,
|
MatchConfig,
|
||||||
match_sound_events_and_raw_predictions,
|
match_sound_events_and_raw_predictions,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
from batdetect2.postprocess.types import (
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
BatDetect2Prediction,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.types import ModelOutput
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
MatchEvaluation,
|
||||||
|
MetricsProtocol,
|
||||||
|
ModelOutput,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
from soundevent import arrays
|
from soundevent import arrays
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.train.types import ClipperProtocol
|
from batdetect2.typing import ClipperProtocol
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
|
|||||||
@ -4,11 +4,8 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.models import BackboneConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.postprocess import PostprocessConfig
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig
|
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
@ -85,16 +82,10 @@ def load_train_config(
|
|||||||
return load_config(path, schema=TrainingConfig, field=field)
|
return load_config(path, schema=TrainingConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
class FullTrainingConfig(BaseConfig):
|
class FullTrainingConfig(ModelConfig):
|
||||||
"""Full training configuration."""
|
"""Full training configuration."""
|
||||||
|
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from soundevent import data
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from batdetect2.train.augmentations import Augmentation
|
from batdetect2.train.augmentations import Augmentation
|
||||||
from batdetect2.train.types import ClipperProtocol, TrainExample
|
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||||
from batdetect2.utils.tensors import adjust_width
|
from batdetect2.utils.tensors import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -34,10 +34,10 @@ from scipy.ndimage import gaussian_filter
|
|||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import (
|
||||||
from batdetect2.train.types import (
|
|
||||||
ClipLabeller,
|
ClipLabeller,
|
||||||
Heatmaps,
|
Heatmaps,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -3,12 +3,8 @@ import torch
|
|||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import ModelOutput
|
from batdetect2.models import Model
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.typing import ModelOutput, TrainExample
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train import TrainExample
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
@ -18,11 +14,8 @@ __all__ = [
|
|||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector: DetectionModel,
|
model: Model,
|
||||||
loss: torch.nn.Module,
|
loss: torch.nn.Module,
|
||||||
targets: TargetProtocol,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
):
|
):
|
||||||
@ -32,18 +25,14 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.targets = targets
|
self.model = model
|
||||||
self.detector = detector
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.postprocessor = postprocessor
|
|
||||||
|
|
||||||
self.save_hyperparameters(logger=False)
|
self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.detector(spec)
|
return self.model(spec)
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample):
|
def training_step(self, batch: TrainExample):
|
||||||
outputs = self.forward(batch.spec)
|
outputs = self.model(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/train", losses.total, logger=True)
|
self.log("detection_loss/train", losses.total, logger=True)
|
||||||
@ -56,7 +45,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
batch: TrainExample,
|
batch: TrainExample,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
outputs = self.forward(batch.spec)
|
outputs = self.model(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/val", losses.total, logger=True)
|
self.log("detection_loss/val", losses.total, logger=True)
|
||||||
|
|||||||
@ -28,9 +28,7 @@ from pydantic import Field
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
||||||
from batdetect2.train.dataset import TrainExample
|
|
||||||
from batdetect2.train.types import Losses, LossProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxLoss",
|
"BBoxLoss",
|
||||||
|
|||||||
@ -34,10 +34,9 @@ from tqdm.auto import tqdm
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
|
|||||||
@ -14,12 +14,6 @@ from batdetect2.evaluate.metrics import (
|
|||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
from batdetect2.models import build_model
|
from batdetect2.models import build_model
|
||||||
from batdetect2.postprocess import build_postprocessor
|
|
||||||
from batdetect2.preprocess import (
|
|
||||||
PreprocessorProtocol,
|
|
||||||
build_preprocessor,
|
|
||||||
)
|
|
||||||
from batdetect2.targets import TargetProtocol, build_targets
|
|
||||||
from batdetect2.train.augmentations import build_augmentations
|
from batdetect2.train.augmentations import build_augmentations
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
@ -32,6 +26,7 @@ from batdetect2.train.dataset import (
|
|||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
|
from batdetect2.typing import PreprocessorProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
@ -88,25 +83,11 @@ def train(
|
|||||||
|
|
||||||
|
|
||||||
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
||||||
targets = build_targets(config=config.targets)
|
model = build_model(config=config)
|
||||||
loss = build_loss(config=config.train.loss)
|
loss = build_loss(config=config.train.loss)
|
||||||
preprocessor = build_preprocessor(config.preprocess)
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets,
|
|
||||||
config=config.postprocess,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
)
|
|
||||||
model = build_model(
|
|
||||||
num_classes=len(targets.class_names),
|
|
||||||
config=config.model,
|
|
||||||
)
|
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
detector=model,
|
model=model,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
targets=targets,
|
|
||||||
learning_rate=config.train.learning_rate,
|
learning_rate=config.train.learning_rate,
|
||||||
t_max=config.train.t_max,
|
t_max=config.train.t_max,
|
||||||
)
|
)
|
||||||
|
|||||||
58
src/batdetect2/typing/__init__.py
Normal file
58
src/batdetect2/typing/__init__.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
|
||||||
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
from batdetect2.typing.postprocess import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
GeometryDecoder,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.preprocess import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
SpectrogramBuilder,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.targets import (
|
||||||
|
Position,
|
||||||
|
Size,
|
||||||
|
SoundEventDecoder,
|
||||||
|
SoundEventEncoder,
|
||||||
|
SoundEventFilter,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.train import (
|
||||||
|
Augmentation,
|
||||||
|
ClipLabeller,
|
||||||
|
ClipperProtocol,
|
||||||
|
Heatmaps,
|
||||||
|
Losses,
|
||||||
|
LossProtocol,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioLoader",
|
||||||
|
"Augmentation",
|
||||||
|
"BackboneModel",
|
||||||
|
"BatDetect2Prediction",
|
||||||
|
"ClipLabeller",
|
||||||
|
"ClipperProtocol",
|
||||||
|
"DetectionModel",
|
||||||
|
"GeometryDecoder",
|
||||||
|
"Heatmaps",
|
||||||
|
"LossProtocol",
|
||||||
|
"Losses",
|
||||||
|
"MatchEvaluation",
|
||||||
|
"MetricsProtocol",
|
||||||
|
"ModelOutput",
|
||||||
|
"Position",
|
||||||
|
"PostprocessorProtocol",
|
||||||
|
"PreprocessorProtocol",
|
||||||
|
"RawPrediction",
|
||||||
|
"Size",
|
||||||
|
"SoundEventDecoder",
|
||||||
|
"SoundEventEncoder",
|
||||||
|
"SoundEventFilter",
|
||||||
|
"SpectrogramBuilder",
|
||||||
|
"TargetProtocol",
|
||||||
|
"TrainExample",
|
||||||
|
]
|
||||||
@ -19,7 +19,6 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
@ -65,7 +64,7 @@ class ModelOutput(NamedTuple):
|
|||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class BackboneModel(ABC, nn.Module):
|
class BackboneModel(ABC, torch.nn.Module):
|
||||||
"""Abstract Base Class for generic feature extraction backbone models.
|
"""Abstract Base Class for generic feature extraction backbone models.
|
||||||
|
|
||||||
Defines the minimal interface for a feature extractor network within a
|
Defines the minimal interface for a feature extractor network within a
|
||||||
@ -191,7 +190,7 @@ class EncoderDecoderModel(BackboneModel):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, nn.Module):
|
class DetectionModel(ABC, torch.nn.Module):
|
||||||
"""Abstract Base Class for complete BatDetect2 detection models.
|
"""Abstract Base Class for complete BatDetect2 detection models.
|
||||||
|
|
||||||
Defines the interface for the overall model that takes an input spectrogram
|
Defines the interface for the overall model that takes an input spectrogram
|
||||||
@ -18,8 +18,8 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.typing.models import ModelOutput
|
||||||
from batdetect2.targets.types import Position, Size
|
from batdetect2.typing.targets import Position, Size
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RawPrediction",
|
"RawPrediction",
|
||||||
@ -16,6 +16,12 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioLoader",
|
||||||
|
"SpectrogramBuilder",
|
||||||
|
"PreprocessorProtocol",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AudioLoader(Protocol):
|
class AudioLoader(Protocol):
|
||||||
"""Defines the interface for an audio loading and processing component.
|
"""Defines the interface for an audio loading and processing component.
|
||||||
@ -12,6 +12,7 @@ that components responsible for these tasks can be interacted with consistently
|
|||||||
throughout BatDetect2.
|
throughout BatDetect2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -19,10 +20,40 @@ from soundevent import data
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
|
"SoundEventEncoder",
|
||||||
|
"SoundEventDecoder",
|
||||||
|
"SoundEventFilter",
|
||||||
"Position",
|
"Position",
|
||||||
"Size",
|
"Size",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
"""Type alias for a sound event class encoder function.
|
||||||
|
|
||||||
|
An encoder function takes a sound event annotation and returns the string name
|
||||||
|
of the target class it belongs to, based on a predefined set of rules.
|
||||||
|
If the annotation does not match any defined target class according to the
|
||||||
|
rules, the function returns None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SoundEventDecoder = Callable[[str], List[data.Tag]]
|
||||||
|
"""Type alias for a sound event class decoder function.
|
||||||
|
|
||||||
|
A decoder function takes a class name string (as predicted by the model or
|
||||||
|
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
|
||||||
|
that represent that class according to the configuration. This is used to
|
||||||
|
translate model outputs back into meaningful annotations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
"""Type alias for a filter function.
|
||||||
|
|
||||||
|
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
||||||
|
and returns True if the annotation should be kept based on the filter's
|
||||||
|
criteria, or False if it should be discarded.
|
||||||
|
"""
|
||||||
|
|
||||||
Position = tuple[float, float]
|
Position = tuple[float, float]
|
||||||
"""A tuple representing (time, frequency) coordinates."""
|
"""A tuple representing (time, frequency) coordinates."""
|
||||||
|
|
||||||
@ -4,13 +4,15 @@ import torch
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models import ModelOutput
|
from batdetect2.typing.models import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Heatmaps",
|
|
||||||
"ClipLabeller",
|
|
||||||
"Augmentation",
|
"Augmentation",
|
||||||
|
"ClipLabeller",
|
||||||
|
"ClipperProtocol",
|
||||||
|
"Heatmaps",
|
||||||
"LossProtocol",
|
"LossProtocol",
|
||||||
|
"Losses",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -12,7 +12,6 @@ from soundevent import data, terms
|
|||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
@ -22,9 +21,12 @@ from batdetect2.targets import (
|
|||||||
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
||||||
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
||||||
from batdetect2.targets.terms import TagInfo
|
from batdetect2.targets.terms import TagInfo
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.typing import (
|
||||||
|
ClipLabeller,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -15,8 +15,7 @@ from batdetect2.postprocess.decoding import (
|
|||||||
get_generic_tags,
|
get_generic_tags,
|
||||||
get_prediction_features,
|
get_prediction_features,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import RawPrediction
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -142,7 +142,7 @@ def test_audio_config_defaults():
|
|||||||
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||||
assert config.resample.method == "poly"
|
assert config.resample.method == "poly"
|
||||||
assert config.scale == audio.SCALE_RAW_AUDIO
|
assert config.scale == audio.SCALE_RAW_AUDIO
|
||||||
assert config.center is True
|
assert config.center is False
|
||||||
assert config.duration == audio.DEFAULT_DURATION
|
assert config.duration == audio.DEFAULT_DURATION
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -108,7 +108,7 @@ def test_spec_size_config_defaults():
|
|||||||
|
|
||||||
def test_pcen_config_defaults():
|
def test_pcen_config_defaults():
|
||||||
config = PcenConfig()
|
config = PcenConfig()
|
||||||
assert config.time_constant == 0.4
|
assert config.time_constant == 0.01
|
||||||
assert config.gain == 0.98
|
assert config.gain == 0.98
|
||||||
assert config.bias == 2
|
assert config.bias == 2
|
||||||
assert config.power == 0.5
|
assert config.power == 0.5
|
||||||
@ -202,13 +202,6 @@ def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
|
|||||||
|
|
||||||
|
|
||||||
def test_apply_pcen(sample_spec: xr.DataArray):
|
def test_apply_pcen(sample_spec: xr.DataArray):
|
||||||
if "original_samplerate" not in sample_spec.attrs:
|
|
||||||
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
|
||||||
if "nfft" not in sample_spec.attrs:
|
|
||||||
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
|
||||||
if "noverlap" not in sample_spec.attrs:
|
|
||||||
sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"])
|
|
||||||
|
|
||||||
pcen_config = PcenConfig()
|
pcen_config = PcenConfig()
|
||||||
pcen_spec = apply_pcen(
|
pcen_spec = apply_pcen(
|
||||||
sample_spec,
|
sample_spec,
|
||||||
|
|||||||
@ -5,14 +5,13 @@ import pytest
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
add_echo,
|
add_echo,
|
||||||
mix_examples,
|
mix_examples,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import select_subclip
|
from batdetect2.train.clips import select_subclip
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||||
|
|
||||||
|
|
||||||
def test_mix_examples(
|
def test_mix_examples(
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
from batdetect2.targets.terms import TagInfo
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
|
|||||||
@ -28,8 +28,8 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
|
|||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
|
||||||
spec1 = module.preprocessor.preprocess_clip(clip)
|
spec1 = module.model.preprocessor.preprocess_clip(clip)
|
||||||
spec2 = recovered.preprocessor.preprocess_clip(clip)
|
spec2 = recovered.model.preprocessor.preprocess_clip(clip)
|
||||||
|
|
||||||
xr.testing.assert_equal(spec1, spec2)
|
xr.testing.assert_equal(spec1, spec2)
|
||||||
|
|
||||||
|
|||||||
@ -4,12 +4,12 @@ import xarray as xr
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.terms import get_term
|
from soundevent.terms import get_term
|
||||||
|
|
||||||
from batdetect2.models.types import ModelOutput
|
|
||||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
from batdetect2.typing import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user