Add dynamic imports to existing registries

This commit is contained in:
mbsantiago 2026-03-16 10:04:34 +00:00
parent 038d58ed99
commit 8ac4f4c44d
22 changed files with 356 additions and 22 deletions

View File

@ -6,7 +6,12 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_TRAIN_CLIP_DURATION = 0.256
@ -16,12 +21,24 @@ DEFAULT_MAX_EMPTY_CLIP = 0.1
__all__ = [ __all__ = [
"build_clipper", "build_clipper",
"ClipConfig", "ClipConfig",
"ClipperImportConfig",
] ]
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper") clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
@add_import_config(clipper_registry)
class ClipperImportConfig(ImportConfig):
"""Use any callable as a clipper.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class RandomClipConfig(BaseConfig): class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip" name: Literal["random_subclip"] = "random_subclip"
duration: float = DEFAULT_TRAIN_CLIP_DURATION duration: float = DEFAULT_TRAIN_CLIP_DURATION

View File

@ -6,13 +6,28 @@ from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
conditions: Registry[SoundEventCondition, []] = Registry("condition") conditions: Registry[SoundEventCondition, []] = Registry("condition")
@add_import_config(conditions)
class SoundEventConditionImportConfig(ImportConfig):
"""Use any callable as a sound event condition.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class HasTagConfig(BaseConfig): class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag" name: Literal["has_tag"] = "has_tag"
tag: data.Tag tag: data.Tag

View File

@ -1,8 +1,10 @@
from typing import Literal
from pathlib import Path from pathlib import Path
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.typing import ( from batdetect2.typing import (
OutputFormatterProtocol, OutputFormatterProtocol,
TargetProtocol, TargetProtocol,
@ -27,3 +29,14 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = ( prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter") Registry(name="output_formatter")
) )
@add_import_config(prediction_formatters)
class PredictionFormatterImportConfig(ImportConfig):
"""Use any callable as a prediction formatter.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"

View File

@ -5,7 +5,11 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions import ( from batdetect2.data.conditions import (
SoundEventCondition, SoundEventCondition,
SoundEventConditionConfig, SoundEventConditionConfig,
@ -20,6 +24,17 @@ SoundEventTransform = Callable[
transforms: Registry[SoundEventTransform, []] = Registry("transform") transforms: Registry[SoundEventTransform, []] = Registry("transform")
@add_import_config(transforms)
class SoundEventTransformImportConfig(ImportConfig):
"""Use any callable as a sound event transform.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class SetFrequencyBoundConfig(BaseConfig): class SetFrequencyBoundConfig(BaseConfig):
name: Literal["set_frequency"] = "set_frequency" name: Literal["set_frequency"] = "set_frequency"
boundary: Literal["low", "high"] = "low" boundary: Literal["low", "high"] = "low"

View File

@ -10,7 +10,12 @@ from soundevent.geometry import (
compute_temporal_iou, compute_temporal_iou,
) )
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.typing import AffinityFunction, Detection from batdetect2.typing import AffinityFunction, Detection
affinity_functions: Registry[AffinityFunction, []] = Registry( affinity_functions: Registry[AffinityFunction, []] = Registry(
@ -18,6 +23,17 @@ affinity_functions: Registry[AffinityFunction, []] = Registry(
) )
@add_import_config(affinity_functions)
class AffinityFunctionImportConfig(ImportConfig):
"""Use any callable as an affinity function.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class TimeAffinityConfig(BaseConfig): class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity" name: Literal["time_affinity"] = "time_affinity"
position: Literal["start", "end", "center"] | float = "start" position: Literal["start", "end", "center"] | float = "start"

View File

@ -16,7 +16,12 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from soundevent import data from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import ( from batdetect2.evaluate.metrics.common import (
average_precision, average_precision,
compute_precision_recall, compute_precision_recall,
@ -26,6 +31,7 @@ from batdetect2.typing import Detection, TargetProtocol
__all__ = [ __all__ = [
"ClassificationMetric", "ClassificationMetric",
"ClassificationMetricConfig", "ClassificationMetricConfig",
"ClassificationMetricImportConfig",
"build_classification_metric", "build_classification_metric",
"compute_precision_recall_curves", "compute_precision_recall_curves",
] ]
@ -58,6 +64,17 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
) )
@add_import_config(classification_metrics)
class ClassificationMetricImportConfig(ImportConfig):
"""Use any callable as a classification metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class BaseClassificationConfig(BaseConfig): class BaseClassificationConfig(BaseConfig):
include: List[str] | None = None include: List[str] | None = None
exclude: List[str] | None = None exclude: List[str] | None = None

View File

@ -7,7 +7,11 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
@ -24,6 +28,17 @@ clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
) )
@add_import_config(clip_classification_metrics)
class ClipClassificationMetricImportConfig(ImportConfig):
"""Use any callable as a clip classification metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class ClipClassificationAveragePrecisionConfig(BaseConfig): class ClipClassificationAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"

View File

@ -6,7 +6,11 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
@ -23,6 +27,17 @@ clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
) )
@add_import_config(clip_detection_metrics)
class ClipDetectionMetricImportConfig(ImportConfig):
"""Use any callable as a clip detection metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class ClipDetectionAveragePrecisionConfig(BaseConfig): class ClipDetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"

