diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index 525f581..dddd48e 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -14,10 +14,7 @@ from batdetect2.models import ModelConfig from batdetect2.outputs import OutputsConfig from batdetect2.train.config import TrainingConfig -__all__ = [ - "BatDetect2Config", - "validate_config", -] +__all__ = ["BatDetect2Config"] class BatDetect2Config(BaseConfig): @@ -32,10 +29,3 @@ class BatDetect2Config(BaseConfig): inference: InferenceConfig = Field(default_factory=InferenceConfig) outputs: OutputsConfig = Field(default_factory=OutputsConfig) logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig) - - -def validate_config(config: dict | None) -> BatDetect2Config: - if config is None: - return BatDetect2Config() - - return BatDetect2Config.model_validate(config) diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index c09a0ba..cd1dc1f 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -26,8 +26,11 @@ The primary entry point for building a full, ready-to-use BatDetect2 model is the ``build_model`` factory function exported from this module. """ +from typing import Literal + import torch from pydantic import Field +from soundevent.data import PathLike from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ from batdetect2.core.configs import BaseConfig @@ -142,6 +145,22 @@ class ModelConfig(BaseConfig): postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) targets: TargetConfig = Field(default_factory=TargetConfig) + @classmethod + def load( + cls, + path: PathLike, + field: str | None = None, + extra: Literal["ignore", "allow", "forbid"] | None = None, + strict: bool | None = None, + targets: TargetConfig | None = None, + ) -> "ModelConfig": + config = super().load(path, field, extra, strict) + + if targets is None: + return config + + return config.model_copy(update={"targets": targets}) + class Model(torch.nn.Module): """End-to-end BatDetect2 model wrapping preprocessing and postprocessing.