diff --git a/docs/source/reference/api.md b/docs/source/reference/api.md index d514bce..916e6df 100644 --- a/docs/source/reference/api.md +++ b/docs/source/reference/api.md @@ -2,7 +2,8 @@ `BatDetect2API` is the main entry point for the current Python workflow. -It wraps model loading, inference, evaluation, output formatting, and training-related entry points behind one object. +It wraps model loading, inference, evaluation, output formatting, and +training-related entry points behind one object. Defined in `batdetect2.api_v2`. @@ -10,8 +11,8 @@ Defined in `batdetect2.api_v2`. - `BatDetect2API.from_checkpoint(path, ...)` - load a trained checkpoint and optional config overrides. -- `BatDetect2API.from_config(config)` - - build a full stack from a `BatDetect2Config` object. +- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)` + - build a full stack from separate config objects. ## Inference methods @@ -46,10 +47,12 @@ Defined in `batdetect2.api_v2`. ## Output persistence helpers -- `save_predictions(predictions, path, audio_dir=None, format=None, config=None)` +- `save_predictions(predictions, path, audio_dir=None, format=None, + config=None)` - `load_predictions(path, format=None, config=None)` -Use these when you want to save programmatic predictions without going through the CLI. +Use these when you want to save programmatic predictions without going through +the CLI. ## Training and evaluation entry points @@ -60,6 +63,9 @@ Use these when you want to save programmatic predictions without going through t ## Related pages -- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline` -- Outputs config reference: {doc}`outputs-config` -- Output formats reference: {doc}`output-formats` +- Python tutorial: + {doc}`../tutorials/integrate-with-a-python-pipeline` +- Outputs config reference: + {doc}`outputs-config` +- Output formats reference: + {doc}`output-formats` diff --git a/docs/source/reference/app-config.md b/docs/source/reference/app-config.md deleted file mode 100644 index 1237c0f..0000000 --- a/docs/source/reference/app-config.md +++ /dev/null @@ -1,38 +0,0 @@ -# Top-level app config reference - -The top-level config object is `BatDetect2Config`. - -Defined in `batdetect2.config`. - -It combines the main configuration surfaces used across training, inference, evaluation, outputs, and logging. - -## Fields - -- `config_version` -- `train` - - training-specific config. -- `evaluation` - - evaluation task and plot config. -- `model` - - model architecture, preprocessing, postprocessing, and targets. -- `audio` - - audio loading and resampling config. -- `inference` - - clipping and loader config for prediction-time workflows. -- `outputs` - - output format and output transform config. -- `logging` - - logging backend and formatting config. - -## Mental model - -Think of `BatDetect2Config` as the complete application wiring for the current stack. - -Use it when you want one reproducible config that describes the whole workflow. - -## Related pages - -- Inference config: {doc}`inference-config` -- Evaluation config: {doc}`evaluation-config` -- Outputs config: {doc}`outputs-config` -- General config reference: {doc}`configs` diff --git a/docs/source/reference/cli/index.md b/docs/source/reference/cli/index.md index e52c828..c8a238b 100644 --- a/docs/source/reference/cli/index.md +++ b/docs/source/reference/cli/index.md @@ -24,8 +24,8 @@ for full options and argument details. - Global CLI options are documented in {doc}`base`. - Paths with spaces should be wrapped in quotes. - Input audio is expected to be mono. -- Legacy `detect` uses a required threshold argument, while `predict` uses - the optional `--detection-threshold` override. +- Legacy `detect` uses a required threshold argument, while `predict` uses the + optional `--detection-threshold` override. ```{warning} `batdetect2 detect` is a legacy command. diff --git a/docs/source/reference/configs.rst b/docs/source/reference/configs.rst index 733159a..90261d1 100644 --- a/docs/source/reference/configs.rst +++ b/docs/source/reference/configs.rst @@ -1,5 +1,15 @@ Config reference ================ -.. automodule:: batdetect2.config - :members: +BatDetect2 uses separate config objects for different workflow surfaces. + +Use the dedicated reference pages for each config family: + +- inference config +- evaluation config +- outputs config +- preprocessing config +- postprocess config +- targets config workflow + +Example config files live under `example_data/configs/`. diff --git a/docs/source/reference/index.md b/docs/source/reference/index.md index f0f0beb..cd09bc1 100644 --- a/docs/source/reference/index.md +++ b/docs/source/reference/index.md @@ -2,14 +2,14 @@ Reference pages are the detailed lookup pages. -Use this section when you need exact command options, setting names, output details, or Python API entries. +Use this section when you need exact command options, setting names, output +details, or Python API entries. ```{toctree} :maxdepth: 1 cli/index api -app-config inference-config evaluation-config outputs-config diff --git a/example_data/config.yaml b/example_data/config.yaml deleted file mode 100644 index 0b2ffa3..0000000 --- a/example_data/config.yaml +++ /dev/null @@ -1,192 +0,0 @@ -config_version: v1 - -audio: - samplerate: 256000 - resample: - enabled: true - method: poly - -model: - samplerate: 256000 - - preprocess: - stft: - window_duration: 0.002 - window_overlap: 0.75 - window_fn: hann - frequencies: - max_freq: 120000 - min_freq: 10000 - size: - height: 128 - resize_factor: 0.5 - spectrogram_transforms: - - name: pcen - time_constant: 0.1 - gain: 0.98 - bias: 2 - power: 0.5 - - name: spectral_mean_subtraction - - architecture: - name: UNetBackbone - input_height: 128 - in_channels: 1 - encoder: - layers: - - name: FreqCoordConvDown - out_channels: 32 - - name: FreqCoordConvDown - out_channels: 64 - - name: LayerGroup - layers: - - name: FreqCoordConvDown - out_channels: 128 - - name: ConvBlock - out_channels: 256 - bottleneck: - channels: 256 - layers: - - name: SelfAttention - attention_channels: 256 - decoder: - layers: - - name: FreqCoordConvUp - out_channels: 64 - - name: FreqCoordConvUp - out_channels: 32 - - name: LayerGroup - layers: - - name: FreqCoordConvUp - out_channels: 32 - - name: ConvBlock - out_channels: 32 - - postprocess: - nms_kernel_size: 9 - detection_threshold: 0.01 - top_k_per_sec: 200 - -train: - optimizer: - name: adam - learning_rate: 0.001 - - scheduler: - name: cosine_annealing - t_max: 100 - - labels: - sigma: 3 - - trainer: - max_epochs: 10 - check_val_every_n_epoch: 5 - - train_loader: - batch_size: 8 - shuffle: true - - clipping_strategy: - name: random_subclip - duration: 0.256 - - augmentations: - enabled: true - audio: - - name: mix_audio - probability: 0.2 - min_weight: 0.3 - max_weight: 0.7 - - name: add_echo - probability: 0.2 - max_delay: 0.005 - min_weight: 0.0 - max_weight: 1.0 - spectrogram: - - name: scale_volume - probability: 0.2 - min_scaling: 0.0 - max_scaling: 2.0 - - name: warp - probability: 0.2 - delta: 0.04 - - name: mask_time - probability: 0.2 - max_perc: 0.05 - max_masks: 3 - - name: mask_freq - probability: 0.2 - max_perc: 0.10 - max_masks: 3 - - val_loader: - clipping_strategy: - name: whole_audio_padded - chunk_size: 0.256 - - loss: - detection: - weight: 1.0 - focal: - beta: 4 - alpha: 2 - classification: - weight: 2.0 - focal: - beta: 4 - alpha: 2 - size: - weight: 0.1 - - validation: - tasks: - - name: sound_event_detection - metrics: - - name: average_precision - - name: sound_event_classification - metrics: - - name: average_precision - -logging: - train: - name: csv - -evaluation: - tasks: - - name: sound_event_detection - metrics: - - name: average_precision - - name: roc_auc - plots: - - name: pr_curve - - name: score_distribution - - name: example_detection - - name: sound_event_classification - metrics: - - name: average_precision - - name: roc_auc - plots: - - name: pr_curve - - name: top_class_detection - metrics: - - name: average_precision - plots: - - name: pr_curve - - name: confusion_matrix - - name: example_classification - - name: clip_detection - metrics: - - name: average_precision - - name: roc_auc - plots: - - name: pr_curve - - name: roc_curve - - name: score_distribution - - name: clip_classification - metrics: - - name: average_precision - - name: roc_auc - plots: - - name: pr_curve - - name: roc_curve diff --git a/example_data/configs/audio.yaml b/example_data/configs/audio.yaml new file mode 100644 index 0000000..0100572 --- /dev/null +++ b/example_data/configs/audio.yaml @@ -0,0 +1,4 @@ +samplerate: 256000 +resample: + enabled: true + method: poly diff --git a/example_data/configs/evaluation.yaml b/example_data/configs/evaluation.yaml new file mode 100644 index 0000000..f2577ae --- /dev/null +++ b/example_data/configs/evaluation.yaml @@ -0,0 +1,37 @@ +tasks: + - name: sound_event_detection + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: score_distribution + - name: example_detection + - name: sound_event_classification + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: top_class_detection + metrics: + - name: average_precision + plots: + - name: pr_curve + - name: confusion_matrix + - name: example_classification + - name: clip_detection + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: roc_curve + - name: score_distribution + - name: clip_classification + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: roc_curve diff --git a/example_data/configs/inference.yaml b/example_data/configs/inference.yaml new file mode 100644 index 0000000..f9a3078 --- /dev/null +++ b/example_data/configs/inference.yaml @@ -0,0 +1,9 @@ +loader: + batch_size: 8 + +clipping: + enabled: true + duration: 0.5 + overlap: 0.0 + max_empty: 0.0 + discard_empty: true diff --git a/example_data/configs/logging.yaml b/example_data/configs/logging.yaml new file mode 100644 index 0000000..bb5f366 --- /dev/null +++ b/example_data/configs/logging.yaml @@ -0,0 +1,2 @@ +train: + name: csv diff --git a/example_data/configs/model.yaml b/example_data/configs/model.yaml new file mode 100644 index 0000000..b03a525 --- /dev/null +++ b/example_data/configs/model.yaml @@ -0,0 +1,59 @@ +samplerate: 256000 + +preprocess: + stft: + window_duration: 0.002 + window_overlap: 0.75 + window_fn: hann + frequencies: + max_freq: 120000 + min_freq: 10000 + size: + height: 128 + resize_factor: 0.5 + spectrogram_transforms: + - name: pcen + time_constant: 0.1 + gain: 0.98 + bias: 2 + power: 0.5 + - name: spectral_mean_subtraction + +architecture: + name: UNetBackbone + input_height: 128 + in_channels: 1 + encoder: + layers: + - name: FreqCoordConvDown + out_channels: 32 + - name: FreqCoordConvDown + out_channels: 64 + - name: LayerGroup + layers: + - name: FreqCoordConvDown + out_channels: 128 + - name: ConvBlock + out_channels: 256 + bottleneck: + channels: 256 + layers: + - name: SelfAttention + attention_channels: 256 + decoder: + layers: + - name: FreqCoordConvUp + out_channels: 64 + - name: FreqCoordConvUp + out_channels: 32 + - name: LayerGroup + layers: + - name: FreqCoordConvUp + out_channels: 32 + - name: ConvBlock + out_channels: 32 + +postprocess: + nms_kernel_size: 9 + detection_threshold: 0.01 + top_k_per_sec: 200 diff --git a/example_data/configs/outputs.yaml b/example_data/configs/outputs.yaml new file mode 100644 index 0000000..458093a --- /dev/null +++ b/example_data/configs/outputs.yaml @@ -0,0 +1,9 @@ +format: + name: raw + include_class_scores: true + include_features: true + include_geometry: true + +transform: + detection_transforms: [] + clip_transforms: [] diff --git a/example_data/configs/training.yaml b/example_data/configs/training.yaml new file mode 100644 index 0000000..b99899e --- /dev/null +++ b/example_data/configs/training.yaml @@ -0,0 +1,79 @@ +optimizer: + name: adam + learning_rate: 0.001 + +scheduler: + name: cosine_annealing + t_max: 100 + +labels: + sigma: 3 + +trainer: + max_epochs: 10 + check_val_every_n_epoch: 5 + +train_loader: + batch_size: 8 + shuffle: true + + clipping_strategy: + name: random_subclip + duration: 0.256 + + augmentations: + enabled: true + audio: + - name: mix_audio + probability: 0.2 + min_weight: 0.3 + max_weight: 0.7 + - name: add_echo + probability: 0.2 + max_delay: 0.005 + min_weight: 0.0 + max_weight: 1.0 + spectrogram: + - name: scale_volume + probability: 0.2 + min_scaling: 0.0 + max_scaling: 2.0 + - name: warp + probability: 0.2 + delta: 0.04 + - name: mask_time + probability: 0.2 + max_perc: 0.05 + max_masks: 3 + - name: mask_freq + probability: 0.2 + max_perc: 0.10 + max_masks: 3 + +val_loader: + clipping_strategy: + name: whole_audio_padded + chunk_size: 0.256 + +loss: + detection: + weight: 1.0 + focal: + beta: 4 + alpha: 2 + classification: + weight: 2.0 + focal: + beta: 4 + alpha: 2 + size: + weight: 0.1 + +validation: + tasks: + - name: sound_event_detection + metrics: + - name: average_precision + - name: sound_event_classification + metrics: + - name: average_precision diff --git a/justfile b/justfile index f8d298a..0e1b870 100644 --- a/justfile +++ b/justfile @@ -112,6 +112,14 @@ clean: clean-build clean-pyc clean-test clean-docs example-train OPTIONS="": uv run batdetect2 train \ --val-dataset example_data/dataset.yaml \ - --config example_data/config.yaml \ + --base-dir . \ + --targets example_data/targets.yaml \ + --model-config example_data/configs/model.yaml \ + --training-config example_data/configs/training.yaml \ + --audio-config example_data/configs/audio.yaml \ + --evaluation-config example_data/configs/evaluation.yaml \ + --inference-config example_data/configs/inference.yaml \ + --outputs-config example_data/configs/outputs.yaml \ + --logging-config example_data/configs/logging.yaml \ {{OPTIONS}} \ example_data/dataset.yaml diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py deleted file mode 100644 index 88b23ee..0000000 --- a/src/batdetect2/config.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Literal - -from pydantic import Field - -from batdetect2.audio import AudioConfig -from batdetect2.core.configs import BaseConfig -from batdetect2.evaluate.config import ( - EvaluationConfig, - get_default_eval_config, -) -from batdetect2.inference.config import InferenceConfig -from batdetect2.logging import AppLoggingConfig -from batdetect2.models import ModelConfig -from batdetect2.outputs import OutputsConfig -from batdetect2.targets import TargetConfig -from batdetect2.train.config import TrainingConfig - -__all__ = ["BatDetect2Config"] - - -class BatDetect2Config(BaseConfig): - config_version: Literal["v1"] = "v1" - - train: TrainingConfig = Field(default_factory=TrainingConfig) - evaluation: EvaluationConfig = Field( - default_factory=get_default_eval_config - ) - model: ModelConfig = Field(default_factory=ModelConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) - audio: AudioConfig = Field(default_factory=AudioConfig) - inference: InferenceConfig = Field(default_factory=InferenceConfig) - outputs: OutputsConfig = Field(default_factory=OutputsConfig) - logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig) diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index 6f26dad..9f121f9 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -203,8 +203,8 @@ in_channels: 1 def test_load_backbone_config_from_example_data(example_data_dir: Path): """load_backbone_config loads the real example config correctly.""" config = load_backbone_config( - example_data_dir / "config.yaml", - field="model.architecture", + example_data_dir / "configs" / "model.yaml", + field="architecture", ) assert isinstance(config, UNetBackboneConfig) diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index b228378..3603cac 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -3,20 +3,19 @@ from pathlib import Path import pytest from soundevent import data -from batdetect2.config import BatDetect2Config -from batdetect2.train import run_train +from batdetect2.train import TrainingConfig, run_train pytestmark = pytest.mark.slow -def _build_fast_train_config() -> BatDetect2Config: - config = BatDetect2Config() - config.train.trainer.limit_train_batches = 1 - config.train.trainer.limit_val_batches = 1 - config.train.trainer.log_every_n_steps = 1 - config.train.trainer.check_val_every_n_epoch = 1 - config.train.train_loader.batch_size = 1 - config.train.train_loader.augmentations.enabled = False +def _build_fast_train_config() -> TrainingConfig: + config = TrainingConfig() + config.trainer.limit_train_batches = 1 + config.trainer.limit_val_batches = 1 + config.trainer.log_every_n_steps = 1 + config.trainer.check_val_every_n_epoch = 1 + config.train_loader.batch_size = 1 + config.train_loader.augmentations.enabled = False return config @@ -29,9 +28,7 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir( run_train( train_annotations=example_annotations[:1], val_annotations=example_annotations[:1], - train_config=config.train, - model_config=config.model, - audio_config=config.audio, + train_config=config, num_epochs=1, train_workers=0, val_workers=0, @@ -50,14 +47,12 @@ def test_train_without_validation_can_still_save_last_checkpoint( example_annotations: list[data.ClipAnnotation], ) -> None: config = _build_fast_train_config() - config.train.checkpoints.save_last = True + config.checkpoints.save_last = True run_train( train_annotations=example_annotations[:1], val_annotations=None, - train_config=config.train, - model_config=config.model, - audio_config=config.audio, + train_config=config, num_epochs=1, train_workers=0, val_workers=0, @@ -73,16 +68,14 @@ def test_train_controls_which_checkpoints_are_kept( example_annotations: list[data.ClipAnnotation], ) -> None: config = _build_fast_train_config() - config.train.checkpoints.save_top_k = 1 - config.train.checkpoints.save_last = True - config.train.checkpoints.filename = "epoch{epoch}" + config.checkpoints.save_top_k = 1 + config.checkpoints.save_last = True + config.checkpoints.filename = "epoch{epoch}" run_train( train_annotations=example_annotations[:1], val_annotations=example_annotations[:1], - train_config=config.train, - model_config=config.model, - audio_config=config.audio, + train_config=config, num_epochs=3, train_workers=0, val_workers=0, diff --git a/tests/test_train/test_config.py b/tests/test_train/test_config.py index 1384887..547e64f 100644 --- a/tests/test_train/test_config.py +++ b/tests/test_train/test_config.py @@ -1,12 +1,43 @@ -from batdetect2.config import BatDetect2Config -from batdetect2.core import load_config +from batdetect2.audio import AudioConfig +from batdetect2.evaluate import EvaluationConfig +from batdetect2.inference import InferenceConfig +from batdetect2.logging import AppLoggingConfig +from batdetect2.models import ModelConfig +from batdetect2.outputs import OutputsConfig +from batdetect2.targets import TargetConfig +from batdetect2.train import TrainingConfig -def test_example_config_is_valid(example_data_dir): - conf = load_config( - example_data_dir / "config.yaml", - schema=BatDetect2Config, - extra="forbid", - strict=True, +def test_example_split_configs_are_valid(example_data_dir): + configs_dir = example_data_dir / "configs" + + assert isinstance( + AudioConfig.load(configs_dir / "audio.yaml"), AudioConfig + ) + assert isinstance( + ModelConfig.load(configs_dir / "model.yaml"), ModelConfig + ) + assert isinstance( + TargetConfig.load(example_data_dir / "targets.yaml"), + TargetConfig, + ) + assert isinstance( + TrainingConfig.load(configs_dir / "training.yaml"), + TrainingConfig, + ) + assert isinstance( + EvaluationConfig.load(configs_dir / "evaluation.yaml"), + EvaluationConfig, + ) + assert isinstance( + InferenceConfig.load(configs_dir / "inference.yaml"), + InferenceConfig, + ) + assert isinstance( + OutputsConfig.load(configs_dir / "outputs.yaml"), + OutputsConfig, + ) + assert isinstance( + AppLoggingConfig.load(configs_dir / "logging.yaml"), + AppLoggingConfig, ) - assert isinstance(conf, BatDetect2Config)