View File

@ -13,13 +13,19 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from soundevent import data from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import Detection from batdetect2.typing import Detection
__all__ = [ __all__ = [
"DetectionMetricConfig", "DetectionMetricConfig",
"DetectionMetric", "DetectionMetric",
"DetectionMetricImportConfig",
"build_detection_metric", "build_detection_metric",
] ]
@ -46,6 +52,17 @@ DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric") detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
@add_import_config(detection_metrics)
class DetectionMetricImportConfig(ImportConfig):
"""Use any callable as a detection metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class DetectionAveragePrecisionConfig(BaseConfig): class DetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"

View File

@ -13,7 +13,12 @@ from pydantic import Field
from sklearn import metrics, preprocessing from sklearn import metrics, preprocessing
from soundevent import data from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import Detection from batdetect2.typing import Detection
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
@ -21,6 +26,7 @@ from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"TopClassMetricConfig", "TopClassMetricConfig",
"TopClassMetric", "TopClassMetric",
"TopClassMetricImportConfig",
"build_top_class_metric", "build_top_class_metric",
] ]
@ -51,6 +57,17 @@ TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric") top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
@add_import_config(top_class_metrics)
class TopClassMetricImportConfig(ImportConfig):
"""Use any callable as a top-class metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class TopClassAveragePrecisionConfig(BaseConfig): class TopClassAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"

View File

@ -12,7 +12,7 @@ from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.classification import ( from batdetect2.evaluate.metrics.classification import (
ClipEval, ClipEval,
_extract_per_class_metric_data, _extract_per_class_metric_data,
@ -40,6 +40,17 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
) )
@add_import_config(classification_plots)
class ClassificationPlotImportConfig(ImportConfig):
"""Use any callable as a classification plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"

View File

@ -12,7 +12,7 @@ from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.clip_classification import ClipEval from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
@ -26,6 +26,7 @@ from batdetect2.typing import TargetProtocol
__all__ = [ __all__ = [
"ClipClassificationPlotConfig", "ClipClassificationPlotConfig",
"ClipClassificationPlotImportConfig",
"ClipClassificationPlotter", "ClipClassificationPlotter",
"build_clip_classification_plotter", "build_clip_classification_plotter",
] ]
@ -39,6 +40,17 @@ clip_classification_plots: Registry[
] = Registry("clip_classification_plot") ] = Registry("clip_classification_plot")
@add_import_config(clip_classification_plots)
class ClipClassificationPlotImportConfig(ImportConfig):
"""Use any callable as a clip classification plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"

View File

