Add extra and strict arguments to load config functions, migrate example config

This commit is contained in:
mbsantiago 2026-03-18 22:06:45 +00:00
parent 652670b01d
commit 99b9e55c0e
4 changed files with 105 additions and 70 deletions

View File

@ -1,10 +1,15 @@
config_version: v1
audio:
samplerate: 256000
resample:
enabled: True
method: "poly"
enabled: true
method: poly
preprocess:
model:
samplerate: 256000
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
@ -23,12 +28,7 @@ preprocess:
power: 0.5
- name: spectral_mean_subtraction
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200
model:
architecture:
name: UNetBackbone
input_height: 128
in_channels: 1
@ -62,6 +62,11 @@ model:
- name: ConvBlock
out_channels: 32
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200
train:
optimizer:
name: adam
@ -80,8 +85,7 @@ train:
train_loader:
batch_size: 8
num_workers: 2
shuffle: True
shuffle: true
clipping_strategy:
name: random_subclip
@ -117,7 +121,6 @@ train:
max_masks: 3
val_loader:
num_workers: 2
clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
@ -136,9 +139,6 @@ train:
size:
weight: 0.1
logger:
name: csv
validation:
tasks:
- name: sound_event_detection
@ -148,6 +148,10 @@ train:
metrics:
- name: average_precision
logging:
train:
name: csv
evaluation:
tasks:
- name: sound_event_detection

View File

@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
configuration sections.
"""
from typing import Any, Type, TypeVar, overload
from typing import Any, Literal, Type, TypeVar, overload
import yaml
from deepmerge.merger import Merger
@ -69,8 +69,20 @@ class BaseConfig(BaseModel):
return cls.model_validate(yaml.safe_load(yaml_str))
@classmethod
def load(cls: Type[C], path: PathLike, field: str | None = None) -> C:
return load_config(path, schema=cls, field=field)
def load(
cls: Type[C],
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> C:
return load_config(
path,
schema=cls,
field=field,
extra=extra,
strict=strict,
)
T = TypeVar("T")
@ -142,6 +154,8 @@ def load_config(
path: PathLike,
schema: Type[T_Model],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T_Model: ...
@ -150,6 +164,8 @@ def load_config(
path: PathLike,
schema: TypeAdapter[T],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T: ...
@ -157,6 +173,8 @@ def load_config(
path: PathLike,
schema: Type[T_Model] | TypeAdapter[T],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T_Model | T:
"""Load and validate configuration data from a file against a schema.
@ -178,6 +196,17 @@ def load_config(
file content is validated against the schema.
Example: `"training.optimizer"` would extract the `optimizer` section
within the `training` section.
extra : Literal["ignore", "allow", "forbid"], optional
How to handle extra keys in the configuration data. If None (default),
the default behaviour of the schema is used. If "ignore", extra keys
are ignored. If "allow", extra keys are allowed and will be accessible
as attributes on the resulting model instance. If "forbid", extra
keys are forbidden and an exception is raised. See pydantic
documentation for more details.
strict : bool, optional
Whether to enforce types strictly. If None (default), the default
behaviour of the schema is used. See pydantic documentation for more
details.
Returns
-------
@ -209,9 +238,9 @@ def load_config(
config = get_object_field(config, field)
if isinstance(schema, TypeAdapter):
return schema.validate_python(config or {})
return schema.validate_python(config or {}, extra=extra, strict=strict)
return schema.model_validate(config or {})
return schema.model_validate(config or {}, extra=extra, strict=strict)
default_merger = Merger(

View File

@ -204,7 +204,7 @@ def test_load_backbone_config_from_example_data(example_data_dir: Path):
"""load_backbone_config loads the real example config correctly."""
config = load_backbone_config(
example_data_dir / "config.yaml",
field="model",
field="model.architecture",
)
assert isinstance(config, UNetBackboneConfig)

View File

@ -6,5 +6,7 @@ def test_example_config_is_valid(example_data_dir):
conf = load_config(
example_data_dir / "config.yaml",
schema=BatDetect2Config,
extra="forbid",
strict=True,
)
assert isinstance(conf, BatDetect2Config)