Add option to override targets when loading the model config

This commit is contained in:
mbsantiago 2026-03-18 23:36:24 +00:00
parent 99b9e55c0e
commit 23ac619c50
2 changed files with 20 additions and 11 deletions

View File

@ -14,10 +14,7 @@ from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig
__all__ = [
"BatDetect2Config",
"validate_config",
]
__all__ = ["BatDetect2Config"]
class BatDetect2Config(BaseConfig):
@ -32,10 +29,3 @@ class BatDetect2Config(BaseConfig):
inference: InferenceConfig = Field(default_factory=InferenceConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
def validate_config(config: dict | None) -> BatDetect2Config:
if config is None:
return BatDetect2Config()
return BatDetect2Config.model_validate(config)

View File

@ -26,8 +26,11 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module.
"""
from typing import Literal
import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
@ -142,6 +145,22 @@ class ModelConfig(BaseConfig):
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
@classmethod
def load(
cls,
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
targets: TargetConfig | None = None,
) -> "ModelConfig":
config = super().load(path, field, extra, strict)
if targets is None:
return config
return config.model_copy(update={"targets": targets})
class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.