mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add extra and strict arguments to load config functions, migrate example config
This commit is contained in:
parent
652670b01d
commit
99b9e55c0e
@ -1,66 +1,71 @@
|
||||
config_version: v1
|
||||
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: True
|
||||
method: "poly"
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
enabled: true
|
||||
method: poly
|
||||
|
||||
model:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
samplerate: 256000
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
architecture:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
@ -80,8 +85,7 @@ train:
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
num_workers: 2
|
||||
shuffle: True
|
||||
shuffle: true
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
@ -117,7 +121,6 @@ train:
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
num_workers: 2
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
@ -136,9 +139,6 @@ train:
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
logger:
|
||||
name: csv
|
||||
|
||||
validation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
@ -148,6 +148,10 @@ train:
|
||||
metrics:
|
||||
- name: average_precision
|
||||
|
||||
logging:
|
||||
train:
|
||||
name: csv
|
||||
|
||||
evaluation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
|
||||
@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
from typing import Any, Literal, Type, TypeVar, overload
|
||||
|
||||
import yaml
|
||||
from deepmerge.merger import Merger
|
||||
@ -69,8 +69,20 @@ class BaseConfig(BaseModel):
|
||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||
|
||||
@classmethod
|
||||
def load(cls: Type[C], path: PathLike, field: str | None = None) -> C:
|
||||
return load_config(path, schema=cls, field=field)
|
||||
def load(
|
||||
cls: Type[C],
|
||||
path: PathLike,
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> C:
|
||||
return load_config(
|
||||
path,
|
||||
schema=cls,
|
||||
field=field,
|
||||
extra=extra,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -142,6 +154,8 @@ def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T_Model],
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> T_Model: ...
|
||||
|
||||
|
||||
@ -150,6 +164,8 @@ def load_config(
|
||||
path: PathLike,
|
||||
schema: TypeAdapter[T],
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> T: ...
|
||||
|
||||
|
||||
@ -157,6 +173,8 @@ def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T_Model] | TypeAdapter[T],
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> T_Model | T:
|
||||
"""Load and validate configuration data from a file against a schema.
|
||||
|
||||
@ -178,6 +196,17 @@ def load_config(
|
||||
file content is validated against the schema.
|
||||
Example: `"training.optimizer"` would extract the `optimizer` section
|
||||
within the `training` section.
|
||||
extra : Literal["ignore", "allow", "forbid"], optional
|
||||
How to handle extra keys in the configuration data. If None (default),
|
||||
the default behaviour of the schema is used. If "ignore", extra keys
|
||||
are ignored. If "allow", extra keys are allowed and will be accessible
|
||||
as attributes on the resulting model instance. If "forbid", extra
|
||||
keys are forbidden and an exception is raised. See pydantic
|
||||
documentation for more details.
|
||||
strict : bool, optional
|
||||
Whether to enforce types strictly. If None (default), the default
|
||||
behaviour of the schema is used. See pydantic documentation for more
|
||||
details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -209,9 +238,9 @@ def load_config(
|
||||
config = get_object_field(config, field)
|
||||
|
||||
if isinstance(schema, TypeAdapter):
|
||||
return schema.validate_python(config or {})
|
||||
return schema.validate_python(config or {}, extra=extra, strict=strict)
|
||||
|
||||
return schema.model_validate(config or {})
|
||||
return schema.model_validate(config or {}, extra=extra, strict=strict)
|
||||
|
||||
|
||||
default_merger = Merger(
|
||||
|
||||
@ -204,7 +204,7 @@ def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
||||
"""load_backbone_config loads the real example config correctly."""
|
||||
config = load_backbone_config(
|
||||
example_data_dir / "config.yaml",
|
||||
field="model",
|
||||
field="model.architecture",
|
||||
)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
|
||||
@ -6,5 +6,7 @@ def test_example_config_is_valid(example_data_dir):
|
||||
conf = load_config(
|
||||
example_data_dir / "config.yaml",
|
||||
schema=BatDetect2Config,
|
||||
extra="forbid",
|
||||
strict=True,
|
||||
)
|
||||
assert isinstance(conf, BatDetect2Config)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user