mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add option to override targets when loading the model config
This commit is contained in:
parent
99b9e55c0e
commit
23ac619c50
@ -14,10 +14,7 @@ from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2Config",
|
||||
"validate_config",
|
||||
]
|
||||
__all__ = ["BatDetect2Config"]
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
@ -32,10 +29,3 @@ class BatDetect2Config(BaseConfig):
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||
|
||||
|
||||
def validate_config(config: dict | None) -> BatDetect2Config:
|
||||
if config is None:
|
||||
return BatDetect2Config()
|
||||
|
||||
return BatDetect2Config.model_validate(config)
|
||||
|
||||
@ -26,8 +26,11 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
|
||||
is the ``build_model`` factory function exported from this module.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
@ -142,6 +145,22 @@ class ModelConfig(BaseConfig):
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
path: PathLike,
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
targets: TargetConfig | None = None,
|
||||
) -> "ModelConfig":
|
||||
config = super().load(path, field, extra, strict)
|
||||
|
||||
if targets is None:
|
||||
return config
|
||||
|
||||
return config.model_copy(update={"targets": targets})
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user