diff --git a/example_data/config.yaml b/example_data/config.yaml index 4a274e5..0b2ffa3 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -1,66 +1,71 @@ +config_version: v1 + audio: samplerate: 256000 resample: - enabled: True - method: "poly" - -preprocess: - stft: - window_duration: 0.002 - window_overlap: 0.75 - window_fn: hann - frequencies: - max_freq: 120000 - min_freq: 10000 - size: - height: 128 - resize_factor: 0.5 - spectrogram_transforms: - - name: pcen - time_constant: 0.1 - gain: 0.98 - bias: 2 - power: 0.5 - - name: spectral_mean_subtraction - -postprocess: - nms_kernel_size: 9 - detection_threshold: 0.01 - top_k_per_sec: 200 + enabled: true + method: poly model: - name: UNetBackbone - input_height: 128 - in_channels: 1 - encoder: - layers: - - name: FreqCoordConvDown - out_channels: 32 - - name: FreqCoordConvDown - out_channels: 64 - - name: LayerGroup - layers: - - name: FreqCoordConvDown - out_channels: 128 - - name: ConvBlock - out_channels: 256 - bottleneck: - channels: 256 - layers: - - name: SelfAttention - attention_channels: 256 - decoder: - layers: - - name: FreqCoordConvUp - out_channels: 64 - - name: FreqCoordConvUp - out_channels: 32 - - name: LayerGroup - layers: - - name: FreqCoordConvUp - out_channels: 32 - - name: ConvBlock - out_channels: 32 + samplerate: 256000 + + preprocess: + stft: + window_duration: 0.002 + window_overlap: 0.75 + window_fn: hann + frequencies: + max_freq: 120000 + min_freq: 10000 + size: + height: 128 + resize_factor: 0.5 + spectrogram_transforms: + - name: pcen + time_constant: 0.1 + gain: 0.98 + bias: 2 + power: 0.5 + - name: spectral_mean_subtraction + + architecture: + name: UNetBackbone + input_height: 128 + in_channels: 1 + encoder: + layers: + - name: FreqCoordConvDown + out_channels: 32 + - name: FreqCoordConvDown + out_channels: 64 + - name: LayerGroup + layers: + - name: FreqCoordConvDown + out_channels: 128 + - name: ConvBlock + out_channels: 256 + bottleneck: + channels: 256 + layers: + - name: SelfAttention + attention_channels: 256 + decoder: + layers: + - name: FreqCoordConvUp + out_channels: 64 + - name: FreqCoordConvUp + out_channels: 32 + - name: LayerGroup + layers: + - name: FreqCoordConvUp + out_channels: 32 + - name: ConvBlock + out_channels: 32 + + postprocess: + nms_kernel_size: 9 + detection_threshold: 0.01 + top_k_per_sec: 200 train: optimizer: @@ -80,8 +85,7 @@ train: train_loader: batch_size: 8 - num_workers: 2 - shuffle: True + shuffle: true clipping_strategy: name: random_subclip @@ -117,7 +121,6 @@ train: max_masks: 3 val_loader: - num_workers: 2 clipping_strategy: name: whole_audio_padded chunk_size: 0.256 @@ -136,9 +139,6 @@ train: size: weight: 0.1 - logger: - name: csv - validation: tasks: - name: sound_event_detection @@ -148,6 +148,10 @@ train: metrics: - name: average_precision +logging: + train: + name: csv + evaluation: tasks: - name: sound_event_detection diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index 49bb435..0fd7dba 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested configuration sections. """ -from typing import Any, Type, TypeVar, overload +from typing import Any, Literal, Type, TypeVar, overload import yaml from deepmerge.merger import Merger @@ -69,8 +69,20 @@ class BaseConfig(BaseModel): return cls.model_validate(yaml.safe_load(yaml_str)) @classmethod - def load(cls: Type[C], path: PathLike, field: str | None = None) -> C: - return load_config(path, schema=cls, field=field) + def load( + cls: Type[C], + path: PathLike, + field: str | None = None, + extra: Literal["ignore", "allow", "forbid"] | None = None, + strict: bool | None = None, + ) -> C: + return load_config( + path, + schema=cls, + field=field, + extra=extra, + strict=strict, + ) T = TypeVar("T") @@ -142,6 +154,8 @@ def load_config( path: PathLike, schema: Type[T_Model], field: str | None = None, + extra: Literal["ignore", "allow", "forbid"] | None = None, + strict: bool | None = None, ) -> T_Model: ... @@ -150,6 +164,8 @@ def load_config( path: PathLike, schema: TypeAdapter[T], field: str | None = None, + extra: Literal["ignore", "allow", "forbid"] | None = None, + strict: bool | None = None, ) -> T: ... @@ -157,6 +173,8 @@ def load_config( path: PathLike, schema: Type[T_Model] | TypeAdapter[T], field: str | None = None, + extra: Literal["ignore", "allow", "forbid"] | None = None, + strict: bool | None = None, ) -> T_Model | T: """Load and validate configuration data from a file against a schema. @@ -178,6 +196,17 @@ def load_config( file content is validated against the schema. Example: `"training.optimizer"` would extract the `optimizer` section within the `training` section. + extra : Literal["ignore", "allow", "forbid"], optional + How to handle extra keys in the configuration data. If None (default), + the default behaviour of the schema is used. If "ignore", extra keys + are ignored. If "allow", extra keys are allowed and will be accessible + as attributes on the resulting model instance. If "forbid", extra + keys are forbidden and an exception is raised. See pydantic + documentation for more details. + strict : bool, optional + Whether to enforce types strictly. If None (default), the default + behaviour of the schema is used. See pydantic documentation for more + details. Returns ------- @@ -209,9 +238,9 @@ def load_config( config = get_object_field(config, field) if isinstance(schema, TypeAdapter): - return schema.validate_python(config or {}) + return schema.validate_python(config or {}, extra=extra, strict=strict) - return schema.model_validate(config or {}) + return schema.model_validate(config or {}, extra=extra, strict=strict) default_merger = Merger( diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index 7af92cd..8565a07 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -204,7 +204,7 @@ def test_load_backbone_config_from_example_data(example_data_dir: Path): """load_backbone_config loads the real example config correctly.""" config = load_backbone_config( example_data_dir / "config.yaml", - field="model", + field="model.architecture", ) assert isinstance(config, UNetBackboneConfig) diff --git a/tests/test_train/test_config.py b/tests/test_train/test_config.py index 4142149..1384887 100644 --- a/tests/test_train/test_config.py +++ b/tests/test_train/test_config.py @@ -6,5 +6,7 @@ def test_example_config_is_valid(example_data_dir): conf = load_config( example_data_dir / "config.yaml", schema=BatDetect2Config, + extra="forbid", + strict=True, ) assert isinstance(conf, BatDetect2Config)