diff --git a/src/batdetect2/audio/clips.py b/src/batdetect2/audio/clips.py index 8364a40..dd70756 100644 --- a/src/batdetect2/audio/clips.py +++ b/src/batdetect2/audio/clips.py @@ -6,7 +6,12 @@ from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds, intervals_overlap -from batdetect2.core import BaseConfig, Registry +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) from batdetect2.typing import ClipperProtocol DEFAULT_TRAIN_CLIP_DURATION = 0.256 @@ -16,12 +21,24 @@ DEFAULT_MAX_EMPTY_CLIP = 0.1 __all__ = [ "build_clipper", "ClipConfig", + "ClipperImportConfig", ] clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper") +@add_import_config(clipper_registry) +class ClipperImportConfig(ImportConfig): + """Use any callable as a clipper. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class RandomClipConfig(BaseConfig): name: Literal["random_subclip"] = "random_subclip" duration: float = DEFAULT_TRAIN_CLIP_DURATION diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py index 015ea2f..ac52152 100644 --- a/src/batdetect2/data/conditions.py +++ b/src/batdetect2/data/conditions.py @@ -6,13 +6,28 @@ from soundevent import data from soundevent.geometry import compute_bounds from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] conditions: Registry[SoundEventCondition, []] = Registry("condition") +@add_import_config(conditions) +class SoundEventConditionImportConfig(ImportConfig): + """Use any callable as a sound event condition. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class HasTagConfig(BaseConfig): name: Literal["has_tag"] = "has_tag" tag: data.Tag diff --git a/src/batdetect2/data/predictions/base.py b/src/batdetect2/data/predictions/base.py index f6f253e..0d92a2d 100644 --- a/src/batdetect2/data/predictions/base.py +++ b/src/batdetect2/data/predictions/base.py @@ -1,8 +1,10 @@ +from typing import Literal + from pathlib import Path from soundevent.data import PathLike -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.typing import ( OutputFormatterProtocol, TargetProtocol, @@ -27,3 +29,14 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path: prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = ( Registry(name="output_formatter") ) + + +@add_import_config(prediction_formatters) +class PredictionFormatterImportConfig(ImportConfig): + """Use any callable as a prediction formatter. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py index 0e9b5d2..63bede7 100644 --- a/src/batdetect2/data/transforms.py +++ b/src/batdetect2/data/transforms.py @@ -5,7 +5,11 @@ from pydantic import Field from soundevent import data from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) from batdetect2.data.conditions import ( SoundEventCondition, SoundEventConditionConfig, @@ -20,6 +24,17 @@ SoundEventTransform = Callable[ transforms: Registry[SoundEventTransform, []] = Registry("transform") +@add_import_config(transforms) +class SoundEventTransformImportConfig(ImportConfig): + """Use any callable as a sound event transform. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class SetFrequencyBoundConfig(BaseConfig): name: Literal["set_frequency"] = "set_frequency" boundary: Literal["low", "high"] = "low" diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index a6a4d6b..a4141fa 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -10,7 +10,12 @@ from soundevent.geometry import ( compute_temporal_iou, ) -from batdetect2.core import BaseConfig, Registry +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) from batdetect2.typing import AffinityFunction, Detection affinity_functions: Registry[AffinityFunction, []] = Registry( @@ -18,6 +23,17 @@ affinity_functions: Registry[AffinityFunction, []] = Registry( ) +@add_import_config(affinity_functions) +class AffinityFunctionImportConfig(ImportConfig): + """Use any callable as an affinity function. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class TimeAffinityConfig(BaseConfig): name: Literal["time_affinity"] = "time_affinity" position: Literal["start", "end", "center"] | float = "start" diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index 97b4134..6ef5971 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -16,7 +16,12 @@ from pydantic import Field from sklearn import metrics from soundevent import data -from batdetect2.core import BaseConfig, Registry +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) from batdetect2.evaluate.metrics.common import ( average_precision, compute_precision_recall, @@ -26,6 +31,7 @@ from batdetect2.typing import Detection, TargetProtocol __all__ = [ "ClassificationMetric", "ClassificationMetricConfig", + "ClassificationMetricImportConfig", "build_classification_metric", "compute_precision_recall_curves", ] @@ -58,6 +64,17 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = ( ) +@add_import_config(classification_metrics) +class ClassificationMetricImportConfig(ImportConfig): + """Use any callable as a classification metric. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class BaseClassificationConfig(BaseConfig): include: List[str] | None = None exclude: List[str] | None = None diff --git a/src/batdetect2/evaluate/metrics/clip_classification.py b/src/batdetect2/evaluate/metrics/clip_classification.py index 80e5020..8595a68 100644 --- a/src/batdetect2/evaluate/metrics/clip_classification.py +++ b/src/batdetect2/evaluate/metrics/clip_classification.py @@ -7,7 +7,11 @@ from pydantic import Field from sklearn import metrics from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) from batdetect2.evaluate.metrics.common import average_precision @@ -24,6 +28,17 @@ clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry( ) +@add_import_config(clip_classification_metrics) +class ClipClassificationMetricImportConfig(ImportConfig): + """Use any callable as a clip classification metric. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class ClipClassificationAveragePrecisionConfig(BaseConfig): name: Literal["average_precision"] = "average_precision" label: str = "average_precision" diff --git a/src/batdetect2/evaluate/metrics/clip_detection.py b/src/batdetect2/evaluate/metrics/clip_detection.py index 7613228..f21bebf 100644 --- a/src/batdetect2/evaluate/metrics/clip_detection.py +++ b/src/batdetect2/evaluate/metrics/clip_detection.py @@ -6,7 +6,11 @@ from pydantic import Field from sklearn import metrics from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) from batdetect2.evaluate.metrics.common import average_precision @@ -23,6 +27,17 @@ clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry( ) +@add_import_config(clip_detection_metrics) +class ClipDetectionMetricImportConfig(ImportConfig): + """Use any callable as a clip detection metric. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class ClipDetectionAveragePrecisionConfig(BaseConfig): name: Literal["average_precision"] = "average_precision" label: str = "average_precision" diff --git a/src/batdetect2/evaluate/metrics/detection.py b/src/batdetect2/evaluate/metrics/detection.py index d2e5a15..2f13915 100644 --- a/src/batdetect2/evaluate/metrics/detection.py +++ b/src/batdetect2/evaluate/metrics/detection.py @@ -13,13 +13,19 @@ from pydantic import Field from sklearn import metrics from soundevent import data -from batdetect2.core import BaseConfig, Registry +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) from batdetect2.evaluate.metrics.common import average_precision from batdetect2.typing import Detection __all__ = [ "DetectionMetricConfig", "DetectionMetric", + "DetectionMetricImportConfig", "build_detection_metric", ] @@ -46,6 +52,17 @@ DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric") +@add_import_config(detection_metrics) +class DetectionMetricImportConfig(ImportConfig): + """Use any callable as a detection metric. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class DetectionAveragePrecisionConfig(BaseConfig): name: Literal["average_precision"] = "average_precision" label: str = "average_precision" diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py index e1f00d1..5fa2605 100644 --- a/src/batdetect2/evaluate/metrics/top_class.py +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -13,7 +13,12 @@ from pydantic import Field from sklearn import metrics, preprocessing from soundevent import data -from batdetect2.core import BaseConfig, Registry +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) from batdetect2.evaluate.metrics.common import average_precision from batdetect2.typing import Detection from batdetect2.typing.targets import TargetProtocol @@ -21,6 +26,7 @@ from batdetect2.typing.targets import TargetProtocol __all__ = [ "TopClassMetricConfig", "TopClassMetric", + "TopClassMetricImportConfig", "build_top_class_metric", ] @@ -51,6 +57,17 @@ TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric") +@add_import_config(top_class_metrics) +class TopClassMetricImportConfig(ImportConfig): + """Use any callable as a top-class metric. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class TopClassAveragePrecisionConfig(BaseConfig): name: Literal["average_precision"] = "average_precision" label: str = "average_precision" diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py index 881ed93..f8727b0 100644 --- a/src/batdetect2/evaluate/plots/classification.py +++ b/src/batdetect2/evaluate/plots/classification.py @@ -12,7 +12,7 @@ from matplotlib.figure import Figure from pydantic import Field from sklearn import metrics -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.evaluate.metrics.classification import ( ClipEval, _extract_per_class_metric_data, @@ -40,6 +40,17 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = ( ) +@add_import_config(classification_plots) +class ClassificationPlotImportConfig(ImportConfig): + """Use any callable as a classification plot. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" diff --git a/src/batdetect2/evaluate/plots/clip_classification.py b/src/batdetect2/evaluate/plots/clip_classification.py index 7def4d0..df34482 100644 --- a/src/batdetect2/evaluate/plots/clip_classification.py +++ b/src/batdetect2/evaluate/plots/clip_classification.py @@ -12,7 +12,7 @@ from matplotlib.figure import Figure from pydantic import Field from sklearn import metrics -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.evaluate.metrics.clip_classification import ClipEval from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig @@ -26,6 +26,7 @@ from batdetect2.typing import TargetProtocol __all__ = [ "ClipClassificationPlotConfig", + "ClipClassificationPlotImportConfig", "ClipClassificationPlotter", "build_clip_classification_plotter", ] @@ -39,6 +40,17 @@ clip_classification_plots: Registry[ ] = Registry("clip_classification_plot") +@add_import_config(clip_classification_plots) +class ClipClassificationPlotImportConfig(ImportConfig): + """Use any callable as a clip classification plot. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" diff --git a/src/batdetect2/evaluate/plots/clip_detection.py b/src/batdetect2/evaluate/plots/clip_detection.py index 9dd2108..3e44804 100644 --- a/src/batdetect2/evaluate/plots/clip_detection.py +++ b/src/batdetect2/evaluate/plots/clip_detection.py @@ -13,7 +13,7 @@ from matplotlib.figure import Figure from pydantic import Field from sklearn import metrics -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.evaluate.metrics.clip_detection import ClipEval from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig @@ -22,6 +22,7 @@ from batdetect2.typing import TargetProtocol __all__ = [ "ClipDetectionPlotConfig", + "ClipDetectionPlotImportConfig", "ClipDetectionPlotter", "build_clip_detection_plotter", ] @@ -36,6 +37,17 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = ( ) +@add_import_config(clip_detection_plots) +class ClipDetectionPlotImportConfig(ImportConfig): + """Use any callable as a clip detection plot. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" diff --git a/src/batdetect2/evaluate/plots/detection.py b/src/batdetect2/evaluate/plots/detection.py index de0e309..cf73e5e 100644 --- a/src/batdetect2/evaluate/plots/detection.py +++ b/src/batdetect2/evaluate/plots/detection.py @@ -16,7 +16,7 @@ from pydantic import Field from sklearn import metrics from batdetect2.audio import AudioConfig, build_audio_loader -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.detection import ClipEval from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig @@ -32,6 +32,17 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry( ) +@add_import_config(detection_plots) +class DetectionPlotImportConfig(ImportConfig): + """Use any callable as a detection plot. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py index 5678b0f..d7a3ba1 100644 --- a/src/batdetect2/evaluate/plots/top_class.py +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -19,7 +19,7 @@ from pydantic import Field from sklearn import metrics from batdetect2.audio import AudioConfig, build_audio_loader -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.top_class import ( ClipEval, @@ -39,6 +39,17 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( ) +@add_import_config(top_class_plots) +class TopClassPlotImportConfig(ImportConfig): + """Use any callable as a top-class plot. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index bcfe38f..a229137 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -4,6 +4,7 @@ from typing import ( Generic, Iterable, List, + Literal, Sequence, Tuple, TypeVar, @@ -14,7 +15,12 @@ from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds -from batdetect2.core import BaseConfig, Registry +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) from batdetect2.evaluate.affinity import ( AffinityConfig, TimeAffinityConfig, @@ -31,6 +37,7 @@ from batdetect2.typing import ( __all__ = [ "BaseTaskConfig", "BaseTask", + "TaskImportConfig", ] tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry( @@ -38,6 +45,17 @@ tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry( ) +@add_import_config(tasks_registry) +class TaskImportConfig(ImportConfig): + """Use any callable as an evaluation task. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + T_Output = TypeVar("T_Output") diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index ca7c937..8f7a2b1 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -31,7 +31,11 @@ from pydantic import Field, TypeAdapter from soundevent import data from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) from batdetect2.models.bottleneck import ( DEFAULT_BOTTLENECK_CONFIG, BottleneckConfig, @@ -94,7 +98,20 @@ class UNetBackboneConfig(BaseConfig): backbone_registry: Registry[BackboneModel, []] = Registry("backbone") + +@add_import_config(backbone_registry) +class BackboneImportConfig(ImportConfig): + """Use any callable as a backbone model. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + __all__ = [ + "BackboneImportConfig", "UNetBackbone", "BackboneConfig", "load_backbone_config", diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 71d8496..7e083c7 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -53,10 +53,11 @@ import torch.nn.functional as F from pydantic import Field from torch import nn -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.core.configs import BaseConfig __all__ = [ + "BlockImportConfig", "ConvBlock", "LayerGroupConfig", "VerticalConv", @@ -119,6 +120,17 @@ class Block(nn.Module): block_registry: Registry[Block, [int, int]] = Registry("block") +@add_import_config(block_registry) +class BlockImportConfig(ImportConfig): + """Use any callable as a model block. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class SelfAttentionConfig(BaseConfig): """Configuration for a ``SelfAttention`` block. diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index a872e38..9d5c7f0 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -18,10 +18,16 @@ import torch from pydantic import Field from batdetect2.audio import TARGET_SAMPLERATE_HZ -from batdetect2.core import BaseConfig, Registry +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) from batdetect2.preprocess.common import center_tensor, peak_normalize __all__ = [ + "AudioTransformImportConfig", "CenterAudioConfig", "ScaleAudioConfig", "FixDurationConfig", @@ -35,6 +41,17 @@ audio_transforms: Registry[torch.nn.Module, [int]] = Registry( """Registry mapping audio transform config classes to their builder methods.""" +@add_import_config(audio_transforms) +class AudioTransformImportConfig(ImportConfig): + """Use any callable as an audio transform. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class CenterAudioConfig(BaseConfig): """Configuration for the DC-offset removal transform. diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index c579c1f..9d98365 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -19,11 +19,16 @@ from pydantic import Field from batdetect2.audio import TARGET_SAMPLERATE_HZ from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) from batdetect2.preprocess.common import peak_normalize __all__ = [ "STFTConfig", + "SpectrogramTransformImportConfig", "build_spectrogram_transform", "build_spectrogram_builder", ] @@ -426,6 +431,17 @@ spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry( ) +@add_import_config(spectrogram_transforms) +class SpectrogramTransformImportConfig(ImportConfig): + """Use any callable as a spectrogram transform. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class PcenConfig(BaseConfig): """Configuration for Per-Channel Energy Normalisation (PCEN). diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 4967710..eb4200c 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -27,7 +27,7 @@ from pydantic import Field from soundevent import data from batdetect2.audio import AudioConfig, build_audio_loader -from batdetect2.core import Registry +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.core.arrays import spec_to_xarray from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor @@ -49,6 +49,7 @@ __all__ = [ "PeakEnergyBBoxMapper", "PeakEnergyBBoxMapperConfig", "ROIMapperConfig", + "ROIMapperImportConfig", "ROITargetMapper", "SIZE_HEIGHT", "SIZE_ORDER", @@ -92,6 +93,17 @@ DEFAULT_ANCHOR = "bottom-left" roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper") +@add_import_config(roi_mapper_registry) +class ROIMapperImportConfig(ImportConfig): + """Use any callable as an ROI mapper. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class AnchorBBoxMapperConfig(BaseConfig): """Configuration for `AnchorBBoxMapper`. diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 67a5490..2c452cc 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -15,12 +15,17 @@ from batdetect2.audio.clips import get_subclip_annotation from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ from batdetect2.core.arrays import adjust_width from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) from batdetect2.typing import AudioLoader, Augmentation __all__ = [ "AugmentationConfig", "AugmentationsConfig", + "AudioAugmentationImportConfig", "DEFAULT_AUGMENTATION_CONFIG", "AddEchoConfig", "AudioSource", @@ -28,6 +33,7 @@ __all__ = [ "MixAudioConfig", "MaskTimeConfig", "ScaleVolumeConfig", + "SpecAugmentationImportConfig", "WarpConfig", "add_echo", "build_augmentations", @@ -51,6 +57,28 @@ spec_augmentations: Registry[Augmentation, []] = Registry( ) +@add_import_config(audio_augmentations) +class AudioAugmentationImportConfig(ImportConfig): + """Use any callable as an audio augmentation. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + +@add_import_config(spec_augmentations) +class SpecAugmentationImportConfig(ImportConfig): + """Use any callable as a spectrogram augmentation. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + class MixAudioConfig(BaseConfig): """Configuration for MixUp augmentation (mixing two examples)."""