Add option to override targets when loading the model config

This commit is contained in:
mbsantiago 2026-03-18 23:36:24 +00:00
parent 99b9e55c0e
commit 23ac619c50
2 changed files with 20 additions and 11 deletions

View File

@ -14,10 +14,7 @@ from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
__all__ = [ __all__ = ["BatDetect2Config"]
"BatDetect2Config",
"validate_config",
]
class BatDetect2Config(BaseConfig): class BatDetect2Config(BaseConfig):
@ -32,10 +29,3 @@ class BatDetect2Config(BaseConfig):
inference: InferenceConfig = Field(default_factory=InferenceConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig) outputs: OutputsConfig = Field(default_factory=OutputsConfig)
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig) logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
def validate_config(config: dict | None) -> BatDetect2Config:
if config is None:
return BatDetect2Config()
return BatDetect2Config.model_validate(config)

View File

@ -26,8 +26,11 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module. is the ``build_model`` factory function exported from this module.
""" """
from typing import Literal
import torch import torch
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
@ -142,6 +145,22 @@ class ModelConfig(BaseConfig):
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig) targets: TargetConfig = Field(default_factory=TargetConfig)
@classmethod
def load(
cls,
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
targets: TargetConfig | None = None,
) -> "ModelConfig":
config = super().load(path, field, extra, strict)
if targets is None:
return config
return config.model_copy(update={"targets": targets})
class Model(torch.nn.Module): class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing. """End-to-end BatDetect2 model wrapping preprocessing and postprocessing.