mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add extra and strict arguments to load config functions, migrate example config
This commit is contained in:
parent
652670b01d
commit
99b9e55c0e
@ -1,8 +1,13 @@
|
|||||||
|
config_version: v1
|
||||||
|
|
||||||
audio:
|
audio:
|
||||||
samplerate: 256000
|
samplerate: 256000
|
||||||
resample:
|
resample:
|
||||||
enabled: True
|
enabled: true
|
||||||
method: "poly"
|
method: poly
|
||||||
|
|
||||||
|
model:
|
||||||
|
samplerate: 256000
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
stft:
|
stft:
|
||||||
@ -23,12 +28,7 @@ preprocess:
|
|||||||
power: 0.5
|
power: 0.5
|
||||||
- name: spectral_mean_subtraction
|
- name: spectral_mean_subtraction
|
||||||
|
|
||||||
postprocess:
|
architecture:
|
||||||
nms_kernel_size: 9
|
|
||||||
detection_threshold: 0.01
|
|
||||||
top_k_per_sec: 200
|
|
||||||
|
|
||||||
model:
|
|
||||||
name: UNetBackbone
|
name: UNetBackbone
|
||||||
input_height: 128
|
input_height: 128
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
@ -62,6 +62,11 @@ model:
|
|||||||
- name: ConvBlock
|
- name: ConvBlock
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
|
|
||||||
|
postprocess:
|
||||||
|
nms_kernel_size: 9
|
||||||
|
detection_threshold: 0.01
|
||||||
|
top_k_per_sec: 200
|
||||||
|
|
||||||
train:
|
train:
|
||||||
optimizer:
|
optimizer:
|
||||||
name: adam
|
name: adam
|
||||||
@ -80,8 +85,7 @@ train:
|
|||||||
|
|
||||||
train_loader:
|
train_loader:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
num_workers: 2
|
shuffle: true
|
||||||
shuffle: True
|
|
||||||
|
|
||||||
clipping_strategy:
|
clipping_strategy:
|
||||||
name: random_subclip
|
name: random_subclip
|
||||||
@ -117,7 +121,6 @@ train:
|
|||||||
max_masks: 3
|
max_masks: 3
|
||||||
|
|
||||||
val_loader:
|
val_loader:
|
||||||
num_workers: 2
|
|
||||||
clipping_strategy:
|
clipping_strategy:
|
||||||
name: whole_audio_padded
|
name: whole_audio_padded
|
||||||
chunk_size: 0.256
|
chunk_size: 0.256
|
||||||
@ -136,9 +139,6 @@ train:
|
|||||||
size:
|
size:
|
||||||
weight: 0.1
|
weight: 0.1
|
||||||
|
|
||||||
logger:
|
|
||||||
name: csv
|
|
||||||
|
|
||||||
validation:
|
validation:
|
||||||
tasks:
|
tasks:
|
||||||
- name: sound_event_detection
|
- name: sound_event_detection
|
||||||
@ -148,6 +148,10 @@ train:
|
|||||||
metrics:
|
metrics:
|
||||||
- name: average_precision
|
- name: average_precision
|
||||||
|
|
||||||
|
logging:
|
||||||
|
train:
|
||||||
|
name: csv
|
||||||
|
|
||||||
evaluation:
|
evaluation:
|
||||||
tasks:
|
tasks:
|
||||||
- name: sound_event_detection
|
- name: sound_event_detection
|
||||||
|
|||||||
@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
|
|||||||
configuration sections.
|
configuration sections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Type, TypeVar, overload
|
from typing import Any, Literal, Type, TypeVar, overload
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from deepmerge.merger import Merger
|
from deepmerge.merger import Merger
|
||||||
@ -69,8 +69,20 @@ class BaseConfig(BaseModel):
|
|||||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls: Type[C], path: PathLike, field: str | None = None) -> C:
|
def load(
|
||||||
return load_config(path, schema=cls, field=field)
|
cls: Type[C],
|
||||||
|
path: PathLike,
|
||||||
|
field: str | None = None,
|
||||||
|
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
|
) -> C:
|
||||||
|
return load_config(
|
||||||
|
path,
|
||||||
|
schema=cls,
|
||||||
|
field=field,
|
||||||
|
extra=extra,
|
||||||
|
strict=strict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -142,6 +154,8 @@ def load_config(
|
|||||||
path: PathLike,
|
path: PathLike,
|
||||||
schema: Type[T_Model],
|
schema: Type[T_Model],
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
|
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
) -> T_Model: ...
|
) -> T_Model: ...
|
||||||
|
|
||||||
|
|
||||||
@ -150,6 +164,8 @@ def load_config(
|
|||||||
path: PathLike,
|
path: PathLike,
|
||||||
schema: TypeAdapter[T],
|
schema: TypeAdapter[T],
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
|
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
) -> T: ...
|
) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@ -157,6 +173,8 @@ def load_config(
|
|||||||
path: PathLike,
|
path: PathLike,
|
||||||
schema: Type[T_Model] | TypeAdapter[T],
|
schema: Type[T_Model] | TypeAdapter[T],
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
|
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
) -> T_Model | T:
|
) -> T_Model | T:
|
||||||
"""Load and validate configuration data from a file against a schema.
|
"""Load and validate configuration data from a file against a schema.
|
||||||
|
|
||||||
@ -178,6 +196,17 @@ def load_config(
|
|||||||
file content is validated against the schema.
|
file content is validated against the schema.
|
||||||
Example: `"training.optimizer"` would extract the `optimizer` section
|
Example: `"training.optimizer"` would extract the `optimizer` section
|
||||||
within the `training` section.
|
within the `training` section.
|
||||||
|
extra : Literal["ignore", "allow", "forbid"], optional
|
||||||
|
How to handle extra keys in the configuration data. If None (default),
|
||||||
|
the default behaviour of the schema is used. If "ignore", extra keys
|
||||||
|
are ignored. If "allow", extra keys are allowed and will be accessible
|
||||||
|
as attributes on the resulting model instance. If "forbid", extra
|
||||||
|
keys are forbidden and an exception is raised. See pydantic
|
||||||
|
documentation for more details.
|
||||||
|
strict : bool, optional
|
||||||
|
Whether to enforce types strictly. If None (default), the default
|
||||||
|
behaviour of the schema is used. See pydantic documentation for more
|
||||||
|
details.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -209,9 +238,9 @@ def load_config(
|
|||||||
config = get_object_field(config, field)
|
config = get_object_field(config, field)
|
||||||
|
|
||||||
if isinstance(schema, TypeAdapter):
|
if isinstance(schema, TypeAdapter):
|
||||||
return schema.validate_python(config or {})
|
return schema.validate_python(config or {}, extra=extra, strict=strict)
|
||||||
|
|
||||||
return schema.model_validate(config or {})
|
return schema.model_validate(config or {}, extra=extra, strict=strict)
|
||||||
|
|
||||||
|
|
||||||
default_merger = Merger(
|
default_merger = Merger(
|
||||||
|
|||||||
@ -204,7 +204,7 @@ def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
|||||||
"""load_backbone_config loads the real example config correctly."""
|
"""load_backbone_config loads the real example config correctly."""
|
||||||
config = load_backbone_config(
|
config = load_backbone_config(
|
||||||
example_data_dir / "config.yaml",
|
example_data_dir / "config.yaml",
|
||||||
field="model",
|
field="model.architecture",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(config, UNetBackboneConfig)
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
|||||||
@ -6,5 +6,7 @@ def test_example_config_is_valid(example_data_dir):
|
|||||||
conf = load_config(
|
conf = load_config(
|
||||||
example_data_dir / "config.yaml",
|
example_data_dir / "config.yaml",
|
||||||
schema=BatDetect2Config,
|
schema=BatDetect2Config,
|
||||||
|
extra="forbid",
|
||||||
|
strict=True,
|
||||||
)
|
)
|
||||||
assert isinstance(conf, BatDetect2Config)
|
assert isinstance(conf, BatDetect2Config)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user