mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Moved types to dedicated module
This commit is contained in:
parent
02adc19070
commit
61115d562c
@ -4,7 +4,7 @@ from typing import Optional, Tuple
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
def iterate_over_sound_events(
|
||||
|
||||
@ -7,7 +7,7 @@ from batdetect2.data.summary import (
|
||||
extract_recordings_df,
|
||||
extract_sound_events_df,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
def split_dataset_by_recordings(
|
||||
|
||||
@ -2,7 +2,7 @@ import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"extract_recordings_df",
|
||||
|
||||
@ -4,15 +4,15 @@ from typing import List, Literal, Optional, Tuple
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import (
|
||||
match_geometries as optimal_match,
|
||||
)
|
||||
from soundevent.evaluation import match_geometries as optimal_match
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
MatchEvaluation,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
MatchingStrategy = Literal["greedy", "optimal"]
|
||||
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
||||
|
||||
@ -4,7 +4,7 @@ import pandas as pd
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.typing import MatchEvaluation, MetricsProtocol
|
||||
|
||||
__all__ = ["DetectionAveragePrecision"]
|
||||
|
||||
|
||||
@ -28,8 +28,11 @@ provided here.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from lightning import LightningModule
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.backbones import (
|
||||
Backbone,
|
||||
BackboneConfig,
|
||||
@ -53,24 +56,25 @@ from batdetect2.models.decoder import (
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.detectors import (
|
||||
Detector,
|
||||
build_detector,
|
||||
)
|
||||
from batdetect2.models.detectors import Detector, build_detector
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.typing.models import DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import PostprocessorProtocol
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"Backbone",
|
||||
"BackboneConfig",
|
||||
"BackboneModel",
|
||||
"BackboneModel",
|
||||
"Bottleneck",
|
||||
"BottleneckConfig",
|
||||
"ClassifierHead",
|
||||
@ -78,65 +82,75 @@ __all__ = [
|
||||
"DEFAULT_DECODER_CONFIG",
|
||||
"DEFAULT_ENCODER_CONFIG",
|
||||
"DecoderConfig",
|
||||
"DetectionModel",
|
||||
"Detector",
|
||||
"DetectorHead",
|
||||
"EncoderConfig",
|
||||
"FreqCoordConvDownConfig",
|
||||
"FreqCoordConvUpConfig",
|
||||
"ModelOutput",
|
||||
"StandardConvDownConfig",
|
||||
"StandardConvUpConfig",
|
||||
"build_backbone",
|
||||
"build_bottleneck",
|
||||
"build_decoder",
|
||||
"build_detector",
|
||||
"build_encoder",
|
||||
"build_model",
|
||||
"build_detector",
|
||||
"load_backbone_config",
|
||||
"Model",
|
||||
"ModelConfig",
|
||||
"build_model",
|
||||
]
|
||||
|
||||
|
||||
def build_model(
|
||||
num_classes: int,
|
||||
config: Optional[BackboneConfig] = None,
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
class Model(LightningModule):
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
targets: TargetProtocol
|
||||
|
||||
This high-level factory function constructs the standard BatDetect2 model
|
||||
architecture. It first builds the feature extraction backbone (typically an
|
||||
encoder-bottleneck-decoder structure) based on the provided
|
||||
`BackboneConfig` (or defaults if None), and then attaches the standard
|
||||
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
|
||||
`build_detector` function.
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(required for the `ClassifierHead`). Must be positive.
|
||||
config : BackboneConfig, optional
|
||||
Configuration object specifying the architecture of the backbone
|
||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||
within the respective builder functions (`build_encoder`, etc.) will be
|
||||
used to construct a default backbone architecture.
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionModel
|
||||
An initialized `Detector` model instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive, or if errors occur during the
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
"""
|
||||
config = config or BackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
class ModelConfig(BaseConfig):
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
def build_model(config: Optional[ModelConfig] = None):
|
||||
config = config or ModelConfig()
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
postprocessor = build_postprocessor(
|
||||
targets=targets,
|
||||
config=config.postprocess,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config.model,
|
||||
)
|
||||
return Model(
|
||||
detector=detector,
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
backbone = build_backbone(config)
|
||||
return build_detector(num_classes, backbone)
|
||||
|
||||
@ -43,7 +43,7 @@ from batdetect2.models.encoder import (
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.types import BackboneModel
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Backbone",
|
||||
|
||||
@ -8,18 +8,25 @@ classifying them.
|
||||
|
||||
The primary components are:
|
||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||
- `build_detector`: A factory function to conveniently construct a standard
|
||||
`Detector` instance given a backbone and the number of target classes.
|
||||
|
||||
This module focuses purely on the neural network architecture definition. The
|
||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Detector",
|
||||
"build_detector",
|
||||
]
|
||||
|
||||
|
||||
class Detector(DetectionModel):
|
||||
@ -119,36 +126,41 @@ class Detector(DetectionModel):
|
||||
)
|
||||
|
||||
|
||||
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
||||
"""Factory function to build a standard Detector model instance.
|
||||
|
||||
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
|
||||
`BBoxHead`) configured appropriately based on the output channels of the
|
||||
provided `backbone` and the specified `num_classes`. It then assembles
|
||||
these components into a `Detector` model.
|
||||
def build_detector(
|
||||
num_classes: int, config: Optional[BackboneConfig] = None
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes for the classification head
|
||||
(excluding any implicit background class). Must be positive.
|
||||
backbone : BackboneModel
|
||||
An initialized feature extraction backbone module instance. The number
|
||||
of output channels from this backbone (`backbone.out_channels`) is used
|
||||
to configure the input channels for the prediction heads.
|
||||
The number of specific target classes the model should predict
|
||||
(required for the `ClassifierHead`). Must be positive.
|
||||
config : BackboneConfig, optional
|
||||
Configuration object specifying the architecture of the backbone
|
||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||
within the respective builder functions (`build_encoder`, etc.) will be
|
||||
used to construct a default backbone architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Detector
|
||||
DetectionModel
|
||||
An initialized `Detector` model instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive.
|
||||
AttributeError
|
||||
If `backbone` does not have the required `out_channels` attribute.
|
||||
If `num_classes` is not positive, or if errors occur during the
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
"""
|
||||
config = config or BackboneConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
classifier_head = ClassifierHead(
|
||||
num_classes=num_classes,
|
||||
in_channels=backbone.out_channels,
|
||||
|
||||
@ -4,7 +4,7 @@ from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_annotation",
|
||||
|
||||
@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
|
||||
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_prediction",
|
||||
|
||||
@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from batdetect2 import plotting
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -6,13 +6,13 @@ from soundevent import data, plot
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
get_default_preprocessor,
|
||||
)
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
|
||||
__all__ = [
|
||||
"plot_matches",
|
||||
|
||||
@ -36,7 +36,6 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
@ -62,13 +61,14 @@ from batdetect2.postprocess.remapping import (
|
||||
features_to_xarray,
|
||||
sizes_to_xarray,
|
||||
)
|
||||
from batdetect2.postprocess.types import (
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
@ -79,8 +79,6 @@ __all__ = [
|
||||
"NMS_KERNEL_SIZE",
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"PostprocessorProtocol",
|
||||
"RawPrediction",
|
||||
"TOP_K_PER_SEC",
|
||||
"build_postprocessor",
|
||||
"classification_to_xarray",
|
||||
|
||||
@ -32,8 +32,8 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
|
||||
@ -57,7 +57,7 @@ from batdetect2.preprocess.spectrogram import (
|
||||
build_spectrogram_builder,
|
||||
get_spectrogram_resolution,
|
||||
)
|
||||
from batdetect2.preprocess.types import (
|
||||
from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
@ -65,7 +65,6 @@ from batdetect2.preprocess.types import (
|
||||
|
||||
__all__ = [
|
||||
"AudioConfig",
|
||||
"AudioLoader",
|
||||
"ConfigurableSpectrogramBuilder",
|
||||
"DEFAULT_DURATION",
|
||||
"FrequencyConfig",
|
||||
@ -77,7 +76,6 @@ __all__ = [
|
||||
"SCALE_RAW_AUDIO",
|
||||
"STFTConfig",
|
||||
"SpecSizeConfig",
|
||||
"SpectrogramBuilder",
|
||||
"SpectrogramConfig",
|
||||
"StandardPreprocessor",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
|
||||
@ -32,7 +32,7 @@ from soundevent.arrays import operations as ops
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.types import AudioLoader
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"ResampleConfig",
|
||||
|
||||
@ -31,7 +31,7 @@ from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.audio import convert_to_xr
|
||||
from batdetect2.preprocess.types import SpectrogramBuilder
|
||||
from batdetect2.typing.preprocess import SpectrogramBuilder
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
@ -540,7 +540,6 @@ def apply_pcen(
|
||||
xr.DataArray
|
||||
PCEN-scaled spectrogram.
|
||||
"""
|
||||
spec = spec * (2**31)
|
||||
samplerate = 1 / spec.time.attrs["step"]
|
||||
hop_size = spec.attrs["hop_size"]
|
||||
|
||||
@ -559,22 +558,24 @@ def apply_pcen(
|
||||
[1, smoothing_constant - 1],
|
||||
)[:]
|
||||
|
||||
spec_data = spec.data * (2**31)
|
||||
|
||||
# Smooth the input array along the given axis
|
||||
smoothed, _ = signal.lfilter(
|
||||
[smoothing_constant],
|
||||
[1, smoothing_constant - 1],
|
||||
spec.data,
|
||||
spec_data,
|
||||
zi=zi,
|
||||
axis=axis, # type: ignore
|
||||
)
|
||||
|
||||
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
|
||||
data = (bias**power) * np.expm1(
|
||||
power * np.log1p(spec.data * smooth / bias)
|
||||
power * np.log1p(spec_data * smooth / bias)
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data,
|
||||
data.astype(spec.dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
|
||||
@ -80,7 +80,7 @@ from batdetect2.targets.transform import (
|
||||
load_transformation_from_config,
|
||||
register_derivation,
|
||||
)
|
||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassesConfig",
|
||||
@ -99,7 +99,6 @@ __all__ = [
|
||||
"TagInfo",
|
||||
"TargetClass",
|
||||
"TargetConfig",
|
||||
"TargetProtocol",
|
||||
"Targets",
|
||||
"TermInfo",
|
||||
"TransformConfig",
|
||||
|
||||
@ -14,6 +14,7 @@ from batdetect2.targets.terms import (
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
)
|
||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SPECIES_LIST",
|
||||
@ -27,25 +28,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
|
||||
An encoder function takes a sound event annotation and returns the string name
|
||||
of the target class it belongs to, based on a predefined set of rules.
|
||||
If the annotation does not match any defined target class according to the
|
||||
rules, the function returns None.
|
||||
"""
|
||||
|
||||
|
||||
SoundEventDecoder = Callable[[str], List[data.Tag]]
|
||||
"""Type alias for a sound event class decoder function.
|
||||
|
||||
A decoder function takes a class name string (as predicted by the model or
|
||||
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
|
||||
that represent that class according to the configuration. This is used to
|
||||
translate model outputs back into meaningful annotations.
|
||||
"""
|
||||
|
||||
DEFAULT_SPECIES_LIST = [
|
||||
"Barbastella barbastellus",
|
||||
"Eptesicus serotinus",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Callable, List, Literal, Optional, Set
|
||||
from typing import List, Literal, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -12,11 +12,11 @@ from batdetect2.targets.terms import (
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
)
|
||||
from batdetect2.typing.targets import SoundEventFilter
|
||||
|
||||
__all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"SoundEventFilter",
|
||||
"build_sound_event_filter",
|
||||
"build_filter_from_rule",
|
||||
"load_filter_config",
|
||||
@ -24,14 +24,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||
"""Type alias for a filter function.
|
||||
|
||||
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
||||
and returns True if the annotation should be kept based on the filter's
|
||||
criteria, or False if it should be discarded.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@ -28,8 +28,8 @@ from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import Position, Size
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
|
||||
@ -24,7 +24,6 @@ from batdetect2.train.config import (
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
TrainExample,
|
||||
list_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
@ -64,7 +63,6 @@ __all__ = [
|
||||
"RandomExampleSource",
|
||||
"SizeLossConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
"TrainingConfig",
|
||||
"TrainingModule",
|
||||
"VolumeAugmentationConfig",
|
||||
|
||||
@ -33,8 +33,7 @@ from pydantic import Field
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.train.types import Augmentation
|
||||
from batdetect2.typing import Augmentation, PreprocessorProtocol
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -14,16 +14,18 @@ from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
match_sound_events_and_raw_predictions,
|
||||
)
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||
from batdetect2.postprocess.types import (
|
||||
BatDetect2Prediction,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.types import ModelOutput
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
MatchEvaluation,
|
||||
MetricsProtocol,
|
||||
ModelOutput,
|
||||
PostprocessorProtocol,
|
||||
TargetProtocol,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
|
||||
@ -6,7 +6,7 @@ from loguru import logger
|
||||
from soundevent import arrays
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.train.types import ClipperProtocol
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
@ -4,11 +4,8 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.models import BackboneConfig
|
||||
from batdetect2.postprocess import PostprocessConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
@ -85,16 +82,10 @@ def load_train_config(
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
||||
|
||||
|
||||
class FullTrainingConfig(BaseConfig):
|
||||
class FullTrainingConfig(ModelConfig):
|
||||
"""Full training configuration."""
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.train.augmentations import Augmentation
|
||||
from batdetect2.train.types import ClipperProtocol, TrainExample
|
||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||
from batdetect2.utils.tensors import adjust_width
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -34,10 +34,10 @@ from scipy.ndimage import gaussian_filter
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.types import (
|
||||
from batdetect2.typing import (
|
||||
ClipLabeller,
|
||||
Heatmaps,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -3,12 +3,8 @@ import torch
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import ModelOutput
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train import TrainExample
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
@ -18,11 +14,8 @@ __all__ = [
|
||||
class TrainingModule(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
model: Model,
|
||||
loss: torch.nn.Module,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
):
|
||||
@ -32,18 +25,14 @@ class TrainingModule(L.LightningModule):
|
||||
self.t_max = t_max
|
||||
|
||||
self.loss = loss
|
||||
self.targets = targets
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
|
||||
self.model = model
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
return self.model(spec)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
outputs = self.model(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/train", losses.total, logger=True)
|
||||
@ -56,7 +45,7 @@ class TrainingModule(L.LightningModule):
|
||||
batch: TrainExample,
|
||||
batch_idx: int,
|
||||
) -> ModelOutput:
|
||||
outputs = self.forward(batch.spec)
|
||||
outputs = self.model(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/val", losses.total, logger=True)
|
||||
|
||||
@ -28,9 +28,7 @@ from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
from batdetect2.train.types import Losses, LossProtocol
|
||||
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"BBoxLoss",
|
||||
|
||||
@ -34,10 +34,9 @@ from tqdm.auto import tqdm
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
|
||||
@ -14,12 +14,6 @@ from batdetect2.evaluate.metrics import (
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
build_preprocessor,
|
||||
)
|
||||
from batdetect2.targets import TargetProtocol, build_targets
|
||||
from batdetect2.train.augmentations import build_augmentations
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
@ -32,6 +26,7 @@ from batdetect2.train.dataset import (
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import PreprocessorProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_train_dataset",
|
||||
@ -88,25 +83,11 @@ def train(
|
||||
|
||||
|
||||
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
||||
targets = build_targets(config=config.targets)
|
||||
model = build_model(config=config)
|
||||
loss = build_loss(config=config.train.loss)
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
config=config.postprocess,
|
||||
max_freq=preprocessor.max_freq,
|
||||
min_freq=preprocessor.min_freq,
|
||||
)
|
||||
model = build_model(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config.model,
|
||||
)
|
||||
return TrainingModule(
|
||||
detector=model,
|
||||
model=model,
|
||||
loss=loss,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
targets=targets,
|
||||
learning_rate=config.train.learning_rate,
|
||||
t_max=config.train.t_max,
|
||||
)
|
||||
|
||||
58
src/batdetect2/typing/__init__.py
Normal file
58
src/batdetect2/typing/__init__.py
Normal file
@ -0,0 +1,58 @@
|
||||
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
GeometryDecoder,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
)
|
||||
from batdetect2.typing.targets import (
|
||||
Position,
|
||||
Size,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
SoundEventFilter,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.train import (
|
||||
Augmentation,
|
||||
ClipLabeller,
|
||||
ClipperProtocol,
|
||||
Heatmaps,
|
||||
Losses,
|
||||
LossProtocol,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AudioLoader",
|
||||
"Augmentation",
|
||||
"BackboneModel",
|
||||
"BatDetect2Prediction",
|
||||
"ClipLabeller",
|
||||
"ClipperProtocol",
|
||||
"DetectionModel",
|
||||
"GeometryDecoder",
|
||||
"Heatmaps",
|
||||
"LossProtocol",
|
||||
"Losses",
|
||||
"MatchEvaluation",
|
||||
"MetricsProtocol",
|
||||
"ModelOutput",
|
||||
"Position",
|
||||
"PostprocessorProtocol",
|
||||
"PreprocessorProtocol",
|
||||
"RawPrediction",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"SpectrogramBuilder",
|
||||
"TargetProtocol",
|
||||
"TrainExample",
|
||||
]
|
||||
@ -19,7 +19,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
@ -65,7 +64,7 @@ class ModelOutput(NamedTuple):
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
class BackboneModel(ABC, nn.Module):
|
||||
class BackboneModel(ABC, torch.nn.Module):
|
||||
"""Abstract Base Class for generic feature extraction backbone models.
|
||||
|
||||
Defines the minimal interface for a feature extractor network within a
|
||||
@ -191,7 +190,7 @@ class EncoderDecoderModel(BackboneModel):
|
||||
...
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
class DetectionModel(ABC, torch.nn.Module):
|
||||
"""Abstract Base Class for complete BatDetect2 detection models.
|
||||
|
||||
Defines the interface for the overall model that takes an input spectrogram
|
||||
@ -18,8 +18,8 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.targets.types import Position, Size
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing.targets import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"RawPrediction",
|
||||
@ -16,6 +16,12 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"AudioLoader",
|
||||
"SpectrogramBuilder",
|
||||
"PreprocessorProtocol",
|
||||
]
|
||||
|
||||
|
||||
class AudioLoader(Protocol):
|
||||
"""Defines the interface for an audio loading and processing component.
|
||||
@ -12,6 +12,7 @@ that components responsible for these tasks can be interacted with consistently
|
||||
throughout BatDetect2.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
@ -19,10 +20,40 @@ from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"TargetProtocol",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventFilter",
|
||||
"Position",
|
||||
"Size",
|
||||
]
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
|
||||
An encoder function takes a sound event annotation and returns the string name
|
||||
of the target class it belongs to, based on a predefined set of rules.
|
||||
If the annotation does not match any defined target class according to the
|
||||
rules, the function returns None.
|
||||
"""
|
||||
|
||||
|
||||
SoundEventDecoder = Callable[[str], List[data.Tag]]
|
||||
"""Type alias for a sound event class decoder function.
|
||||
|
||||
A decoder function takes a class name string (as predicted by the model or
|
||||
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
|
||||
that represent that class according to the configuration. This is used to
|
||||
translate model outputs back into meaningful annotations.
|
||||
"""
|
||||
|
||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||
"""Type alias for a filter function.
|
||||
|
||||
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
||||
and returns True if the annotation should be kept based on the filter's
|
||||
criteria, or False if it should be discarded.
|
||||
"""
|
||||
|
||||
Position = tuple[float, float]
|
||||
"""A tuple representing (time, frequency) coordinates."""
|
||||
|
||||
@ -4,13 +4,15 @@ import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import ModelOutput
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Heatmaps",
|
||||
"ClipLabeller",
|
||||
"Augmentation",
|
||||
"ClipLabeller",
|
||||
"ClipperProtocol",
|
||||
"Heatmaps",
|
||||
"LossProtocol",
|
||||
"Losses",
|
||||
"TrainExample",
|
||||
]
|
||||
|
||||
@ -12,7 +12,6 @@ from soundevent import data, terms
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
TermRegistry,
|
||||
@ -22,9 +21,12 @@ from batdetect2.targets import (
|
||||
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
||||
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
from batdetect2.typing import (
|
||||
ClipLabeller,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -15,8 +15,7 @@ from batdetect2.postprocess.decoding import (
|
||||
get_generic_tags,
|
||||
get_prediction_features,
|
||||
)
|
||||
from batdetect2.postprocess.types import RawPrediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -142,7 +142,7 @@ def test_audio_config_defaults():
|
||||
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||
assert config.resample.method == "poly"
|
||||
assert config.scale == audio.SCALE_RAW_AUDIO
|
||||
assert config.center is True
|
||||
assert config.center is False
|
||||
assert config.duration == audio.DEFAULT_DURATION
|
||||
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ def test_spec_size_config_defaults():
|
||||
|
||||
def test_pcen_config_defaults():
|
||||
config = PcenConfig()
|
||||
assert config.time_constant == 0.4
|
||||
assert config.time_constant == 0.01
|
||||
assert config.gain == 0.98
|
||||
assert config.bias == 2
|
||||
assert config.power == 0.5
|
||||
@ -202,13 +202,6 @@ def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
|
||||
|
||||
|
||||
def test_apply_pcen(sample_spec: xr.DataArray):
|
||||
if "original_samplerate" not in sample_spec.attrs:
|
||||
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
||||
if "nfft" not in sample_spec.attrs:
|
||||
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
||||
if "noverlap" not in sample_spec.attrs:
|
||||
sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"])
|
||||
|
||||
pcen_config = PcenConfig()
|
||||
pcen_spec = apply_pcen(
|
||||
sample_spec,
|
||||
|
||||
@ -5,14 +5,13 @@ import pytest
|
||||
import xarray as xr
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.train.augmentations import (
|
||||
add_echo,
|
||||
mix_examples,
|
||||
)
|
||||
from batdetect2.train.clips import select_subclip
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||
|
||||
|
||||
def test_mix_examples(
|
||||
|
||||
@ -6,7 +6,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
|
||||
recording = data.Recording(
|
||||
|
||||
@ -28,8 +28,8 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
|
||||
spec1 = module.preprocessor.preprocess_clip(clip)
|
||||
spec2 = recovered.preprocessor.preprocess_clip(clip)
|
||||
spec1 = module.model.preprocessor.preprocess_clip(clip)
|
||||
spec2 = recovered.model.preprocessor.preprocess_clip(clip)
|
||||
|
||||
xr.testing.assert_equal(spec1, spec2)
|
||||
|
||||
|
||||
@ -4,12 +4,12 @@ import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.typing import ModelOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Loading…
Reference in New Issue
Block a user