@ -13,7 +13,7 @@ from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.clip_detection import ClipEval from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
@ -22,6 +22,7 @@ from batdetect2.typing import TargetProtocol
__all__ = [ __all__ = [
"ClipDetectionPlotConfig", "ClipDetectionPlotConfig",
"ClipDetectionPlotImportConfig",
"ClipDetectionPlotter", "ClipDetectionPlotter",
"build_clip_detection_plotter", "build_clip_detection_plotter",
] ]
@ -36,6 +37,17 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
) )
@add_import_config(clip_detection_plots)
class ClipDetectionPlotImportConfig(ImportConfig):
"""Use any callable as a clip detection plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"

View File

@ -16,7 +16,7 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
@ -32,6 +32,17 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
) )
@add_import_config(detection_plots)
class DetectionPlotImportConfig(ImportConfig):
"""Use any callable as a detection plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"

View File

@ -19,7 +19,7 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ( from batdetect2.evaluate.metrics.top_class import (
ClipEval, ClipEval,
@ -39,6 +39,17 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
) )
@add_import_config(top_class_plots)
class TopClassPlotImportConfig(ImportConfig):
"""Use any callable as a top-class plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"

View File

@ -4,6 +4,7 @@ from typing import (
Generic, Generic,
Iterable, Iterable,
List, List,
Literal,
Sequence, Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
@ -14,7 +15,12 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.affinity import ( from batdetect2.evaluate.affinity import (
AffinityConfig, AffinityConfig,
TimeAffinityConfig, TimeAffinityConfig,
@ -31,6 +37,7 @@ from batdetect2.typing import (
__all__ = [ __all__ = [
"BaseTaskConfig", "BaseTaskConfig",
"BaseTask", "BaseTask",
"TaskImportConfig",
] ]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry( tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
@ -38,6 +45,17 @@ tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
) )
@add_import_config(tasks_registry)
class TaskImportConfig(ImportConfig):
"""Use any callable as an evaluation task.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
T_Output = TypeVar("T_Output") T_Output = TypeVar("T_Output")

View File

