mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
refactor: remove aggregate app config
This commit is contained in:
parent
c27e7f9f52
commit
7a10b7ffff
@ -2,7 +2,8 @@
|
||||
|
||||
`BatDetect2API` is the main entry point for the current Python workflow.
|
||||
|
||||
It wraps model loading, inference, evaluation, output formatting, and training-related entry points behind one object.
|
||||
It wraps model loading, inference, evaluation, output formatting, and
|
||||
training-related entry points behind one object.
|
||||
|
||||
Defined in `batdetect2.api_v2`.
|
||||
|
||||
@ -10,8 +11,8 @@ Defined in `batdetect2.api_v2`.
|
||||
|
||||
- `BatDetect2API.from_checkpoint(path, ...)`
|
||||
- load a trained checkpoint and optional config overrides.
|
||||
- `BatDetect2API.from_config(config)`
|
||||
- build a full stack from a `BatDetect2Config` object.
|
||||
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
||||
- build a full stack from separate config objects.
|
||||
|
||||
## Inference methods
|
||||
|
||||
@ -46,10 +47,12 @@ Defined in `batdetect2.api_v2`.
|
||||
|
||||
## Output persistence helpers
|
||||
|
||||
- `save_predictions(predictions, path, audio_dir=None, format=None, config=None)`
|
||||
- `save_predictions(predictions, path, audio_dir=None, format=None,
|
||||
config=None)`
|
||||
- `load_predictions(path, format=None, config=None)`
|
||||
|
||||
Use these when you want to save programmatic predictions without going through the CLI.
|
||||
Use these when you want to save programmatic predictions without going through
|
||||
the CLI.
|
||||
|
||||
## Training and evaluation entry points
|
||||
|
||||
@ -60,6 +63,9 @@ Use these when you want to save programmatic predictions without going through t
|
||||
|
||||
## Related pages
|
||||
|
||||
- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||
- Outputs config reference: {doc}`outputs-config`
|
||||
- Output formats reference: {doc}`output-formats`
|
||||
- Python tutorial:
|
||||
{doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||
- Outputs config reference:
|
||||
{doc}`outputs-config`
|
||||
- Output formats reference:
|
||||
{doc}`output-formats`
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
# Top-level app config reference
|
||||
|
||||
The top-level config object is `BatDetect2Config`.
|
||||
|
||||
Defined in `batdetect2.config`.
|
||||
|
||||
It combines the main configuration surfaces used across training, inference, evaluation, outputs, and logging.
|
||||
|
||||
## Fields
|
||||
|
||||
- `config_version`
|
||||
- `train`
|
||||
- training-specific config.
|
||||
- `evaluation`
|
||||
- evaluation task and plot config.
|
||||
- `model`
|
||||
- model architecture, preprocessing, postprocessing, and targets.
|
||||
- `audio`
|
||||
- audio loading and resampling config.
|
||||
- `inference`
|
||||
- clipping and loader config for prediction-time workflows.
|
||||
- `outputs`
|
||||
- output format and output transform config.
|
||||
- `logging`
|
||||
- logging backend and formatting config.
|
||||
|
||||
## Mental model
|
||||
|
||||
Think of `BatDetect2Config` as the complete application wiring for the current stack.
|
||||
|
||||
Use it when you want one reproducible config that describes the whole workflow.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Inference config: {doc}`inference-config`
|
||||
- Evaluation config: {doc}`evaluation-config`
|
||||
- Outputs config: {doc}`outputs-config`
|
||||
- General config reference: {doc}`configs`
|
||||
@ -24,8 +24,8 @@ for full options and argument details.
|
||||
- Global CLI options are documented in {doc}`base`.
|
||||
- Paths with spaces should be wrapped in quotes.
|
||||
- Input audio is expected to be mono.
|
||||
- Legacy `detect` uses a required threshold argument, while `predict` uses
|
||||
the optional `--detection-threshold` override.
|
||||
- Legacy `detect` uses a required threshold argument, while `predict` uses the
|
||||
optional `--detection-threshold` override.
|
||||
|
||||
```{warning}
|
||||
`batdetect2 detect` is a legacy command.
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
Config reference
|
||||
================
|
||||
|
||||
.. automodule:: batdetect2.config
|
||||
:members:
|
||||
BatDetect2 uses separate config objects for different workflow surfaces.
|
||||
|
||||
Use the dedicated reference pages for each config family:
|
||||
|
||||
- inference config
|
||||
- evaluation config
|
||||
- outputs config
|
||||
- preprocessing config
|
||||
- postprocess config
|
||||
- targets config workflow
|
||||
|
||||
Example config files live under `example_data/configs/`.
|
||||
|
||||
@ -2,14 +2,14 @@
|
||||
|
||||
Reference pages are the detailed lookup pages.
|
||||
|
||||
Use this section when you need exact command options, setting names, output details, or Python API entries.
|
||||
Use this section when you need exact command options, setting names, output
|
||||
details, or Python API entries.
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
cli/index
|
||||
api
|
||||
app-config
|
||||
inference-config
|
||||
evaluation-config
|
||||
outputs-config
|
||||
|
||||
@ -1,192 +0,0 @@
|
||||
config_version: v1
|
||||
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: true
|
||||
method: poly
|
||||
|
||||
model:
|
||||
samplerate: 256000
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
architecture:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
name: adam
|
||||
learning_rate: 0.001
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 10
|
||||
check_val_every_n_epoch: 5
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
shuffle: true
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
duration: 0.256
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
classification:
|
||||
weight: 2.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
validation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
|
||||
logging:
|
||||
train:
|
||||
name: csv
|
||||
|
||||
evaluation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: score_distribution
|
||||
- name: example_detection
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: top_class_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: confusion_matrix
|
||||
- name: example_classification
|
||||
- name: clip_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
- name: score_distribution
|
||||
- name: clip_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
4
example_data/configs/audio.yaml
Normal file
4
example_data/configs/audio.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: true
|
||||
method: poly
|
||||
37
example_data/configs/evaluation.yaml
Normal file
37
example_data/configs/evaluation.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: score_distribution
|
||||
- name: example_detection
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: top_class_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: confusion_matrix
|
||||
- name: example_classification
|
||||
- name: clip_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
- name: score_distribution
|
||||
- name: clip_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
9
example_data/configs/inference.yaml
Normal file
9
example_data/configs/inference.yaml
Normal file
@ -0,0 +1,9 @@
|
||||
loader:
|
||||
batch_size: 8
|
||||
|
||||
clipping:
|
||||
enabled: true
|
||||
duration: 0.5
|
||||
overlap: 0.0
|
||||
max_empty: 0.0
|
||||
discard_empty: true
|
||||
2
example_data/configs/logging.yaml
Normal file
2
example_data/configs/logging.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
train:
|
||||
name: csv
|
||||
59
example_data/configs/model.yaml
Normal file
59
example_data/configs/model.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
samplerate: 256000
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
architecture:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
9
example_data/configs/outputs.yaml
Normal file
9
example_data/configs/outputs.yaml
Normal file
@ -0,0 +1,9 @@
|
||||
format:
|
||||
name: raw
|
||||
include_class_scores: true
|
||||
include_features: true
|
||||
include_geometry: true
|
||||
|
||||
transform:
|
||||
detection_transforms: []
|
||||
clip_transforms: []
|
||||
79
example_data/configs/training.yaml
Normal file
79
example_data/configs/training.yaml
Normal file
@ -0,0 +1,79 @@
|
||||
optimizer:
|
||||
name: adam
|
||||
learning_rate: 0.001
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 10
|
||||
check_val_every_n_epoch: 5
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
shuffle: true
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
duration: 0.256
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
classification:
|
||||
weight: 2.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
validation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
10
justfile
10
justfile
@ -112,6 +112,14 @@ clean: clean-build clean-pyc clean-test clean-docs
|
||||
example-train OPTIONS="":
|
||||
uv run batdetect2 train \
|
||||
--val-dataset example_data/dataset.yaml \
|
||||
--config example_data/config.yaml \
|
||||
--base-dir . \
|
||||
--targets example_data/targets.yaml \
|
||||
--model-config example_data/configs/model.yaml \
|
||||
--training-config example_data/configs/training.yaml \
|
||||
--audio-config example_data/configs/audio.yaml \
|
||||
--evaluation-config example_data/configs/evaluation.yaml \
|
||||
--inference-config example_data/configs/inference.yaml \
|
||||
--outputs-config example_data/configs/outputs.yaml \
|
||||
--logging-config example_data/configs/logging.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/dataset.yaml
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.evaluate.config import (
|
||||
EvaluationConfig,
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = ["BatDetect2Config"]
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
config_version: Literal["v1"] = "v1"
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(
|
||||
default_factory=get_default_eval_config
|
||||
)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||
@ -203,8 +203,8 @@ in_channels: 1
|
||||
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
||||
"""load_backbone_config loads the real example config correctly."""
|
||||
config = load_backbone_config(
|
||||
example_data_dir / "config.yaml",
|
||||
field="model.architecture",
|
||||
example_data_dir / "configs" / "model.yaml",
|
||||
field="architecture",
|
||||
)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
|
||||
@ -3,20 +3,19 @@ from pathlib import Path
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.train import run_train
|
||||
from batdetect2.train import TrainingConfig, run_train
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
def _build_fast_train_config() -> BatDetect2Config:
|
||||
config = BatDetect2Config()
|
||||
config.train.trainer.limit_train_batches = 1
|
||||
config.train.trainer.limit_val_batches = 1
|
||||
config.train.trainer.log_every_n_steps = 1
|
||||
config.train.trainer.check_val_every_n_epoch = 1
|
||||
config.train.train_loader.batch_size = 1
|
||||
config.train.train_loader.augmentations.enabled = False
|
||||
def _build_fast_train_config() -> TrainingConfig:
|
||||
config = TrainingConfig()
|
||||
config.trainer.limit_train_batches = 1
|
||||
config.trainer.limit_val_batches = 1
|
||||
config.trainer.log_every_n_steps = 1
|
||||
config.trainer.check_val_every_n_epoch = 1
|
||||
config.train_loader.batch_size = 1
|
||||
config.train_loader.augmentations.enabled = False
|
||||
return config
|
||||
|
||||
|
||||
@ -29,9 +28,7 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
@ -50,14 +47,12 @@ def test_train_without_validation_can_still_save_last_checkpoint(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
config = _build_fast_train_config()
|
||||
config.train.checkpoints.save_last = True
|
||||
config.checkpoints.save_last = True
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=None,
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
@ -73,16 +68,14 @@ def test_train_controls_which_checkpoints_are_kept(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
config = _build_fast_train_config()
|
||||
config.train.checkpoints.save_top_k = 1
|
||||
config.train.checkpoints.save_last = True
|
||||
config.train.checkpoints.filename = "epoch{epoch}"
|
||||
config.checkpoints.save_top_k = 1
|
||||
config.checkpoints.save_last = True
|
||||
config.checkpoints.filename = "epoch{epoch}"
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config,
|
||||
num_epochs=3,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
|
||||
@ -1,12 +1,43 @@
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import load_config
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig
|
||||
|
||||
|
||||
def test_example_config_is_valid(example_data_dir):
|
||||
conf = load_config(
|
||||
example_data_dir / "config.yaml",
|
||||
schema=BatDetect2Config,
|
||||
extra="forbid",
|
||||
strict=True,
|
||||
def test_example_split_configs_are_valid(example_data_dir):
|
||||
configs_dir = example_data_dir / "configs"
|
||||
|
||||
assert isinstance(
|
||||
AudioConfig.load(configs_dir / "audio.yaml"), AudioConfig
|
||||
)
|
||||
assert isinstance(
|
||||
ModelConfig.load(configs_dir / "model.yaml"), ModelConfig
|
||||
)
|
||||
assert isinstance(
|
||||
TargetConfig.load(example_data_dir / "targets.yaml"),
|
||||
TargetConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
TrainingConfig.load(configs_dir / "training.yaml"),
|
||||
TrainingConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
EvaluationConfig.load(configs_dir / "evaluation.yaml"),
|
||||
EvaluationConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
InferenceConfig.load(configs_dir / "inference.yaml"),
|
||||
InferenceConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
OutputsConfig.load(configs_dir / "outputs.yaml"),
|
||||
OutputsConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
AppLoggingConfig.load(configs_dir / "logging.yaml"),
|
||||
AppLoggingConfig,
|
||||
)
|
||||
assert isinstance(conf, BatDetect2Config)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user