mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add dynamic imports to existing registries
This commit is contained in:
parent
038d58ed99
commit
8ac4f4c44d
@ -6,7 +6,12 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
@ -16,12 +21,24 @@ DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
__all__ = [
|
||||
"build_clipper",
|
||||
"ClipConfig",
|
||||
"ClipperImportConfig",
|
||||
]
|
||||
|
||||
|
||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
||||
|
||||
|
||||
@add_import_config(clipper_registry)
|
||||
class ClipperImportConfig(ImportConfig):
|
||||
"""Use any callable as a clipper.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class RandomClipConfig(BaseConfig):
|
||||
name: Literal["random_subclip"] = "random_subclip"
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
|
||||
@ -6,13 +6,28 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
|
||||
|
||||
@add_import_config(conditions)
|
||||
class SoundEventConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from typing import Literal
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.typing import (
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
@ -27,3 +29,14 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||
Registry(name="output_formatter")
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(prediction_formatters)
|
||||
class PredictionFormatterImportConfig(ImportConfig):
|
||||
"""Use any callable as a prediction formatter.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
@ -5,7 +5,11 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
@ -20,6 +24,17 @@ SoundEventTransform = Callable[
|
||||
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
||||
|
||||
|
||||
@add_import_config(transforms)
|
||||
class SoundEventTransformImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event transform.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class SetFrequencyBoundConfig(BaseConfig):
|
||||
name: Literal["set_frequency"] = "set_frequency"
|
||||
boundary: Literal["low", "high"] = "low"
|
||||
|
||||
@ -10,7 +10,12 @@ from soundevent.geometry import (
|
||||
compute_temporal_iou,
|
||||
)
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.typing import AffinityFunction, Detection
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
@ -18,6 +23,17 @@ affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(affinity_functions)
|
||||
class AffinityFunctionImportConfig(ImportConfig):
|
||||
"""Use any callable as an affinity function.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class TimeAffinityConfig(BaseConfig):
|
||||
name: Literal["time_affinity"] = "time_affinity"
|
||||
position: Literal["start", "end", "center"] | float = "start"
|
||||
|
||||
@ -16,7 +16,12 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import (
|
||||
average_precision,
|
||||
compute_precision_recall,
|
||||
@ -26,6 +31,7 @@ from batdetect2.typing import Detection, TargetProtocol
|
||||
__all__ = [
|
||||
"ClassificationMetric",
|
||||
"ClassificationMetricConfig",
|
||||
"ClassificationMetricImportConfig",
|
||||
"build_classification_metric",
|
||||
"compute_precision_recall_curves",
|
||||
]
|
||||
@ -58,6 +64,17 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(classification_metrics)
|
||||
class ClassificationMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a classification metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class BaseClassificationConfig(BaseConfig):
|
||||
include: List[str] | None = None
|
||||
exclude: List[str] | None = None
|
||||
|
||||
@ -7,7 +7,11 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@ -24,6 +28,17 @@ clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(clip_classification_metrics)
|
||||
class ClipClassificationMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip classification metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
@ -6,7 +6,11 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@ -23,6 +27,17 @@ clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(clip_detection_metrics)
|
||||
class ClipDetectionMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip detection metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
@ -13,13 +13,19 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import Detection
|
||||
|
||||
__all__ = [
|
||||
"DetectionMetricConfig",
|
||||
"DetectionMetric",
|
||||
"DetectionMetricImportConfig",
|
||||
"build_detection_metric",
|
||||
]
|
||||
|
||||
@ -46,6 +52,17 @@ DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
||||
|
||||
|
||||
@add_import_config(detection_metrics)
|
||||
class DetectionMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a detection metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
@ -13,7 +13,12 @@ from pydantic import Field
|
||||
from sklearn import metrics, preprocessing
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import Detection
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
@ -21,6 +26,7 @@ from batdetect2.typing.targets import TargetProtocol
|
||||
__all__ = [
|
||||
"TopClassMetricConfig",
|
||||
"TopClassMetric",
|
||||
"TopClassMetricImportConfig",
|
||||
"build_top_class_metric",
|
||||
]
|
||||
|
||||
@ -51,6 +57,17 @@ TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
||||
|
||||
|
||||
@add_import_config(top_class_metrics)
|
||||
class TopClassMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a top-class metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
@ -12,7 +12,7 @@ from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClipEval,
|
||||
_extract_per_class_metric_data,
|
||||
@ -40,6 +40,17 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(classification_plots)
|
||||
class ClassificationPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a classification plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
|
||||
@ -12,7 +12,7 @@ from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
@ -26,6 +26,7 @@ from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipClassificationPlotConfig",
|
||||
"ClipClassificationPlotImportConfig",
|
||||
"ClipClassificationPlotter",
|
||||
"build_clip_classification_plotter",
|
||||
]
|
||||
@ -39,6 +40,17 @@ clip_classification_plots: Registry[
|
||||
] = Registry("clip_classification_plot")
|
||||
|
||||
|
||||
@add_import_config(clip_classification_plots)
|
||||
class ClipClassificationPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip classification plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
|
||||
@ -13,7 +13,7 @@ from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
@ -22,6 +22,7 @@ from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionPlotConfig",
|
||||
"ClipDetectionPlotImportConfig",
|
||||
"ClipDetectionPlotter",
|
||||
"build_clip_detection_plotter",
|
||||
]
|
||||
@ -36,6 +37,17 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(clip_detection_plots)
|
||||
class ClipDetectionPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip detection plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
|
||||
@ -16,7 +16,7 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
@ -32,6 +32,17 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(detection_plots)
|
||||
class DetectionPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a detection plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
|
||||
@ -19,7 +19,7 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
@ -39,6 +39,17 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(top_class_plots)
|
||||
class TopClassPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a top-class plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import (
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
@ -14,7 +15,12 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
TimeAffinityConfig,
|
||||
@ -31,6 +37,7 @@ from batdetect2.typing import (
|
||||
__all__ = [
|
||||
"BaseTaskConfig",
|
||||
"BaseTask",
|
||||
"TaskImportConfig",
|
||||
]
|
||||
|
||||
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||
@ -38,6 +45,17 @@ tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(tasks_registry)
|
||||
class TaskImportConfig(ImportConfig):
|
||||
"""Use any callable as an evaluation task.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
T_Output = TypeVar("T_Output")
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,11 @@ from pydantic import Field, TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
@ -94,7 +98,20 @@ class UNetBackboneConfig(BaseConfig):
|
||||
|
||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||
|
||||
|
||||
@add_import_config(backbone_registry)
|
||||
class BackboneImportConfig(ImportConfig):
|
||||
"""Use any callable as a backbone model.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BackboneImportConfig",
|
||||
"UNetBackbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
|
||||
@ -53,10 +53,11 @@ import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"BlockImportConfig",
|
||||
"ConvBlock",
|
||||
"LayerGroupConfig",
|
||||
"VerticalConv",
|
||||
@ -119,6 +120,17 @@ class Block(nn.Module):
|
||||
block_registry: Registry[Block, [int, int]] = Registry("block")
|
||||
|
||||
|
||||
@add_import_config(block_registry)
|
||||
class BlockImportConfig(ImportConfig):
|
||||
"""Use any callable as a model block.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class SelfAttentionConfig(BaseConfig):
|
||||
"""Configuration for a ``SelfAttention`` block.
|
||||
|
||||
|
||||
@ -18,10 +18,16 @@ import torch
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.preprocess.common import center_tensor, peak_normalize
|
||||
|
||||
__all__ = [
|
||||
"AudioTransformImportConfig",
|
||||
"CenterAudioConfig",
|
||||
"ScaleAudioConfig",
|
||||
"FixDurationConfig",
|
||||
@ -35,6 +41,17 @@ audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
"""Registry mapping audio transform config classes to their builder methods."""
|
||||
|
||||
|
||||
@add_import_config(audio_transforms)
|
||||
class AudioTransformImportConfig(ImportConfig):
|
||||
"""Use any callable as an audio transform.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
"""Configuration for the DC-offset removal transform.
|
||||
|
||||
|
||||
@ -19,11 +19,16 @@ from pydantic import Field
|
||||
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.preprocess.common import peak_normalize
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
"SpectrogramTransformImportConfig",
|
||||
"build_spectrogram_transform",
|
||||
"build_spectrogram_builder",
|
||||
]
|
||||
@ -426,6 +431,17 @@ spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(spectrogram_transforms)
|
||||
class SpectrogramTransformImportConfig(ImportConfig):
|
||||
"""Use any callable as a spectrogram transform.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalisation (PCEN).
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.core.arrays import spec_to_xarray
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
@ -49,6 +49,7 @@ __all__ = [
|
||||
"PeakEnergyBBoxMapper",
|
||||
"PeakEnergyBBoxMapperConfig",
|
||||
"ROIMapperConfig",
|
||||
"ROIMapperImportConfig",
|
||||
"ROITargetMapper",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
@ -92,6 +93,17 @@ DEFAULT_ANCHOR = "bottom-left"
|
||||
roi_mapper_registry: Registry[ROITargetMapper, []] = Registry("roi_mapper")
|
||||
|
||||
|
||||
@add_import_config(roi_mapper_registry)
|
||||
class ROIMapperImportConfig(ImportConfig):
|
||||
"""Use any callable as an ROI mapper.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class AnchorBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `AnchorBBoxMapper`.
|
||||
|
||||
|
||||
@ -15,12 +15,17 @@ from batdetect2.audio.clips import get_subclip_annotation
|
||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.typing import AudioLoader, Augmentation
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
"AugmentationsConfig",
|
||||
"AudioAugmentationImportConfig",
|
||||
"DEFAULT_AUGMENTATION_CONFIG",
|
||||
"AddEchoConfig",
|
||||
"AudioSource",
|
||||
@ -28,6 +33,7 @@ __all__ = [
|
||||
"MixAudioConfig",
|
||||
"MaskTimeConfig",
|
||||
"ScaleVolumeConfig",
|
||||
"SpecAugmentationImportConfig",
|
||||
"WarpConfig",
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
@ -51,6 +57,28 @@ spec_augmentations: Registry[Augmentation, []] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(audio_augmentations)
|
||||
class AudioAugmentationImportConfig(ImportConfig):
|
||||
"""Use any callable as an audio augmentation.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
@add_import_config(spec_augmentations)
|
||||
class SpecAugmentationImportConfig(ImportConfig):
|
||||
"""Use any callable as a spectrogram augmentation.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class MixAudioConfig(BaseConfig):
|
||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user