@ -31,7 +31,11 @@ from pydantic import Field, TypeAdapter
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.models.bottleneck import ( from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG, DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig, BottleneckConfig,
@ -94,7 +98,20 @@ class UNetBackboneConfig(BaseConfig):
backbone_registry: Registry[BackboneModel, []] = Registry("backbone") backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
@add_import_config(backbone_registry)
class BackboneImportConfig(ImportConfig):
"""Use any callable as a backbone model.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
__all__ = [ __all__ = [
"BackboneImportConfig",
"UNetBackbone", "UNetBackbone",
"BackboneConfig", "BackboneConfig",
"load_backbone_config", "load_backbone_config",

View File

@ -53,10 +53,11 @@ import torch.nn.functional as F
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
__all__ = [ __all__ = [
"BlockImportConfig",
"ConvBlock", "ConvBlock",
"LayerGroupConfig", "LayerGroupConfig",
"VerticalConv", "VerticalConv",
@ -119,6 +120,17 @@ class Block(nn.Module):
block_registry: Registry[Block, [int, int]] = Registry("block") block_registry: Registry[Block, [int, int]] = Registry("block")
@add_import_config(block_registry)
class BlockImportConfig(ImportConfig):
"""Use any callable as a model block.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class SelfAttentionConfig(BaseConfig): class SelfAttentionConfig(BaseConfig):
"""Configuration for a ``SelfAttention`` block. """Configuration for a ``SelfAttention`` block.

View File

@ -18,10 +18,16 @@ import torch
from pydantic import Field from pydantic import Field
from batdetect2.audio import TARGET_SAMPLERATE_HZ from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.preprocess.common import center_tensor, peak_normalize from batdetect2.preprocess.common import center_tensor, peak_normalize
__all__ = [ __all__ = [
"AudioTransformImportConfig",
"CenterAudioConfig", "CenterAudioConfig",
"ScaleAudioConfig", "ScaleAudioConfig",
"FixDurationConfig", "FixDurationConfig",
@ -35,6 +41,17 @@ audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
"""Registry mapping audio transform config classes to their builder methods.""" """Registry mapping audio transform config classes to their builder methods."""
@add_import_config(audio_transforms)
class AudioTransformImportConfig(ImportConfig):
"""Use any callable as an audio transform.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class CenterAudioConfig(BaseConfig): class CenterAudioConfig(BaseConfig):
"""Configuration for the DC-offset removal transform. """Configuration for the DC-offset removal transform.

View File

@ -19,11 +19,16 @@ from pydantic import Field
from batdetect2.audio import TARGET_SAMPLERATE_HZ from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.preprocess.common import peak_normalize from batdetect2.preprocess.common import peak_normalize
__all__ = [ __all__ = [
"STFTConfig", "STFTConfig",
"SpectrogramTransformImportConfig",
"build_spectrogram_transform", "build_spectrogram_transform",
"build_spectrogram_builder", "build_spectrogram_builder",
] ]
@ -426,6 +431,17 @@ spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
) )
@add_import_config(spectrogram_transforms)
class SpectrogramTransformImportConfig(ImportConfig):
"""Use any callable as a spectrogram transform.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PcenConfig(BaseConfig): class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalisation (PCEN). """Configuration for Per-Channel Energy Normalisation (PCEN).

View File

@ -27,7 +27,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.core.arrays import spec_to_xarray from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
@ -49,6 +49,7 @@ __all__ = [
"PeakEnergyBBoxMapper", "PeakEnergyBBoxMapper",
"PeakEnergyBBoxMapperConfig", "PeakEnergyBBoxMapperConfig",
"ROIMapperConfig", "ROIMapperConfig",
"ROIMapperImportConfig",
"ROITargetMapper", "ROITargetMapper",
"SIZE_HEIGHT", "SIZE_HEIGHT",
"SIZE_ORDER", "SIZE_ORDER",
@ -92,6 +93,17 @@ DEFAULT_ANCHOR = "bottom-left"
roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper") roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper")
@add_import_config(roi_mapper_registry)
class ROIMapperImportConfig(ImportConfig):
"""Use any callable as an ROI mapper.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class AnchorBBoxMapperConfig(BaseConfig): class AnchorBBoxMapperConfig(BaseConfig):
"""Configuration for `AnchorBBoxMapper`. """Configuration for `AnchorBBoxMapper`.

View File

@ -15,12 +15,17 @@ from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.arrays import adjust_width from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.typing import AudioLoader, Augmentation from batdetect2.typing import AudioLoader, Augmentation
__all__ = [ __all__ = [
"AugmentationConfig", "AugmentationConfig",
"AugmentationsConfig", "AugmentationsConfig",
"AudioAugmentationImportConfig",
"DEFAULT_AUGMENTATION_CONFIG", "DEFAULT_AUGMENTATION_CONFIG",
"AddEchoConfig", "AddEchoConfig",
"AudioSource", "AudioSource",
@ -28,6 +33,7 @@ __all__ = [
"MixAudioConfig", "MixAudioConfig",
"MaskTimeConfig", "MaskTimeConfig",
"ScaleVolumeConfig", "ScaleVolumeConfig",
"SpecAugmentationImportConfig",
"WarpConfig", "WarpConfig",
"add_echo", "add_echo",
"build_augmentations", "build_augmentations",
@ -51,6 +57,28 @@ spec_augmentations: Registry[Augmentation, []] = Registry(
) )
@add_import_config(audio_augmentations)
class AudioAugmentationImportConfig(ImportConfig):
"""Use any callable as an audio augmentation.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
@add_import_config(spec_augmentations)
class SpecAugmentationImportConfig(ImportConfig):
"""Use any callable as a spectrogram augmentation.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class MixAudioConfig(BaseConfig): class MixAudioConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples).""" """Configuration for MixUp augmentation (mixing two examples)."""