mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add option to override targets when loading the model config
This commit is contained in:
parent
99b9e55c0e
commit
23ac619c50
@ -14,10 +14,7 @@ from batdetect2.models import ModelConfig
|
|||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["BatDetect2Config"]
|
||||||
"BatDetect2Config",
|
|
||||||
"validate_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2Config(BaseConfig):
|
class BatDetect2Config(BaseConfig):
|
||||||
@ -32,10 +29,3 @@ class BatDetect2Config(BaseConfig):
|
|||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||||
|
|
||||||
|
|
||||||
def validate_config(config: dict | None) -> BatDetect2Config:
|
|
||||||
if config is None:
|
|
||||||
return BatDetect2Config()
|
|
||||||
|
|
||||||
return BatDetect2Config.model_validate(config)
|
|
||||||
|
|||||||
@ -26,8 +26,11 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
|
|||||||
is the ``build_model`` factory function exported from this module.
|
is the ``build_model`` factory function exported from this module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
@ -142,6 +145,22 @@ class ModelConfig(BaseConfig):
|
|||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
path: PathLike,
|
||||||
|
field: str | None = None,
|
||||||
|
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
|
targets: TargetConfig | None = None,
|
||||||
|
) -> "ModelConfig":
|
||||||
|
config = super().load(path, field, extra, strict)
|
||||||
|
|
||||||
|
if targets is None:
|
||||||
|
return config
|
||||||
|
|
||||||
|
return config.model_copy(update={"targets": targets})
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user