mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
refactor: remove aggregate app config
This commit is contained in:
parent
c27e7f9f52
commit
7a10b7ffff
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
`BatDetect2API` is the main entry point for the current Python workflow.
|
`BatDetect2API` is the main entry point for the current Python workflow.
|
||||||
|
|
||||||
It wraps model loading, inference, evaluation, output formatting, and training-related entry points behind one object.
|
It wraps model loading, inference, evaluation, output formatting, and
|
||||||
|
training-related entry points behind one object.
|
||||||
|
|
||||||
Defined in `batdetect2.api_v2`.
|
Defined in `batdetect2.api_v2`.
|
||||||
|
|
||||||
@ -10,8 +11,8 @@ Defined in `batdetect2.api_v2`.
|
|||||||
|
|
||||||
- `BatDetect2API.from_checkpoint(path, ...)`
|
- `BatDetect2API.from_checkpoint(path, ...)`
|
||||||
- load a trained checkpoint and optional config overrides.
|
- load a trained checkpoint and optional config overrides.
|
||||||
- `BatDetect2API.from_config(config)`
|
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
||||||
- build a full stack from a `BatDetect2Config` object.
|
- build a full stack from separate config objects.
|
||||||
|
|
||||||
## Inference methods
|
## Inference methods
|
||||||
|
|
||||||
@ -46,10 +47,12 @@ Defined in `batdetect2.api_v2`.
|
|||||||
|
|
||||||
## Output persistence helpers
|
## Output persistence helpers
|
||||||
|
|
||||||
- `save_predictions(predictions, path, audio_dir=None, format=None, config=None)`
|
- `save_predictions(predictions, path, audio_dir=None, format=None,
|
||||||
|
config=None)`
|
||||||
- `load_predictions(path, format=None, config=None)`
|
- `load_predictions(path, format=None, config=None)`
|
||||||
|
|
||||||
Use these when you want to save programmatic predictions without going through the CLI.
|
Use these when you want to save programmatic predictions without going through
|
||||||
|
the CLI.
|
||||||
|
|
||||||
## Training and evaluation entry points
|
## Training and evaluation entry points
|
||||||
|
|
||||||
@ -60,6 +63,9 @@ Use these when you want to save programmatic predictions without going through t
|
|||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline`
|
- Python tutorial:
|
||||||
- Outputs config reference: {doc}`outputs-config`
|
{doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||||
- Output formats reference: {doc}`output-formats`
|
- Outputs config reference:
|
||||||
|
{doc}`outputs-config`
|
||||||
|
- Output formats reference:
|
||||||
|
{doc}`output-formats`
|
||||||
|
|||||||
@ -1,38 +0,0 @@
|
|||||||
# Top-level app config reference
|
|
||||||
|
|
||||||
The top-level config object is `BatDetect2Config`.
|
|
||||||
|
|
||||||
Defined in `batdetect2.config`.
|
|
||||||
|
|
||||||
It combines the main configuration surfaces used across training, inference, evaluation, outputs, and logging.
|
|
||||||
|
|
||||||
## Fields
|
|
||||||
|
|
||||||
- `config_version`
|
|
||||||
- `train`
|
|
||||||
- training-specific config.
|
|
||||||
- `evaluation`
|
|
||||||
- evaluation task and plot config.
|
|
||||||
- `model`
|
|
||||||
- model architecture, preprocessing, postprocessing, and targets.
|
|
||||||
- `audio`
|
|
||||||
- audio loading and resampling config.
|
|
||||||
- `inference`
|
|
||||||
- clipping and loader config for prediction-time workflows.
|
|
||||||
- `outputs`
|
|
||||||
- output format and output transform config.
|
|
||||||
- `logging`
|
|
||||||
- logging backend and formatting config.
|
|
||||||
|
|
||||||
## Mental model
|
|
||||||
|
|
||||||
Think of `BatDetect2Config` as the complete application wiring for the current stack.
|
|
||||||
|
|
||||||
Use it when you want one reproducible config that describes the whole workflow.
|
|
||||||
|
|
||||||
## Related pages
|
|
||||||
|
|
||||||
- Inference config: {doc}`inference-config`
|
|
||||||
- Evaluation config: {doc}`evaluation-config`
|
|
||||||
- Outputs config: {doc}`outputs-config`
|
|
||||||
- General config reference: {doc}`configs`
|
|
||||||
@ -24,8 +24,8 @@ for full options and argument details.
|
|||||||
- Global CLI options are documented in {doc}`base`.
|
- Global CLI options are documented in {doc}`base`.
|
||||||
- Paths with spaces should be wrapped in quotes.
|
- Paths with spaces should be wrapped in quotes.
|
||||||
- Input audio is expected to be mono.
|
- Input audio is expected to be mono.
|
||||||
- Legacy `detect` uses a required threshold argument, while `predict` uses
|
- Legacy `detect` uses a required threshold argument, while `predict` uses the
|
||||||
the optional `--detection-threshold` override.
|
optional `--detection-threshold` override.
|
||||||
|
|
||||||
```{warning}
|
```{warning}
|
||||||
`batdetect2 detect` is a legacy command.
|
`batdetect2 detect` is a legacy command.
|
||||||
|
|||||||
@ -1,5 +1,15 @@
|
|||||||
Config reference
|
Config reference
|
||||||
================
|
================
|
||||||
|
|
||||||
.. automodule:: batdetect2.config
|
BatDetect2 uses separate config objects for different workflow surfaces.
|
||||||
:members:
|
|
||||||
|
Use the dedicated reference pages for each config family:
|
||||||
|
|
||||||
|
- inference config
|
||||||
|
- evaluation config
|
||||||
|
- outputs config
|
||||||
|
- preprocessing config
|
||||||
|
- postprocess config
|
||||||
|
- targets config workflow
|
||||||
|
|
||||||
|
Example config files live under `example_data/configs/`.
|
||||||
|
|||||||
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
Reference pages are the detailed lookup pages.
|
Reference pages are the detailed lookup pages.
|
||||||
|
|
||||||
Use this section when you need exact command options, setting names, output details, or Python API entries.
|
Use this section when you need exact command options, setting names, output
|
||||||
|
details, or Python API entries.
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
cli/index
|
cli/index
|
||||||
api
|
api
|
||||||
app-config
|
|
||||||
inference-config
|
inference-config
|
||||||
evaluation-config
|
evaluation-config
|
||||||
outputs-config
|
outputs-config
|
||||||
|
|||||||
@ -1,192 +0,0 @@
|
|||||||
config_version: v1
|
|
||||||
|
|
||||||
audio:
|
|
||||||
samplerate: 256000
|
|
||||||
resample:
|
|
||||||
enabled: true
|
|
||||||
method: poly
|
|
||||||
|
|
||||||
model:
|
|
||||||
samplerate: 256000
|
|
||||||
|
|
||||||
preprocess:
|
|
||||||
stft:
|
|
||||||
window_duration: 0.002
|
|
||||||
window_overlap: 0.75
|
|
||||||
window_fn: hann
|
|
||||||
frequencies:
|
|
||||||
max_freq: 120000
|
|
||||||
min_freq: 10000
|
|
||||||
size:
|
|
||||||
height: 128
|
|
||||||
resize_factor: 0.5
|
|
||||||
spectrogram_transforms:
|
|
||||||
- name: pcen
|
|
||||||
time_constant: 0.1
|
|
||||||
gain: 0.98
|
|
||||||
bias: 2
|
|
||||||
power: 0.5
|
|
||||||
- name: spectral_mean_subtraction
|
|
||||||
|
|
||||||
architecture:
|
|
||||||
name: UNetBackbone
|
|
||||||
input_height: 128
|
|
||||||
in_channels: 1
|
|
||||||
encoder:
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvDown
|
|
||||||
out_channels: 32
|
|
||||||
- name: FreqCoordConvDown
|
|
||||||
out_channels: 64
|
|
||||||
- name: LayerGroup
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvDown
|
|
||||||
out_channels: 128
|
|
||||||
- name: ConvBlock
|
|
||||||
out_channels: 256
|
|
||||||
bottleneck:
|
|
||||||
channels: 256
|
|
||||||
layers:
|
|
||||||
- name: SelfAttention
|
|
||||||
attention_channels: 256
|
|
||||||
decoder:
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvUp
|
|
||||||
out_channels: 64
|
|
||||||
- name: FreqCoordConvUp
|
|
||||||
out_channels: 32
|
|
||||||
- name: LayerGroup
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvUp
|
|
||||||
out_channels: 32
|
|
||||||
- name: ConvBlock
|
|
||||||
out_channels: 32
|
|
||||||
|
|
||||||
postprocess:
|
|
||||||
nms_kernel_size: 9
|
|
||||||
detection_threshold: 0.01
|
|
||||||
top_k_per_sec: 200
|
|
||||||
|
|
||||||
train:
|
|
||||||
optimizer:
|
|
||||||
name: adam
|
|
||||||
learning_rate: 0.001
|
|
||||||
|
|
||||||
scheduler:
|
|
||||||
name: cosine_annealing
|
|
||||||
t_max: 100
|
|
||||||
|
|
||||||
labels:
|
|
||||||
sigma: 3
|
|
||||||
|
|
||||||
trainer:
|
|
||||||
max_epochs: 10
|
|
||||||
check_val_every_n_epoch: 5
|
|
||||||
|
|
||||||
train_loader:
|
|
||||||
batch_size: 8
|
|
||||||
shuffle: true
|
|
||||||
|
|
||||||
clipping_strategy:
|
|
||||||
name: random_subclip
|
|
||||||
duration: 0.256
|
|
||||||
|
|
||||||
augmentations:
|
|
||||||
enabled: true
|
|
||||||
audio:
|
|
||||||
- name: mix_audio
|
|
||||||
probability: 0.2
|
|
||||||
min_weight: 0.3
|
|
||||||
max_weight: 0.7
|
|
||||||
- name: add_echo
|
|
||||||
probability: 0.2
|
|
||||||
max_delay: 0.005
|
|
||||||
min_weight: 0.0
|
|
||||||
max_weight: 1.0
|
|
||||||
spectrogram:
|
|
||||||
- name: scale_volume
|
|
||||||
probability: 0.2
|
|
||||||
min_scaling: 0.0
|
|
||||||
max_scaling: 2.0
|
|
||||||
- name: warp
|
|
||||||
probability: 0.2
|
|
||||||
delta: 0.04
|
|
||||||
- name: mask_time
|
|
||||||
probability: 0.2
|
|
||||||
max_perc: 0.05
|
|
||||||
max_masks: 3
|
|
||||||
- name: mask_freq
|
|
||||||
probability: 0.2
|
|
||||||
max_perc: 0.10
|
|
||||||
max_masks: 3
|
|
||||||
|
|
||||||
val_loader:
|
|
||||||
clipping_strategy:
|
|
||||||
name: whole_audio_padded
|
|
||||||
chunk_size: 0.256
|
|
||||||
|
|
||||||
loss:
|
|
||||||
detection:
|
|
||||||
weight: 1.0
|
|
||||||
focal:
|
|
||||||
beta: 4
|
|
||||||
alpha: 2
|
|
||||||
classification:
|
|
||||||
weight: 2.0
|
|
||||||
focal:
|
|
||||||
beta: 4
|
|
||||||
alpha: 2
|
|
||||||
size:
|
|
||||||
weight: 0.1
|
|
||||||
|
|
||||||
validation:
|
|
||||||
tasks:
|
|
||||||
- name: sound_event_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: sound_event_classification
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
|
|
||||||
logging:
|
|
||||||
train:
|
|
||||||
name: csv
|
|
||||||
|
|
||||||
evaluation:
|
|
||||||
tasks:
|
|
||||||
- name: sound_event_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: score_distribution
|
|
||||||
- name: example_detection
|
|
||||||
- name: sound_event_classification
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: top_class_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: confusion_matrix
|
|
||||||
- name: example_classification
|
|
||||||
- name: clip_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: roc_curve
|
|
||||||
- name: score_distribution
|
|
||||||
- name: clip_classification
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: roc_curve
|
|
||||||
4
example_data/configs/audio.yaml
Normal file
4
example_data/configs/audio.yaml
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
samplerate: 256000
|
||||||
|
resample:
|
||||||
|
enabled: true
|
||||||
|
method: poly
|
||||||
37
example_data/configs/evaluation.yaml
Normal file
37
example_data/configs/evaluation.yaml
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
tasks:
|
||||||
|
- name: sound_event_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: score_distribution
|
||||||
|
- name: example_detection
|
||||||
|
- name: sound_event_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: top_class_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: confusion_matrix
|
||||||
|
- name: example_classification
|
||||||
|
- name: clip_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
|
- name: score_distribution
|
||||||
|
- name: clip_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
9
example_data/configs/inference.yaml
Normal file
9
example_data/configs/inference.yaml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
loader:
|
||||||
|
batch_size: 8
|
||||||
|
|
||||||
|
clipping:
|
||||||
|
enabled: true
|
||||||
|
duration: 0.5
|
||||||
|
overlap: 0.0
|
||||||
|
max_empty: 0.0
|
||||||
|
discard_empty: true
|
||||||
2
example_data/configs/logging.yaml
Normal file
2
example_data/configs/logging.yaml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
train:
|
||||||
|
name: csv
|
||||||
59
example_data/configs/model.yaml
Normal file
59
example_data/configs/model.yaml
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
samplerate: 256000
|
||||||
|
|
||||||
|
preprocess:
|
||||||
|
stft:
|
||||||
|
window_duration: 0.002
|
||||||
|
window_overlap: 0.75
|
||||||
|
window_fn: hann
|
||||||
|
frequencies:
|
||||||
|
max_freq: 120000
|
||||||
|
min_freq: 10000
|
||||||
|
size:
|
||||||
|
height: 128
|
||||||
|
resize_factor: 0.5
|
||||||
|
spectrogram_transforms:
|
||||||
|
- name: pcen
|
||||||
|
time_constant: 0.1
|
||||||
|
gain: 0.98
|
||||||
|
bias: 2
|
||||||
|
power: 0.5
|
||||||
|
- name: spectral_mean_subtraction
|
||||||
|
|
||||||
|
architecture:
|
||||||
|
name: UNetBackbone
|
||||||
|
input_height: 128
|
||||||
|
in_channels: 1
|
||||||
|
encoder:
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvDown
|
||||||
|
out_channels: 32
|
||||||
|
- name: FreqCoordConvDown
|
||||||
|
out_channels: 64
|
||||||
|
- name: LayerGroup
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvDown
|
||||||
|
out_channels: 128
|
||||||
|
- name: ConvBlock
|
||||||
|
out_channels: 256
|
||||||
|
bottleneck:
|
||||||
|
channels: 256
|
||||||
|
layers:
|
||||||
|
- name: SelfAttention
|
||||||
|
attention_channels: 256
|
||||||
|
decoder:
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvUp
|
||||||
|
out_channels: 64
|
||||||
|
- name: FreqCoordConvUp
|
||||||
|
out_channels: 32
|
||||||
|
- name: LayerGroup
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvUp
|
||||||
|
out_channels: 32
|
||||||
|
- name: ConvBlock
|
||||||
|
out_channels: 32
|
||||||
|
|
||||||
|
postprocess:
|
||||||
|
nms_kernel_size: 9
|
||||||
|
detection_threshold: 0.01
|
||||||
|
top_k_per_sec: 200
|
||||||
9
example_data/configs/outputs.yaml
Normal file
9
example_data/configs/outputs.yaml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
format:
|
||||||
|
name: raw
|
||||||
|
include_class_scores: true
|
||||||
|
include_features: true
|
||||||
|
include_geometry: true
|
||||||
|
|
||||||
|
transform:
|
||||||
|
detection_transforms: []
|
||||||
|
clip_transforms: []
|
||||||
79
example_data/configs/training.yaml
Normal file
79
example_data/configs/training.yaml
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
optimizer:
|
||||||
|
name: adam
|
||||||
|
learning_rate: 0.001
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
name: cosine_annealing
|
||||||
|
t_max: 100
|
||||||
|
|
||||||
|
labels:
|
||||||
|
sigma: 3
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 10
|
||||||
|
check_val_every_n_epoch: 5
|
||||||
|
|
||||||
|
train_loader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: true
|
||||||
|
|
||||||
|
clipping_strategy:
|
||||||
|
name: random_subclip
|
||||||
|
duration: 0.256
|
||||||
|
|
||||||
|
augmentations:
|
||||||
|
enabled: true
|
||||||
|
audio:
|
||||||
|
- name: mix_audio
|
||||||
|
probability: 0.2
|
||||||
|
min_weight: 0.3
|
||||||
|
max_weight: 0.7
|
||||||
|
- name: add_echo
|
||||||
|
probability: 0.2
|
||||||
|
max_delay: 0.005
|
||||||
|
min_weight: 0.0
|
||||||
|
max_weight: 1.0
|
||||||
|
spectrogram:
|
||||||
|
- name: scale_volume
|
||||||
|
probability: 0.2
|
||||||
|
min_scaling: 0.0
|
||||||
|
max_scaling: 2.0
|
||||||
|
- name: warp
|
||||||
|
probability: 0.2
|
||||||
|
delta: 0.04
|
||||||
|
- name: mask_time
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.05
|
||||||
|
max_masks: 3
|
||||||
|
- name: mask_freq
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.10
|
||||||
|
max_masks: 3
|
||||||
|
|
||||||
|
val_loader:
|
||||||
|
clipping_strategy:
|
||||||
|
name: whole_audio_padded
|
||||||
|
chunk_size: 0.256
|
||||||
|
|
||||||
|
loss:
|
||||||
|
detection:
|
||||||
|
weight: 1.0
|
||||||
|
focal:
|
||||||
|
beta: 4
|
||||||
|
alpha: 2
|
||||||
|
classification:
|
||||||
|
weight: 2.0
|
||||||
|
focal:
|
||||||
|
beta: 4
|
||||||
|
alpha: 2
|
||||||
|
size:
|
||||||
|
weight: 0.1
|
||||||
|
|
||||||
|
validation:
|
||||||
|
tasks:
|
||||||
|
- name: sound_event_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: sound_event_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
10
justfile
10
justfile
@ -112,6 +112,14 @@ clean: clean-build clean-pyc clean-test clean-docs
|
|||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
uv run batdetect2 train \
|
uv run batdetect2 train \
|
||||||
--val-dataset example_data/dataset.yaml \
|
--val-dataset example_data/dataset.yaml \
|
||||||
--config example_data/config.yaml \
|
--base-dir . \
|
||||||
|
--targets example_data/targets.yaml \
|
||||||
|
--model-config example_data/configs/model.yaml \
|
||||||
|
--training-config example_data/configs/training.yaml \
|
||||||
|
--audio-config example_data/configs/audio.yaml \
|
||||||
|
--evaluation-config example_data/configs/evaluation.yaml \
|
||||||
|
--inference-config example_data/configs/inference.yaml \
|
||||||
|
--outputs-config example_data/configs/outputs.yaml \
|
||||||
|
--logging-config example_data/configs/logging.yaml \
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
example_data/dataset.yaml
|
example_data/dataset.yaml
|
||||||
|
|||||||
@ -1,33 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.evaluate.config import (
|
|
||||||
EvaluationConfig,
|
|
||||||
get_default_eval_config,
|
|
||||||
)
|
|
||||||
from batdetect2.inference.config import InferenceConfig
|
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.models import ModelConfig
|
|
||||||
from batdetect2.outputs import OutputsConfig
|
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train.config import TrainingConfig
|
|
||||||
|
|
||||||
__all__ = ["BatDetect2Config"]
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2Config(BaseConfig):
|
|
||||||
config_version: Literal["v1"] = "v1"
|
|
||||||
|
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
|
||||||
evaluation: EvaluationConfig = Field(
|
|
||||||
default_factory=get_default_eval_config
|
|
||||||
)
|
|
||||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
|
||||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
|
||||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
|
||||||
@ -203,8 +203,8 @@ in_channels: 1
|
|||||||
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
||||||
"""load_backbone_config loads the real example config correctly."""
|
"""load_backbone_config loads the real example config correctly."""
|
||||||
config = load_backbone_config(
|
config = load_backbone_config(
|
||||||
example_data_dir / "config.yaml",
|
example_data_dir / "configs" / "model.yaml",
|
||||||
field="model.architecture",
|
field="architecture",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(config, UNetBackboneConfig)
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
|||||||
@ -3,20 +3,19 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.train import TrainingConfig, run_train
|
||||||
from batdetect2.train import run_train
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.slow
|
pytestmark = pytest.mark.slow
|
||||||
|
|
||||||
|
|
||||||
def _build_fast_train_config() -> BatDetect2Config:
|
def _build_fast_train_config() -> TrainingConfig:
|
||||||
config = BatDetect2Config()
|
config = TrainingConfig()
|
||||||
config.train.trainer.limit_train_batches = 1
|
config.trainer.limit_train_batches = 1
|
||||||
config.train.trainer.limit_val_batches = 1
|
config.trainer.limit_val_batches = 1
|
||||||
config.train.trainer.log_every_n_steps = 1
|
config.trainer.log_every_n_steps = 1
|
||||||
config.train.trainer.check_val_every_n_epoch = 1
|
config.trainer.check_val_every_n_epoch = 1
|
||||||
config.train.train_loader.batch_size = 1
|
config.train_loader.batch_size = 1
|
||||||
config.train.train_loader.augmentations.enabled = False
|
config.train_loader.augmentations.enabled = False
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@ -29,9 +28,7 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
|
|||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=example_annotations[:1],
|
val_annotations=example_annotations[:1],
|
||||||
train_config=config.train,
|
train_config=config,
|
||||||
model_config=config.model,
|
|
||||||
audio_config=config.audio,
|
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
@ -50,14 +47,12 @@ def test_train_without_validation_can_still_save_last_checkpoint(
|
|||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
config = _build_fast_train_config()
|
config = _build_fast_train_config()
|
||||||
config.train.checkpoints.save_last = True
|
config.checkpoints.save_last = True
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=None,
|
val_annotations=None,
|
||||||
train_config=config.train,
|
train_config=config,
|
||||||
model_config=config.model,
|
|
||||||
audio_config=config.audio,
|
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
@ -73,16 +68,14 @@ def test_train_controls_which_checkpoints_are_kept(
|
|||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
config = _build_fast_train_config()
|
config = _build_fast_train_config()
|
||||||
config.train.checkpoints.save_top_k = 1
|
config.checkpoints.save_top_k = 1
|
||||||
config.train.checkpoints.save_last = True
|
config.checkpoints.save_last = True
|
||||||
config.train.checkpoints.filename = "epoch{epoch}"
|
config.checkpoints.filename = "epoch{epoch}"
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=example_annotations[:1],
|
val_annotations=example_annotations[:1],
|
||||||
train_config=config.train,
|
train_config=config,
|
||||||
model_config=config.model,
|
|
||||||
audio_config=config.audio,
|
|
||||||
num_epochs=3,
|
num_epochs=3,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
|
|||||||
@ -1,12 +1,43 @@
|
|||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.core import load_config
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
|
from batdetect2.inference import InferenceConfig
|
||||||
|
from batdetect2.logging import AppLoggingConfig
|
||||||
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
|
from batdetect2.train import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
def test_example_config_is_valid(example_data_dir):
|
def test_example_split_configs_are_valid(example_data_dir):
|
||||||
conf = load_config(
|
configs_dir = example_data_dir / "configs"
|
||||||
example_data_dir / "config.yaml",
|
|
||||||
schema=BatDetect2Config,
|
assert isinstance(
|
||||||
extra="forbid",
|
AudioConfig.load(configs_dir / "audio.yaml"), AudioConfig
|
||||||
strict=True,
|
)
|
||||||
|
assert isinstance(
|
||||||
|
ModelConfig.load(configs_dir / "model.yaml"), ModelConfig
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
TargetConfig.load(example_data_dir / "targets.yaml"),
|
||||||
|
TargetConfig,
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
TrainingConfig.load(configs_dir / "training.yaml"),
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
EvaluationConfig.load(configs_dir / "evaluation.yaml"),
|
||||||
|
EvaluationConfig,
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
InferenceConfig.load(configs_dir / "inference.yaml"),
|
||||||
|
InferenceConfig,
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
OutputsConfig.load(configs_dir / "outputs.yaml"),
|
||||||
|
OutputsConfig,
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
AppLoggingConfig.load(configs_dir / "logging.yaml"),
|
||||||
|
AppLoggingConfig,
|
||||||
)
|
)
|
||||||
assert isinstance(conf, BatDetect2Config)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user