mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
No commits in common. "5a974711b09d7983e788db039a258cc834dc7646" and "f82ec218f073a1a6bda0b63cc9954e640a42b02e" have entirely different histories.
5a974711b0
...
f82ec218f0
1
.gitignore
vendored
1
.gitignore
vendored
@ -132,4 +132,3 @@ notebooks/tmp
|
|||||||
|
|
||||||
# Assets
|
# Assets
|
||||||
!assets/*
|
!assets/*
|
||||||
/models
|
|
||||||
|
|||||||
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
`BatDetect2API` is the main entry point for the current Python workflow.
|
`BatDetect2API` is the main entry point for the current Python workflow.
|
||||||
|
|
||||||
It wraps model loading, inference, evaluation, output formatting, and
|
It wraps model loading, inference, evaluation, output formatting, and training-related entry points behind one object.
|
||||||
training-related entry points behind one object.
|
|
||||||
|
|
||||||
Defined in `batdetect2.api_v2`.
|
Defined in `batdetect2.api_v2`.
|
||||||
|
|
||||||
@ -11,8 +10,8 @@ Defined in `batdetect2.api_v2`.
|
|||||||
|
|
||||||
- `BatDetect2API.from_checkpoint(path, ...)`
|
- `BatDetect2API.from_checkpoint(path, ...)`
|
||||||
- load a trained checkpoint and optional config overrides.
|
- load a trained checkpoint and optional config overrides.
|
||||||
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
- `BatDetect2API.from_config(config)`
|
||||||
- build a full stack from separate config objects.
|
- build a full stack from a `BatDetect2Config` object.
|
||||||
|
|
||||||
## Inference methods
|
## Inference methods
|
||||||
|
|
||||||
@ -47,12 +46,10 @@ Defined in `batdetect2.api_v2`.
|
|||||||
|
|
||||||
## Output persistence helpers
|
## Output persistence helpers
|
||||||
|
|
||||||
- `save_predictions(predictions, path, audio_dir=None, format=None,
|
- `save_predictions(predictions, path, audio_dir=None, format=None, config=None)`
|
||||||
config=None)`
|
|
||||||
- `load_predictions(path, format=None, config=None)`
|
- `load_predictions(path, format=None, config=None)`
|
||||||
|
|
||||||
Use these when you want to save programmatic predictions without going through
|
Use these when you want to save programmatic predictions without going through the CLI.
|
||||||
the CLI.
|
|
||||||
|
|
||||||
## Training and evaluation entry points
|
## Training and evaluation entry points
|
||||||
|
|
||||||
@ -63,9 +60,6 @@ the CLI.
|
|||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Python tutorial:
|
- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||||
{doc}`../tutorials/integrate-with-a-python-pipeline`
|
- Outputs config reference: {doc}`outputs-config`
|
||||||
- Outputs config reference:
|
- Output formats reference: {doc}`output-formats`
|
||||||
{doc}`outputs-config`
|
|
||||||
- Output formats reference:
|
|
||||||
{doc}`output-formats`
|
|
||||||
|
|||||||
38
docs/source/reference/app-config.md
Normal file
38
docs/source/reference/app-config.md
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Top-level app config reference
|
||||||
|
|
||||||
|
The top-level config object is `BatDetect2Config`.
|
||||||
|
|
||||||
|
Defined in `batdetect2.config`.
|
||||||
|
|
||||||
|
It combines the main configuration surfaces used across training, inference, evaluation, outputs, and logging.
|
||||||
|
|
||||||
|
## Fields
|
||||||
|
|
||||||
|
- `config_version`
|
||||||
|
- `train`
|
||||||
|
- training-specific config.
|
||||||
|
- `evaluation`
|
||||||
|
- evaluation task and plot config.
|
||||||
|
- `model`
|
||||||
|
- model architecture, preprocessing, postprocessing, and targets.
|
||||||
|
- `audio`
|
||||||
|
- audio loading and resampling config.
|
||||||
|
- `inference`
|
||||||
|
- clipping and loader config for prediction-time workflows.
|
||||||
|
- `outputs`
|
||||||
|
- output format and output transform config.
|
||||||
|
- `logging`
|
||||||
|
- logging backend and formatting config.
|
||||||
|
|
||||||
|
## Mental model
|
||||||
|
|
||||||
|
Think of `BatDetect2Config` as the complete application wiring for the current stack.
|
||||||
|
|
||||||
|
Use it when you want one reproducible config that describes the whole workflow.
|
||||||
|
|
||||||
|
## Related pages
|
||||||
|
|
||||||
|
- Inference config: {doc}`inference-config`
|
||||||
|
- Evaluation config: {doc}`evaluation-config`
|
||||||
|
- Outputs config: {doc}`outputs-config`
|
||||||
|
- General config reference: {doc}`configs`
|
||||||
@ -24,8 +24,8 @@ for full options and argument details.
|
|||||||
- Global CLI options are documented in {doc}`base`.
|
- Global CLI options are documented in {doc}`base`.
|
||||||
- Paths with spaces should be wrapped in quotes.
|
- Paths with spaces should be wrapped in quotes.
|
||||||
- Input audio is expected to be mono.
|
- Input audio is expected to be mono.
|
||||||
- Legacy `detect` uses a required threshold argument, while `predict` uses the
|
- Legacy `detect` uses a required threshold argument, while `predict` uses
|
||||||
optional `--detection-threshold` override.
|
the optional `--detection-threshold` override.
|
||||||
|
|
||||||
```{warning}
|
```{warning}
|
||||||
`batdetect2 detect` is a legacy command.
|
`batdetect2 detect` is a legacy command.
|
||||||
|
|||||||
@ -1,15 +1,5 @@
|
|||||||
Config reference
|
Config reference
|
||||||
================
|
================
|
||||||
|
|
||||||
BatDetect2 uses separate config objects for different workflow surfaces.
|
.. automodule:: batdetect2.config
|
||||||
|
:members:
|
||||||
Use the dedicated reference pages for each config family:
|
|
||||||
|
|
||||||
- inference config
|
|
||||||
- evaluation config
|
|
||||||
- outputs config
|
|
||||||
- preprocessing config
|
|
||||||
- postprocess config
|
|
||||||
- targets config workflow
|
|
||||||
|
|
||||||
Example config files live under `example_data/configs/`.
|
|
||||||
|
|||||||
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
Reference pages are the detailed lookup pages.
|
Reference pages are the detailed lookup pages.
|
||||||
|
|
||||||
Use this section when you need exact command options, setting names, output
|
Use this section when you need exact command options, setting names, output details, or Python API entries.
|
||||||
details, or Python API entries.
|
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
cli/index
|
cli/index
|
||||||
api
|
api
|
||||||
|
app-config
|
||||||
inference-config
|
inference-config
|
||||||
evaluation-config
|
evaluation-config
|
||||||
outputs-config
|
outputs-config
|
||||||
|
|||||||
192
example_data/config.yaml
Normal file
192
example_data/config.yaml
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
config_version: v1
|
||||||
|
|
||||||
|
audio:
|
||||||
|
samplerate: 256000
|
||||||
|
resample:
|
||||||
|
enabled: true
|
||||||
|
method: poly
|
||||||
|
|
||||||
|
model:
|
||||||
|
samplerate: 256000
|
||||||
|
|
||||||
|
preprocess:
|
||||||
|
stft:
|
||||||
|
window_duration: 0.002
|
||||||
|
window_overlap: 0.75
|
||||||
|
window_fn: hann
|
||||||
|
frequencies:
|
||||||
|
max_freq: 120000
|
||||||
|
min_freq: 10000
|
||||||
|
size:
|
||||||
|
height: 128
|
||||||
|
resize_factor: 0.5
|
||||||
|
spectrogram_transforms:
|
||||||
|
- name: pcen
|
||||||
|
time_constant: 0.1
|
||||||
|
gain: 0.98
|
||||||
|
bias: 2
|
||||||
|
power: 0.5
|
||||||
|
- name: spectral_mean_subtraction
|
||||||
|
|
||||||
|
architecture:
|
||||||
|
name: UNetBackbone
|
||||||
|
input_height: 128
|
||||||
|
in_channels: 1
|
||||||
|
encoder:
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvDown
|
||||||
|
out_channels: 32
|
||||||
|
- name: FreqCoordConvDown
|
||||||
|
out_channels: 64
|
||||||
|
- name: LayerGroup
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvDown
|
||||||
|
out_channels: 128
|
||||||
|
- name: ConvBlock
|
||||||
|
out_channels: 256
|
||||||
|
bottleneck:
|
||||||
|
channels: 256
|
||||||
|
layers:
|
||||||
|
- name: SelfAttention
|
||||||
|
attention_channels: 256
|
||||||
|
decoder:
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvUp
|
||||||
|
out_channels: 64
|
||||||
|
- name: FreqCoordConvUp
|
||||||
|
out_channels: 32
|
||||||
|
- name: LayerGroup
|
||||||
|
layers:
|
||||||
|
- name: FreqCoordConvUp
|
||||||
|
out_channels: 32
|
||||||
|
- name: ConvBlock
|
||||||
|
out_channels: 32
|
||||||
|
|
||||||
|
postprocess:
|
||||||
|
nms_kernel_size: 9
|
||||||
|
detection_threshold: 0.01
|
||||||
|
top_k_per_sec: 200
|
||||||
|
|
||||||
|
train:
|
||||||
|
optimizer:
|
||||||
|
name: adam
|
||||||
|
learning_rate: 0.001
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
name: cosine_annealing
|
||||||
|
t_max: 100
|
||||||
|
|
||||||
|
labels:
|
||||||
|
sigma: 3
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 10
|
||||||
|
check_val_every_n_epoch: 5
|
||||||
|
|
||||||
|
train_loader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: true
|
||||||
|
|
||||||
|
clipping_strategy:
|
||||||
|
name: random_subclip
|
||||||
|
duration: 0.256
|
||||||
|
|
||||||
|
augmentations:
|
||||||
|
enabled: true
|
||||||
|
audio:
|
||||||
|
- name: mix_audio
|
||||||
|
probability: 0.2
|
||||||
|
min_weight: 0.3
|
||||||
|
max_weight: 0.7
|
||||||
|
- name: add_echo
|
||||||
|
probability: 0.2
|
||||||
|
max_delay: 0.005
|
||||||
|
min_weight: 0.0
|
||||||
|
max_weight: 1.0
|
||||||
|
spectrogram:
|
||||||
|
- name: scale_volume
|
||||||
|
probability: 0.2
|
||||||
|
min_scaling: 0.0
|
||||||
|
max_scaling: 2.0
|
||||||
|
- name: warp
|
||||||
|
probability: 0.2
|
||||||
|
delta: 0.04
|
||||||
|
- name: mask_time
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.05
|
||||||
|
max_masks: 3
|
||||||
|
- name: mask_freq
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.10
|
||||||
|
max_masks: 3
|
||||||
|
|
||||||
|
val_loader:
|
||||||
|
clipping_strategy:
|
||||||
|
name: whole_audio_padded
|
||||||
|
chunk_size: 0.256
|
||||||
|
|
||||||
|
loss:
|
||||||
|
detection:
|
||||||
|
weight: 1.0
|
||||||
|
focal:
|
||||||
|
beta: 4
|
||||||
|
alpha: 2
|
||||||
|
classification:
|
||||||
|
weight: 2.0
|
||||||
|
focal:
|
||||||
|
beta: 4
|
||||||
|
alpha: 2
|
||||||
|
size:
|
||||||
|
weight: 0.1
|
||||||
|
|
||||||
|
validation:
|
||||||
|
tasks:
|
||||||
|
- name: sound_event_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: sound_event_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
|
||||||
|
logging:
|
||||||
|
train:
|
||||||
|
name: csv
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
tasks:
|
||||||
|
- name: sound_event_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: score_distribution
|
||||||
|
- name: example_detection
|
||||||
|
- name: sound_event_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: top_class_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: confusion_matrix
|
||||||
|
- name: example_classification
|
||||||
|
- name: clip_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
|
- name: score_distribution
|
||||||
|
- name: clip_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
@ -1,4 +0,0 @@
|
|||||||
samplerate: 256000
|
|
||||||
resample:
|
|
||||||
enabled: true
|
|
||||||
method: poly
|
|
||||||
@ -1,37 +0,0 @@
|
|||||||
tasks:
|
|
||||||
- name: sound_event_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: score_distribution
|
|
||||||
- name: example_detection
|
|
||||||
- name: sound_event_classification
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: top_class_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: confusion_matrix
|
|
||||||
- name: example_classification
|
|
||||||
- name: clip_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: roc_curve
|
|
||||||
- name: score_distribution
|
|
||||||
- name: clip_classification
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: roc_curve
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
loader:
|
|
||||||
batch_size: 8
|
|
||||||
|
|
||||||
clipping:
|
|
||||||
enabled: true
|
|
||||||
duration: 0.5
|
|
||||||
overlap: 0.0
|
|
||||||
max_empty: 0.0
|
|
||||||
discard_empty: true
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
train:
|
|
||||||
name: csv
|
|
||||||
@ -1,59 +0,0 @@
|
|||||||
samplerate: 256000
|
|
||||||
|
|
||||||
preprocess:
|
|
||||||
stft:
|
|
||||||
window_duration: 0.002
|
|
||||||
window_overlap: 0.75
|
|
||||||
window_fn: hann
|
|
||||||
frequencies:
|
|
||||||
max_freq: 120000
|
|
||||||
min_freq: 10000
|
|
||||||
size:
|
|
||||||
height: 128
|
|
||||||
resize_factor: 0.5
|
|
||||||
spectrogram_transforms:
|
|
||||||
- name: pcen
|
|
||||||
time_constant: 0.1
|
|
||||||
gain: 0.98
|
|
||||||
bias: 2
|
|
||||||
power: 0.5
|
|
||||||
- name: spectral_mean_subtraction
|
|
||||||
|
|
||||||
architecture:
|
|
||||||
name: UNetBackbone
|
|
||||||
input_height: 128
|
|
||||||
in_channels: 1
|
|
||||||
encoder:
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvDown
|
|
||||||
out_channels: 32
|
|
||||||
- name: FreqCoordConvDown
|
|
||||||
out_channels: 64
|
|
||||||
- name: LayerGroup
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvDown
|
|
||||||
out_channels: 128
|
|
||||||
- name: ConvBlock
|
|
||||||
out_channels: 256
|
|
||||||
bottleneck:
|
|
||||||
channels: 256
|
|
||||||
layers:
|
|
||||||
- name: SelfAttention
|
|
||||||
attention_channels: 256
|
|
||||||
decoder:
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvUp
|
|
||||||
out_channels: 64
|
|
||||||
- name: FreqCoordConvUp
|
|
||||||
out_channels: 32
|
|
||||||
- name: LayerGroup
|
|
||||||
layers:
|
|
||||||
- name: FreqCoordConvUp
|
|
||||||
out_channels: 32
|
|
||||||
- name: ConvBlock
|
|
||||||
out_channels: 32
|
|
||||||
|
|
||||||
postprocess:
|
|
||||||
nms_kernel_size: 9
|
|
||||||
detection_threshold: 0.01
|
|
||||||
top_k_per_sec: 200
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
format:
|
|
||||||
name: raw
|
|
||||||
include_class_scores: true
|
|
||||||
include_features: true
|
|
||||||
include_geometry: true
|
|
||||||
|
|
||||||
transform:
|
|
||||||
detection_transforms: []
|
|
||||||
clip_transforms: []
|
|
||||||
@ -1,79 +0,0 @@
|
|||||||
optimizer:
|
|
||||||
name: adam
|
|
||||||
learning_rate: 0.001
|
|
||||||
|
|
||||||
scheduler:
|
|
||||||
name: cosine_annealing
|
|
||||||
t_max: 100
|
|
||||||
|
|
||||||
labels:
|
|
||||||
sigma: 3
|
|
||||||
|
|
||||||
trainer:
|
|
||||||
max_epochs: 10
|
|
||||||
check_val_every_n_epoch: 5
|
|
||||||
|
|
||||||
train_loader:
|
|
||||||
batch_size: 8
|
|
||||||
shuffle: true
|
|
||||||
|
|
||||||
clipping_strategy:
|
|
||||||
name: random_subclip
|
|
||||||
duration: 0.256
|
|
||||||
|
|
||||||
augmentations:
|
|
||||||
enabled: true
|
|
||||||
audio:
|
|
||||||
- name: mix_audio
|
|
||||||
probability: 0.2
|
|
||||||
min_weight: 0.3
|
|
||||||
max_weight: 0.7
|
|
||||||
- name: add_echo
|
|
||||||
probability: 0.2
|
|
||||||
max_delay: 0.005
|
|
||||||
min_weight: 0.0
|
|
||||||
max_weight: 1.0
|
|
||||||
spectrogram:
|
|
||||||
- name: scale_volume
|
|
||||||
probability: 0.2
|
|
||||||
min_scaling: 0.0
|
|
||||||
max_scaling: 2.0
|
|
||||||
- name: warp
|
|
||||||
probability: 0.2
|
|
||||||
delta: 0.04
|
|
||||||
- name: mask_time
|
|
||||||
probability: 0.2
|
|
||||||
max_perc: 0.05
|
|
||||||
max_masks: 3
|
|
||||||
- name: mask_freq
|
|
||||||
probability: 0.2
|
|
||||||
max_perc: 0.10
|
|
||||||
max_masks: 3
|
|
||||||
|
|
||||||
val_loader:
|
|
||||||
clipping_strategy:
|
|
||||||
name: whole_audio_padded
|
|
||||||
chunk_size: 0.256
|
|
||||||
|
|
||||||
loss:
|
|
||||||
detection:
|
|
||||||
weight: 1.0
|
|
||||||
focal:
|
|
||||||
beta: 4
|
|
||||||
alpha: 2
|
|
||||||
classification:
|
|
||||||
weight: 2.0
|
|
||||||
focal:
|
|
||||||
beta: 4
|
|
||||||
alpha: 2
|
|
||||||
size:
|
|
||||||
weight: 0.1
|
|
||||||
|
|
||||||
validation:
|
|
||||||
tasks:
|
|
||||||
- name: sound_event_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: sound_event_classification
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
8
justfile
8
justfile
@ -112,12 +112,6 @@ clean: clean-build clean-pyc clean-test clean-docs
|
|||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
uv run batdetect2 train \
|
uv run batdetect2 train \
|
||||||
--val-dataset example_data/dataset.yaml \
|
--val-dataset example_data/dataset.yaml \
|
||||||
--base-dir . \
|
--config example_data/config.yaml \
|
||||||
--targets example_data/targets.yaml \
|
|
||||||
--model-config example_data/configs/model.yaml \
|
|
||||||
--training-config example_data/configs/training.yaml \
|
|
||||||
--audio-config example_data/configs/audio.yaml \
|
|
||||||
--evaluation-config example_data/configs/evaluation.yaml \
|
|
||||||
--logging-config example_data/configs/logging.yaml \
|
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
example_data/dataset.yaml
|
example_data/dataset.yaml
|
||||||
|
|||||||
@ -12,14 +12,11 @@ if TYPE_CHECKING:
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, AudioLoader
|
from batdetect2.audio import AudioConfig, AudioLoader
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.data import Dataset
|
from batdetect2.data import Dataset
|
||||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import (
|
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
||||||
AppLoggingConfig,
|
|
||||||
LoggerConfig,
|
|
||||||
LoggingCallback,
|
|
||||||
)
|
|
||||||
from batdetect2.models import Model, ModelConfig
|
from batdetect2.models import Model, ModelConfig
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputFormatConfig,
|
OutputFormatConfig,
|
||||||
@ -39,7 +36,6 @@ if TYPE_CHECKING:
|
|||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.train import TrainingConfig
|
from batdetect2.train import TrainingConfig
|
||||||
from batdetect2.train.logging import TrainLoggingContext
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
@ -111,11 +107,9 @@ class BatDetect2API:
|
|||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
|
||||||
):
|
):
|
||||||
from batdetect2.train import run_train
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
self.model.train()
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
@ -136,15 +130,12 @@ class BatDetect2API:
|
|||||||
train_config=train_config or self.train_config,
|
train_config=train_config or self.train_config,
|
||||||
audio_config=audio_config or self.audio_config,
|
audio_config=audio_config or self.audio_config,
|
||||||
logger_config=logger_config or self.logging_config.train,
|
logger_config=logger_config or self.logging_config.train,
|
||||||
logging_callbacks=logging_callbacks,
|
|
||||||
)
|
)
|
||||||
self.model.eval()
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def finetune(
|
def finetune(
|
||||||
self,
|
self,
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets_config: TargetConfig,
|
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
trainable: Literal[
|
trainable: Literal[
|
||||||
"all", "heads", "classifier_head", "bbox_head"
|
"all", "heads", "classifier_head", "bbox_head"
|
||||||
@ -157,77 +148,25 @@ class BatDetect2API:
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
|
model_config: ModelConfig | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
"""Fine-tune from a checkpoint using a new target definition."""
|
"""Fine-tune the model with trainable-parameter selection."""
|
||||||
from batdetect2.evaluate import build_evaluator
|
|
||||||
from batdetect2.models import build_model_with_new_targets
|
|
||||||
from batdetect2.outputs import (
|
|
||||||
build_output_formatter,
|
|
||||||
build_output_transform,
|
|
||||||
)
|
|
||||||
from batdetect2.targets import (
|
|
||||||
TargetConfig,
|
|
||||||
build_roi_mapping,
|
|
||||||
build_targets,
|
|
||||||
)
|
|
||||||
from batdetect2.train import run_train
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
target_config = TargetConfig.model_validate(targets_config)
|
self._set_trainable_parameters(trainable)
|
||||||
targets = build_targets(config=target_config)
|
|
||||||
roi_mapper = build_roi_mapping(config=target_config.roi)
|
|
||||||
model = build_model_with_new_targets(
|
|
||||||
model=self.model,
|
|
||||||
targets=targets,
|
|
||||||
roi_mapper=roi_mapper,
|
|
||||||
)
|
|
||||||
output_transform = build_output_transform(
|
|
||||||
config=self.outputs_config.transform,
|
|
||||||
targets=targets,
|
|
||||||
roi_mapper=roi_mapper,
|
|
||||||
)
|
|
||||||
api = BatDetect2API(
|
|
||||||
model_config=self.model_config,
|
|
||||||
audio_config=audio_config or self.audio_config,
|
|
||||||
train_config=train_config or self.train_config,
|
|
||||||
evaluation_config=self.evaluation_config,
|
|
||||||
inference_config=self.inference_config,
|
|
||||||
outputs_config=self.outputs_config,
|
|
||||||
logging_config=self.logging_config,
|
|
||||||
targets=targets,
|
|
||||||
roi_mapper=roi_mapper,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
postprocessor=self.postprocessor,
|
|
||||||
evaluator=build_evaluator(
|
|
||||||
config=self.evaluation_config,
|
|
||||||
targets=targets,
|
|
||||||
roi_mapper=roi_mapper,
|
|
||||||
transform=output_transform,
|
|
||||||
),
|
|
||||||
formatter=build_output_formatter(
|
|
||||||
targets,
|
|
||||||
config=self.outputs_config.format,
|
|
||||||
),
|
|
||||||
output_transform=output_transform,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
api._set_trainable_parameters(trainable)
|
|
||||||
api.model.train()
|
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
model=api.model,
|
model=self.model,
|
||||||
targets=api.targets,
|
targets=self.targets,
|
||||||
roi_mapper=api.roi_mapper,
|
roi_mapper=self.roi_mapper,
|
||||||
model_config=api.model_config,
|
model_config=model_config or self.model_config,
|
||||||
preprocessor=api.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
audio_loader=api.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
@ -236,13 +175,11 @@ class BatDetect2API:
|
|||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
audio_config=api.audio_config,
|
audio_config=audio_config or self.audio_config,
|
||||||
train_config=api.train_config,
|
train_config=train_config or self.train_config,
|
||||||
logger_config=logger_config or api.logging_config.train,
|
logger_config=logger_config or self.logging_config.train,
|
||||||
logging_callbacks=logging_callbacks,
|
|
||||||
)
|
)
|
||||||
api.model.eval()
|
return self
|
||||||
return api
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
@ -546,70 +483,46 @@ class BatDetect2API:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
model_config: ModelConfig | None = None,
|
config: BatDetect2Config,
|
||||||
targets_config: TargetConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
|
||||||
train_config: TrainingConfig | None = None,
|
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
|
||||||
inference_config: InferenceConfig | None = None,
|
|
||||||
outputs_config: OutputsConfig | None = None,
|
|
||||||
logging_config: AppLoggingConfig | None = None,
|
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.models import build_model
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.models import ModelConfig, build_model
|
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputsConfig,
|
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
build_output_transform,
|
build_output_transform,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
TargetConfig,
|
|
||||||
build_roi_mapping,
|
|
||||||
build_targets,
|
|
||||||
)
|
|
||||||
from batdetect2.train import TrainingConfig
|
|
||||||
|
|
||||||
model_config = model_config or ModelConfig()
|
targets = build_targets(config=config.model.targets)
|
||||||
targets_config = targets_config or TargetConfig()
|
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||||
audio_config = audio_config or AudioConfig()
|
|
||||||
train_config = train_config or TrainingConfig()
|
|
||||||
evaluation_config = evaluation_config or EvaluationConfig()
|
|
||||||
inference_config = inference_config or InferenceConfig()
|
|
||||||
outputs_config = outputs_config or OutputsConfig()
|
|
||||||
logging_config = logging_config or AppLoggingConfig()
|
|
||||||
|
|
||||||
targets = build_targets(config=targets_config)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=audio_config)
|
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=model_config.preprocess,
|
config=config.model.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=model_config.postprocess,
|
config=config.model.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatter = build_output_formatter(
|
formatter = build_output_formatter(
|
||||||
targets,
|
targets,
|
||||||
config=outputs_config.format,
|
config=config.outputs.format,
|
||||||
)
|
)
|
||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=outputs_config.transform,
|
config=config.outputs.transform,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(
|
||||||
config=evaluation_config,
|
config=config.evaluation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
transform=output_transform,
|
transform=output_transform,
|
||||||
@ -618,27 +531,27 @@ class BatDetect2API:
|
|||||||
# NOTE: Build separate instances of preprocessor and postprocessor
|
# NOTE: Build separate instances of preprocessor and postprocessor
|
||||||
# to avoid device mismatch errors
|
# to avoid device mismatch errors
|
||||||
model = build_model(
|
model = build_model(
|
||||||
config=model_config,
|
config=config.model,
|
||||||
class_names=targets.class_names,
|
targets=targets,
|
||||||
dimension_names=roi_mapper.dimension_names,
|
roi_mapper=roi_mapper,
|
||||||
preprocessor=build_preprocessor(
|
preprocessor=build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=model_config.preprocess,
|
config=config.model.preprocess,
|
||||||
),
|
),
|
||||||
postprocessor=build_postprocessor(
|
postprocessor=build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=model_config.postprocess,
|
config=config.model.postprocess,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
model_config=model_config,
|
model_config=config.model,
|
||||||
audio_config=audio_config,
|
audio_config=config.audio,
|
||||||
train_config=train_config,
|
train_config=config.train,
|
||||||
evaluation_config=evaluation_config,
|
evaluation_config=config.evaluation,
|
||||||
inference_config=inference_config,
|
inference_config=config.inference,
|
||||||
outputs_config=outputs_config,
|
outputs_config=config.outputs,
|
||||||
logging_config=logging_config,
|
logging_config=config.logging,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
@ -654,6 +567,7 @@ class BatDetect2API:
|
|||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
|
targets_config: TargetConfig | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
@ -665,6 +579,7 @@ class BatDetect2API:
|
|||||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
|
from batdetect2.models import build_model_with_new_targets
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputsConfig,
|
OutputsConfig,
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
@ -672,41 +587,37 @@ class BatDetect2API:
|
|||||||
)
|
)
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
build_roi_mapping,
|
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||||
build_targets,
|
|
||||||
check_target_compatibility,
|
|
||||||
)
|
|
||||||
from batdetect2.train import load_model_from_checkpoint
|
|
||||||
|
|
||||||
model, configs = load_model_from_checkpoint(path)
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
model_config = configs.model
|
|
||||||
train_config = train_config or configs.train
|
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig(
|
audio_config = audio_config or AudioConfig(
|
||||||
samplerate=model_config.samplerate,
|
samplerate=model_config.samplerate,
|
||||||
)
|
)
|
||||||
|
train_config = train_config or TrainingConfig()
|
||||||
evaluation_config = evaluation_config or EvaluationConfig()
|
evaluation_config = evaluation_config or EvaluationConfig()
|
||||||
inference_config = inference_config or InferenceConfig()
|
inference_config = inference_config or InferenceConfig()
|
||||||
outputs_config = outputs_config or OutputsConfig()
|
outputs_config = outputs_config or OutputsConfig()
|
||||||
logging_config = logging_config or AppLoggingConfig()
|
logging_config = logging_config or AppLoggingConfig()
|
||||||
targets_config = configs.targets
|
|
||||||
|
|
||||||
|
if (
|
||||||
|
targets_config is not None
|
||||||
|
and targets_config != model_config.targets
|
||||||
|
):
|
||||||
targets = build_targets(config=targets_config)
|
targets = build_targets(config=targets_config)
|
||||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||||
|
model = build_model_with_new_targets(
|
||||||
if not check_target_compatibility(targets, model.class_names):
|
model=model,
|
||||||
raise ValueError(
|
targets=targets,
|
||||||
"Provided targets_config is incompatible with the "
|
roi_mapper=roi_mapper,
|
||||||
"checkpoint model: missing one or more model classes."
|
)
|
||||||
|
model_config = model_config.model_copy(
|
||||||
|
update={"targets": targets_config}
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.dimension_names != roi_mapper.dimension_names:
|
targets = build_targets(config=model_config.targets)
|
||||||
raise ValueError(
|
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
|
||||||
"Provided targets_config is incompatible with the "
|
|
||||||
"checkpoint model: mismatched dimension names."
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=audio_config)
|
audio_loader = build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from batdetect2.cli.base import cli
|
|||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.evaluate import evaluate_command
|
from batdetect2.cli.evaluate import evaluate_command
|
||||||
from batdetect2.cli.finetune import finetune_command
|
|
||||||
from batdetect2.cli.inference import predict
|
from batdetect2.cli.inference import predict
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
@ -11,7 +10,6 @@ __all__ = [
|
|||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
"finetune_command",
|
|
||||||
"evaluate_command",
|
"evaluate_command",
|
||||||
"predict",
|
"predict",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -77,7 +77,6 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
"num_workers",
|
"num_workers",
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of worker processes for dataset loading.",
|
help="Number of worker processes for dataset loading.",
|
||||||
default=0,
|
|
||||||
)
|
)
|
||||||
def evaluate_command(
|
def evaluate_command(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
@ -106,6 +105,7 @@ def evaluate_command(
|
|||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
@ -119,6 +119,11 @@ def evaluate_command(
|
|||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
target_conf = (
|
||||||
|
TargetConfig.load(targets_config)
|
||||||
|
if targets_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
audio_conf = (
|
audio_conf = (
|
||||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||||
)
|
)
|
||||||
@ -145,6 +150,7 @@ def evaluate_command(
|
|||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
|
targets_config=target_conf,
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
evaluation_config=eval_conf,
|
evaluation_config=eval_conf,
|
||||||
inference_config=inference_conf,
|
inference_config=inference_conf,
|
||||||
|
|||||||
@ -1,211 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Literal, cast
|
|
||||||
|
|
||||||
import click
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
|
|
||||||
__all__ = ["finetune_command"]
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command(
|
|
||||||
name="finetune", short_help="Fine-tune a checkpoint on new targets."
|
|
||||||
)
|
|
||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
|
||||||
@click.option(
|
|
||||||
"--model",
|
|
||||||
"model_path",
|
|
||||||
required=True,
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to a checkpoint to fine-tune from.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--targets",
|
|
||||||
"targets_config",
|
|
||||||
required=True,
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to the new targets config file.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--val-dataset",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to validation dataset config file.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--base-dir",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Base directory used to resolve relative paths inside the training "
|
|
||||||
"and validation dataset configs."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--training-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to training config file.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--audio-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to audio config file.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--logging-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to logging config file.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--trainable",
|
|
||||||
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
|
||||||
default="heads",
|
|
||||||
show_default=True,
|
|
||||||
help="Which model parameters remain trainable during fine-tuning.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--ckpt-dir",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Directory where checkpoints are saved.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--log-dir",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Directory where logs are written.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-workers",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Number of worker processes for training data loading.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--val-workers",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Number of worker processes for validation data loading.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--num-epochs",
|
|
||||||
type=int,
|
|
||||||
help="Maximum number of training epochs.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--experiment-name",
|
|
||||||
type=str,
|
|
||||||
help="Experiment name used for logging backends.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--run-name",
|
|
||||||
type=str,
|
|
||||||
help="Run name used for logging backends.",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
help="Random seed used for reproducibility.",
|
|
||||||
)
|
|
||||||
def finetune_command(
|
|
||||||
train_dataset: Path,
|
|
||||||
model_path: Path,
|
|
||||||
targets_config: Path,
|
|
||||||
val_dataset: Path | None = None,
|
|
||||||
ckpt_dir: Path | None = None,
|
|
||||||
log_dir: Path | None = None,
|
|
||||||
base_dir: Path | None = None,
|
|
||||||
training_config: Path | None = None,
|
|
||||||
audio_config: Path | None = None,
|
|
||||||
logging_config: Path | None = None,
|
|
||||||
trainable: str = "heads",
|
|
||||||
seed: int | None = None,
|
|
||||||
num_epochs: int | None = None,
|
|
||||||
train_workers: int = 0,
|
|
||||||
val_workers: int = 0,
|
|
||||||
experiment_name: str | None = None,
|
|
||||||
run_name: str | None = None,
|
|
||||||
):
|
|
||||||
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
from batdetect2.data import load_dataset, load_dataset_config
|
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train import TrainingConfig
|
|
||||||
from batdetect2.train.logging import (
|
|
||||||
DatasetConfigArtifact,
|
|
||||||
DatasetConfigArtifactLogging,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Initiating fine-tuning process...")
|
|
||||||
|
|
||||||
target_conf = TargetConfig.load(targets_config)
|
|
||||||
train_conf = (
|
|
||||||
TrainingConfig.load(training_config)
|
|
||||||
if training_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
audio_conf = (
|
|
||||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
|
||||||
)
|
|
||||||
logging_conf = (
|
|
||||||
AppLoggingConfig.load(logging_config)
|
|
||||||
if logging_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset_conf = load_dataset_config(train_dataset)
|
|
||||||
train_annotations = load_dataset(train_dataset_conf, base_dir=base_dir)
|
|
||||||
|
|
||||||
val_dataset_conf = (
|
|
||||||
load_dataset_config(val_dataset) if val_dataset else None
|
|
||||||
)
|
|
||||||
val_annotations = (
|
|
||||||
load_dataset(val_dataset_conf, base_dir=base_dir)
|
|
||||||
if val_dataset_conf
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
logging_callbacks = [
|
|
||||||
DatasetConfigArtifactLogging(
|
|
||||||
train_dataset_config=DatasetConfigArtifact(
|
|
||||||
filename="train_dataset.yaml",
|
|
||||||
config=train_dataset_conf,
|
|
||||||
),
|
|
||||||
val_dataset_config=(
|
|
||||||
DatasetConfigArtifact(
|
|
||||||
filename="val_dataset.yaml",
|
|
||||||
config=val_dataset_conf,
|
|
||||||
)
|
|
||||||
if val_dataset_conf
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
|
||||||
model_path,
|
|
||||||
train_config=train_conf,
|
|
||||||
audio_config=audio_conf,
|
|
||||||
logging_config=logging_conf,
|
|
||||||
)
|
|
||||||
|
|
||||||
return api.finetune(
|
|
||||||
train_annotations=train_annotations,
|
|
||||||
val_annotations=val_annotations,
|
|
||||||
targets_config=target_conf,
|
|
||||||
trainable=cast(
|
|
||||||
Literal["all", "heads", "classifier_head", "bbox_head"],
|
|
||||||
trainable,
|
|
||||||
),
|
|
||||||
train_workers=train_workers,
|
|
||||||
val_workers=val_workers,
|
|
||||||
checkpoint_dir=ckpt_dir,
|
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
num_epochs=num_epochs,
|
|
||||||
run_name=run_name,
|
|
||||||
seed=seed,
|
|
||||||
train_config=train_conf,
|
|
||||||
audio_config=audio_conf,
|
|
||||||
logger_config=logging_conf.train if logging_conf is not None else None,
|
|
||||||
logging_callbacks=logging_callbacks,
|
|
||||||
)
|
|
||||||
@ -86,13 +86,11 @@ __all__ = ["train_command"]
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--train-workers",
|
"--train-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
|
||||||
help="Number of worker processes for training data loading.",
|
help="Number of worker processes for training data loading.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--val-workers",
|
"--val-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
|
||||||
help="Number of worker processes for validation data loading.",
|
help="Number of worker processes for validation data loading.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -145,7 +143,8 @@ def train_command(
|
|||||||
"""
|
"""
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.data import load_dataset_config, load_dataset_from_config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
@ -153,10 +152,6 @@ def train_command(
|
|||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
from batdetect2.targets import TargetConfig
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train import TrainingConfig
|
from batdetect2.train import TrainingConfig
|
||||||
from batdetect2.train.logging import (
|
|
||||||
DatasetConfigArtifact,
|
|
||||||
DatasetConfigArtifactLogging,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
@ -201,6 +196,9 @@ def train_command(
|
|||||||
if target_conf is not None:
|
if target_conf is not None:
|
||||||
logger.info("Loaded targets configuration.")
|
logger.info("Loaded targets configuration.")
|
||||||
|
|
||||||
|
if model_conf is not None and target_conf is not None:
|
||||||
|
model_conf = model_conf.model_copy(update={"targets": target_conf})
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
logger.info("Loading training dataset...")
|
||||||
train_annotations = load_dataset_from_config(
|
train_annotations = load_dataset_from_config(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@ -226,49 +224,37 @@ def train_command(
|
|||||||
|
|
||||||
logger.info("Configuration and data loaded. Starting training...")
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
|
|
||||||
logging_callbacks = [
|
|
||||||
DatasetConfigArtifactLogging(
|
|
||||||
train_dataset_config=DatasetConfigArtifact(
|
|
||||||
filename="train_dataset.yaml",
|
|
||||||
config=load_dataset_config(train_dataset),
|
|
||||||
),
|
|
||||||
val_dataset_config=(
|
|
||||||
DatasetConfigArtifact(
|
|
||||||
filename="val_dataset.yaml",
|
|
||||||
config=load_dataset_config(val_dataset),
|
|
||||||
)
|
|
||||||
if val_dataset is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if model_path is not None and model_conf is not None:
|
if model_path is not None and model_conf is not None:
|
||||||
raise click.UsageError(
|
raise click.UsageError(
|
||||||
"--model-config cannot be used with --model. "
|
"--model-config cannot be used with --model. "
|
||||||
"Checkpoint model configuration is loaded from the checkpoint."
|
"Checkpoint model configuration is loaded from the checkpoint."
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_path is not None and target_conf is not None:
|
|
||||||
raise click.UsageError(
|
|
||||||
"--targets cannot be used with --model. "
|
|
||||||
"Checkpoint target configuration is loaded from the checkpoint."
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
api = BatDetect2API.from_config(
|
conf = BatDetect2Config()
|
||||||
model_config=model_conf,
|
if model_conf is not None:
|
||||||
targets_config=target_conf,
|
conf.model = model_conf
|
||||||
train_config=train_conf,
|
elif target_conf is not None:
|
||||||
audio_config=audio_conf,
|
conf.model = conf.model.model_copy(update={"targets": target_conf})
|
||||||
evaluation_config=eval_conf,
|
|
||||||
inference_config=inference_conf,
|
if train_conf is not None:
|
||||||
outputs_config=outputs_conf,
|
conf.train = train_conf
|
||||||
logging_config=logging_conf,
|
if audio_conf is not None:
|
||||||
)
|
conf.audio = audio_conf
|
||||||
|
if eval_conf is not None:
|
||||||
|
conf.evaluation = eval_conf
|
||||||
|
if inference_conf is not None:
|
||||||
|
conf.inference = inference_conf
|
||||||
|
if outputs_conf is not None:
|
||||||
|
conf.outputs = outputs_conf
|
||||||
|
if logging_conf is not None:
|
||||||
|
conf.logging = logging_conf
|
||||||
|
|
||||||
|
api = BatDetect2API.from_config(conf)
|
||||||
else:
|
else:
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
|
targets_config=target_conf,
|
||||||
train_config=train_conf,
|
train_config=train_conf,
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
evaluation_config=eval_conf,
|
evaluation_config=eval_conf,
|
||||||
@ -288,5 +274,4 @@ def train_command(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
logging_callbacks=logging_callbacks,
|
|
||||||
)
|
)
|
||||||
|
|||||||
31
src/batdetect2/config.py
Normal file
31
src/batdetect2/config.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.evaluate.config import (
|
||||||
|
EvaluationConfig,
|
||||||
|
get_default_eval_config,
|
||||||
|
)
|
||||||
|
from batdetect2.inference.config import InferenceConfig
|
||||||
|
from batdetect2.logging import AppLoggingConfig
|
||||||
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
|
__all__ = ["BatDetect2Config"]
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2Config(BaseConfig):
|
||||||
|
config_version: Literal["v1"] = "v1"
|
||||||
|
|
||||||
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
evaluation: EvaluationConfig = Field(
|
||||||
|
default_factory=get_default_eval_config
|
||||||
|
)
|
||||||
|
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
|
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||||
|
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||||
@ -12,8 +12,6 @@ from typing import (
|
|||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"add_import_config",
|
"add_import_config",
|
||||||
"ImportConfig",
|
"ImportConfig",
|
||||||
@ -122,7 +120,7 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
return self._registry[name](config, *args, **kwargs)
|
return self._registry[name](config, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ImportConfig(BaseConfig):
|
class ImportConfig(BaseModel):
|
||||||
"""Base config for dynamic instantiation via Hydra.
|
"""Base config for dynamic instantiation via Hydra.
|
||||||
|
|
||||||
Subclass this to create a registry-specific import escape hatch.
|
Subclass this to create a registry-specific import escape hatch.
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from batdetect2.data.conditions.common import (
|
|||||||
IdInListConfig,
|
IdInListConfig,
|
||||||
JsonList,
|
JsonList,
|
||||||
ListFormatConfig,
|
ListFormatConfig,
|
||||||
TagInfo,
|
|
||||||
TxtList,
|
TxtList,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions.recordings import (
|
from batdetect2.data.conditions.recordings import (
|
||||||
@ -64,17 +63,16 @@ __all__ = [
|
|||||||
"NotConfig",
|
"NotConfig",
|
||||||
"Operator",
|
"Operator",
|
||||||
"PathInListConfig",
|
"PathInListConfig",
|
||||||
"RecordingAllOfConfig",
|
|
||||||
"RecordingAnyOfConfig",
|
|
||||||
"RecordingCondition",
|
"RecordingCondition",
|
||||||
"RecordingConditionConfig",
|
"RecordingConditionConfig",
|
||||||
"RecordingConditionImportConfig",
|
"RecordingConditionImportConfig",
|
||||||
|
"RecordingAllOfConfig",
|
||||||
|
"RecordingAnyOfConfig",
|
||||||
"RecordingNotConfig",
|
"RecordingNotConfig",
|
||||||
"RecordingSatisfiesConfig",
|
"RecordingSatisfiesConfig",
|
||||||
"SoundEventCondition",
|
"SoundEventCondition",
|
||||||
"SoundEventConditionConfig",
|
"SoundEventConditionConfig",
|
||||||
"SoundEventConditionImportConfig",
|
"SoundEventConditionImportConfig",
|
||||||
"TagInfo",
|
|
||||||
"TxtList",
|
"TxtList",
|
||||||
"build_clip_annotation_condition",
|
"build_clip_annotation_condition",
|
||||||
"build_recording_condition",
|
"build_recording_condition",
|
||||||
|
|||||||
@ -2,23 +2,10 @@ import csv
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Generic,
|
|
||||||
Literal,
|
|
||||||
ParamSpec,
|
|
||||||
Protocol,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import BaseModel, Field, model_validator
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
PlainSerializer,
|
|
||||||
model_validator,
|
|
||||||
)
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
@ -151,26 +138,19 @@ class IdInList(Generic[UUIDObject]):
|
|||||||
return obj.uuid in self.ids
|
return obj.uuid in self.ids
|
||||||
|
|
||||||
|
|
||||||
def dump_tag(tag: data.Tag) -> dict[str, Any]:
|
|
||||||
return {"key": tag.term.name, "value": tag.value}
|
|
||||||
|
|
||||||
|
|
||||||
TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)]
|
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
class HasTagConfig(BaseConfig):
|
||||||
name: Literal["has_tag"] = "has_tag"
|
name: Literal["has_tag"] = "has_tag"
|
||||||
tag: TagInfo
|
tag: data.Tag
|
||||||
|
|
||||||
|
|
||||||
class HasAllTagsConfig(BaseConfig):
|
class HasAllTagsConfig(BaseConfig):
|
||||||
name: Literal["has_all_tags"] = "has_all_tags"
|
name: Literal["has_all_tags"] = "has_all_tags"
|
||||||
tags: list[TagInfo]
|
tags: list[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTagConfig(BaseConfig):
|
class HasAnyTagConfig(BaseConfig):
|
||||||
name: Literal["has_any_tag"] = "has_any_tag"
|
name: Literal["has_any_tag"] = "has_any_tag"
|
||||||
tags: list[TagInfo]
|
tags: list[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
class JsonList(BaseConfig):
|
class JsonList(BaseConfig):
|
||||||
|
|||||||
@ -94,7 +94,7 @@ def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
|
|||||||
return partial(operator.ge, value)
|
return partial(operator.ge, value)
|
||||||
|
|
||||||
if op == "eq":
|
if op == "eq":
|
||||||
return partial(operator.eq, value)
|
return partial(operator.eq, b=value)
|
||||||
|
|
||||||
raise ValueError(f"Invalid operator {op}")
|
raise ValueError(f"Invalid operator {op}")
|
||||||
|
|
||||||
|
|||||||
@ -24,8 +24,8 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
|||||||
def run_evaluate(
|
def run_evaluate(
|
||||||
model: Model,
|
model: Model,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol | None = None,
|
||||||
roi_mapper: ROIMapperProtocol,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
@ -46,6 +46,8 @@ def run_evaluate(
|
|||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or model.preprocessor
|
||||||
|
targets = targets or model.targets
|
||||||
|
roi_mapper = roi_mapper or model.roi_mapper
|
||||||
|
|
||||||
loader = build_test_loader(
|
loader = build_test_loader(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
|
|||||||
@ -45,16 +45,8 @@ def run_batch_inference(
|
|||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or model.preprocessor
|
||||||
|
targets = targets or model.targets
|
||||||
if targets is None:
|
roi_mapper = roi_mapper or model.roi_mapper
|
||||||
raise ValueError(
|
|
||||||
"targets must be provided when running batch inference."
|
|
||||||
)
|
|
||||||
|
|
||||||
if roi_mapper is None:
|
|
||||||
raise ValueError(
|
|
||||||
"roi_mapper must be provided when running batch inference."
|
|
||||||
)
|
|
||||||
|
|
||||||
output_transform = output_transform or build_output_transform(
|
output_transform = output_transform or build_output_transform(
|
||||||
config=output_config.transform,
|
config=output_config.transform,
|
||||||
|
|||||||
@ -7,37 +7,21 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
targets: TargetProtocol | None = None,
|
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.detection_threshold = detection_threshold
|
self.detection_threshold = detection_threshold
|
||||||
|
|
||||||
if output_transform is None and targets is None:
|
|
||||||
raise ValueError(
|
|
||||||
"targets must be provided when building inference output "
|
|
||||||
"transforms."
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_transform is None and roi_mapper is None:
|
|
||||||
raise ValueError(
|
|
||||||
"roi_mapper must be provided when building inference output "
|
|
||||||
"transforms."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_transform = output_transform or build_output_transform(
|
self.output_transform = output_transform or build_output_transform(
|
||||||
targets=targets,
|
targets=model.targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=model.roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
|
|||||||
@ -24,7 +24,12 @@ from batdetect2.core.configs import BaseConfig
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from lightning.pytorch.loggers import Logger
|
from lightning.pytorch.loggers import (
|
||||||
|
CSVLogger,
|
||||||
|
Logger,
|
||||||
|
MLFlowLogger,
|
||||||
|
TensorBoardLogger,
|
||||||
|
)
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -38,15 +43,11 @@ __all__ = [
|
|||||||
"DVCLiveConfig",
|
"DVCLiveConfig",
|
||||||
"LoggerConfig",
|
"LoggerConfig",
|
||||||
"MLFlowLoggerConfig",
|
"MLFlowLoggerConfig",
|
||||||
"LoggingCallback",
|
|
||||||
"TensorBoardLoggerConfig",
|
"TensorBoardLoggerConfig",
|
||||||
"build_logger",
|
"build_logger",
|
||||||
"enable_logging",
|
"enable_logging",
|
||||||
"get_image_logger",
|
"get_image_logger",
|
||||||
"get_table_logger",
|
"get_table_logger",
|
||||||
"log_artifact_file",
|
|
||||||
"log_config_artifact",
|
|
||||||
"log_csv_artifact",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -122,18 +123,6 @@ class LoggerBuilder(Protocol, Generic[T]):
|
|||||||
) -> Logger: ...
|
) -> Logger: ...
|
||||||
|
|
||||||
|
|
||||||
LoggingContext = TypeVar("LoggingContext", contravariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallback(Protocol, Generic[LoggingContext]):
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
logger: Logger,
|
|
||||||
artifact_path: Path,
|
|
||||||
context: LoggingContext,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(
|
def create_dvclive_logger(
|
||||||
config: DVCLiveConfig,
|
config: DVCLiveConfig,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
@ -287,71 +276,6 @@ def build_logger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_artifact_file(
|
|
||||||
runtime_logger: Logger,
|
|
||||||
path: Path,
|
|
||||||
artifact_path: str = "artifacts",
|
|
||||||
) -> None:
|
|
||||||
from lightning.pytorch.loggers import (
|
|
||||||
CSVLogger,
|
|
||||||
MLFlowLogger,
|
|
||||||
TensorBoardLogger,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(runtime_logger, MLFlowLogger):
|
|
||||||
runtime_logger.experiment.log_artifact( # type: ignore[call-arg]
|
|
||||||
local_path=str(path),
|
|
||||||
artifact_path=artifact_path,
|
|
||||||
run_id=runtime_logger.run_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
experiment = getattr(runtime_logger, "experiment", None)
|
|
||||||
if experiment is not None and hasattr(experiment, "log_artifact"):
|
|
||||||
experiment.log_artifact(path=path, name=path.name, copy=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(runtime_logger, (CSVLogger, TensorBoardLogger)):
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"Skipping artifact logging for unsupported logger type {logger_type}",
|
|
||||||
logger_type=type(runtime_logger).__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_config_artifact(
|
|
||||||
logger: Logger,
|
|
||||||
config: BaseConfig,
|
|
||||||
filename: str,
|
|
||||||
artifact_path: Path,
|
|
||||||
) -> None:
|
|
||||||
artifact_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
path = artifact_path / filename
|
|
||||||
path.write_text(config.to_yaml_string())
|
|
||||||
log_artifact_file(
|
|
||||||
logger,
|
|
||||||
path,
|
|
||||||
artifact_path=artifact_path.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_csv_artifact(
|
|
||||||
logger: Logger,
|
|
||||||
df: pd.DataFrame,
|
|
||||||
filename: str,
|
|
||||||
artifact_path: Path,
|
|
||||||
) -> None:
|
|
||||||
artifact_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
path = artifact_path / filename
|
|
||||||
df.to_csv(path, index=False)
|
|
||||||
log_artifact_file(
|
|
||||||
logger,
|
|
||||||
path,
|
|
||||||
artifact_path=artifact_path.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
PlotLogger = Callable[[str, "Figure", int], None]
|
PlotLogger = Callable[[str, "Figure", int], None]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -26,8 +26,11 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
|
|||||||
is the ``build_model`` factory function exported from this module.
|
is the ``build_model`` factory function exported from this module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
@ -70,6 +73,7 @@ from batdetect2.postprocess.types import (
|
|||||||
)
|
)
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.config import TargetConfig
|
||||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -127,6 +131,10 @@ class ModelConfig(BaseConfig):
|
|||||||
Parameters for converting raw model outputs into detections (NMS
|
Parameters for converting raw model outputs into detections (NMS
|
||||||
kernel, thresholds, top-k limit). Defaults to
|
kernel, thresholds, top-k limit). Defaults to
|
||||||
``PostprocessConfig()``.
|
``PostprocessConfig()``.
|
||||||
|
targets : TargetConfig
|
||||||
|
Detection and classification target definitions (class list,
|
||||||
|
detection target, bounding-box mapper). Defaults to
|
||||||
|
``TargetConfig()``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
@ -135,6 +143,23 @@ class ModelConfig(BaseConfig):
|
|||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
path: PathLike,
|
||||||
|
field: str | None = None,
|
||||||
|
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
|
targets: TargetConfig | None = None,
|
||||||
|
) -> "ModelConfig":
|
||||||
|
config = super().load(path, field, extra, strict)
|
||||||
|
|
||||||
|
if targets is None:
|
||||||
|
return config
|
||||||
|
|
||||||
|
return config.model_copy(update={"targets": targets})
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
@ -158,32 +183,33 @@ class Model(torch.nn.Module):
|
|||||||
postprocessor : PostprocessorProtocol
|
postprocessor : PostprocessorProtocol
|
||||||
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
||||||
per-clip detection tensors.
|
per-clip detection tensors.
|
||||||
class_names : list[str]
|
targets : TargetProtocol
|
||||||
Class names corresponding to the model classification outputs.
|
Describes the set of target classes; used when building heads and
|
||||||
dimension_names : list[str]
|
during training target construction.
|
||||||
Size-dimension names corresponding to the model size outputs.
|
roi_mapper : ROIMapperProtocol
|
||||||
|
Maps geometries to target-size channels and back.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
class_names: list[str]
|
targets: TargetProtocol
|
||||||
dimension_names: list[str]
|
roi_mapper: ROIMapperProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector: DetectionModel,
|
detector: DetectionModel,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
class_names: list[str],
|
targets: TargetProtocol,
|
||||||
dimension_names: list[str],
|
roi_mapper: ROIMapperProtocol,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.class_names = class_names
|
self.targets = targets
|
||||||
self.dimension_names = dimension_names
|
self.roi_mapper = roi_mapper
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||||
"""Run the full detection pipeline on a waveform tensor.
|
"""Run the full detection pipeline on a waveform tensor.
|
||||||
@ -211,9 +237,9 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def build_model(
|
def build_model(
|
||||||
config: ModelConfig | dict | None = None,
|
config: ModelConfig | None = None,
|
||||||
class_names: list[str] | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
dimension_names: list[str] | None = None,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
@ -228,13 +254,11 @@ def build_model(
|
|||||||
----------
|
----------
|
||||||
config : ModelConfig, optional
|
config : ModelConfig, optional
|
||||||
Full model configuration (samplerate, architecture, preprocessing,
|
Full model configuration (samplerate, architecture, preprocessing,
|
||||||
postprocessing). Defaults to ``ModelConfig()`` if not provided.
|
postprocessing, targets). Defaults to ``ModelConfig()`` if not
|
||||||
class_names : list[str], optional
|
provided.
|
||||||
Class names used to size the classifier head. Required when building
|
targets : TargetProtocol, optional
|
||||||
a new model.
|
Pre-built targets object. If given, overrides
|
||||||
dimension_names : list[str], optional
|
``config.targets``.
|
||||||
Dimension names used to size the bbox head. Required when building a
|
|
||||||
new model.
|
|
||||||
preprocessor : PreprocessorProtocol, optional
|
preprocessor : PreprocessorProtocol, optional
|
||||||
Pre-built preprocessor. If given, overrides
|
Pre-built preprocessor. If given, overrides
|
||||||
``config.preprocess`` and ``config.samplerate`` for the
|
``config.preprocess`` and ``config.samplerate`` for the
|
||||||
@ -254,20 +278,19 @@ def build_model(
|
|||||||
"""
|
"""
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
|
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
|
targets = targets or build_targets(config=config.targets)
|
||||||
|
|
||||||
if isinstance(config, dict):
|
targets_config = getattr(targets, "config", None)
|
||||||
config = ModelConfig.model_validate(config)
|
roi_config = (
|
||||||
|
targets_config.roi
|
||||||
if class_names is None:
|
if isinstance(targets_config, TargetConfig)
|
||||||
raise ValueError("class_names must be provided when building a model.")
|
else config.targets.roi
|
||||||
|
|
||||||
if dimension_names is None:
|
|
||||||
raise ValueError(
|
|
||||||
"dimension_names must be provided when building a model."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
config=config.preprocess,
|
config=config.preprocess,
|
||||||
input_samplerate=config.samplerate,
|
input_samplerate=config.samplerate,
|
||||||
@ -277,16 +300,16 @@ def build_model(
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(class_names),
|
num_classes=len(targets.class_names),
|
||||||
num_sizes=len(dimension_names),
|
num_sizes=len(roi_mapper.dimension_names),
|
||||||
config=config.architecture,
|
config=config.architecture,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
class_names=class_names,
|
targets=targets,
|
||||||
dimension_names=dimension_names,
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -306,6 +329,6 @@ def build_model_with_new_targets(
|
|||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=model.postprocessor,
|
postprocessor=model.postprocessor,
|
||||||
preprocessor=model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
class_names=targets.class_names,
|
targets=targets,
|
||||||
dimension_names=roi_mapper.dimension_names,
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -53,12 +53,8 @@ import torch.nn.functional as F
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core import (
|
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||||
BaseConfig,
|
from batdetect2.core.configs import BaseConfig
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BlockImportConfig",
|
"BlockImportConfig",
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from batdetect2.plotting.gallery import plot_match_gallery
|
|||||||
from batdetect2.plotting.heatmaps import (
|
from batdetect2.plotting.heatmaps import (
|
||||||
plot_classification_heatmap,
|
plot_classification_heatmap,
|
||||||
plot_detection_heatmap,
|
plot_detection_heatmap,
|
||||||
plot_size_heatmap,
|
|
||||||
)
|
)
|
||||||
from batdetect2.plotting.matches import (
|
from batdetect2.plotting.matches import (
|
||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
@ -26,6 +25,5 @@ __all__ = [
|
|||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_detection_heatmap",
|
"plot_detection_heatmap",
|
||||||
"plot_classification_heatmap",
|
"plot_classification_heatmap",
|
||||||
"plot_size_heatmap",
|
|
||||||
"plot_match_gallery",
|
"plot_match_gallery",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
"""Plot heatmaps."""
|
"""Plot heatmaps"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -8,12 +8,6 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
|||||||
|
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"plot_detection_heatmap",
|
|
||||||
"plot_classification_heatmap",
|
|
||||||
"plot_size_heatmap",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def plot_detection_heatmap(
|
def plot_detection_heatmap(
|
||||||
heatmap: torch.Tensor | np.ndarray,
|
heatmap: torch.Tensor | np.ndarray,
|
||||||
@ -114,91 +108,7 @@ def plot_classification_heatmap(
|
|||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_size_heatmap(
|
def create_colormap(color: str) -> Colormap:
|
||||||
heatmap: torch.Tensor | np.ndarray,
|
|
||||||
dimension_names: list[str],
|
|
||||||
ax: axes.Axes | None = None,
|
|
||||||
figsize: tuple[int, int] = (10, 10),
|
|
||||||
color: str = "crimson",
|
|
||||||
size: float = 20,
|
|
||||||
fontsize: float = 8,
|
|
||||||
) -> axes.Axes:
|
|
||||||
"""Plot sparse size labels from a size heatmap.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
heatmap : torch.Tensor | np.ndarray
|
|
||||||
Size heatmap with shape ``[num_dims, height, width]``. Entries are
|
|
||||||
expected to be zero everywhere except at labelled positions.
|
|
||||||
dimension_names : list[str]
|
|
||||||
Names corresponding to the first heatmap dimension.
|
|
||||||
ax : matplotlib.axes.Axes | None, default=None
|
|
||||||
Axis to plot on. If ``None``, a new axis is created.
|
|
||||||
figsize : tuple[int, int], default=(10, 10)
|
|
||||||
Figure size used when creating a new axis.
|
|
||||||
color : str, default="crimson"
|
|
||||||
Color used for scatter points and text labels.
|
|
||||||
size : float, default=20
|
|
||||||
Marker size for plotted points.
|
|
||||||
fontsize : float, default=8
|
|
||||||
Font size used for the text labels.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
matplotlib.axes.Axes
|
|
||||||
Axis containing the plotted size labels.
|
|
||||||
"""
|
|
||||||
ax = create_ax(ax, figsize=figsize)
|
|
||||||
|
|
||||||
if isinstance(heatmap, torch.Tensor):
|
|
||||||
heatmap = heatmap.numpy()
|
|
||||||
|
|
||||||
if heatmap.ndim == 4:
|
|
||||||
heatmap = heatmap[0]
|
|
||||||
|
|
||||||
if heatmap.ndim != 3:
|
|
||||||
raise ValueError("Expecting a 3-dimensional array")
|
|
||||||
|
|
||||||
if len(dimension_names) != heatmap.shape[0]:
|
|
||||||
raise ValueError("Inconsistent number of dimension names")
|
|
||||||
|
|
||||||
point_mask = np.any(heatmap != 0, axis=0)
|
|
||||||
rows, cols = np.nonzero(point_mask)
|
|
||||||
|
|
||||||
if len(rows) == 0:
|
|
||||||
return ax
|
|
||||||
|
|
||||||
ax.scatter(cols, rows, c=color, s=size)
|
|
||||||
|
|
||||||
for row, col in zip(rows, cols, strict=False):
|
|
||||||
values = heatmap[:, row, col]
|
|
||||||
labels = [
|
|
||||||
f"{name}={value:.2f}"
|
|
||||||
for name, value in zip(
|
|
||||||
dimension_names,
|
|
||||||
values,
|
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
if value != 0
|
|
||||||
]
|
|
||||||
ax.text(
|
|
||||||
float(col),
|
|
||||||
float(row),
|
|
||||||
"\n".join(labels),
|
|
||||||
fontsize=fontsize,
|
|
||||||
color=color,
|
|
||||||
va="bottom",
|
|
||||||
ha="left",
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(0, heatmap.shape[2])
|
|
||||||
ax.set_ylim(0, heatmap.shape[1])
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def create_colormap(
|
|
||||||
color: str | tuple[float, float, float, float],
|
|
||||||
) -> Colormap:
|
|
||||||
(r, g, b, a) = to_rgba(color)
|
(r, g, b, a) = to_rgba(color)
|
||||||
return LinearSegmentedColormap.from_list(
|
return LinearSegmentedColormap.from_list(
|
||||||
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from batdetect2.targets.classes import (
|
|||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.config import TargetConfig, build_default_target_config
|
from batdetect2.targets.config import TargetConfig
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
ROIMapperConfig,
|
ROIMapperConfig,
|
||||||
@ -36,14 +36,13 @@ from batdetect2.targets.types import (
|
|||||||
SoundEventFilter,
|
SoundEventFilter,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.utils import check_target_compatibility
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnchorBBoxMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"Position",
|
"Position",
|
||||||
"ROIMapperConfig",
|
|
||||||
"ROIMapperProtocol",
|
|
||||||
"ROIMappingConfig",
|
"ROIMappingConfig",
|
||||||
|
"ROIMapperProtocol",
|
||||||
|
"ROIMapperConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"Size",
|
"Size",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
@ -53,14 +52,12 @@ __all__ = [
|
|||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"Targets",
|
"Targets",
|
||||||
"build_default_target_config",
|
|
||||||
"build_roi_mapper",
|
|
||||||
"build_roi_mapping",
|
"build_roi_mapping",
|
||||||
|
"build_roi_mapper",
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
"build_targets",
|
"build_targets",
|
||||||
"call_type",
|
"call_type",
|
||||||
"check_target_compatibility",
|
|
||||||
"data_source",
|
"data_source",
|
||||||
"generic_class",
|
"generic_class",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from batdetect2.data.conditions import (
|
|||||||
NotConfig,
|
NotConfig,
|
||||||
SoundEventCondition,
|
SoundEventCondition,
|
||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
TagInfo,
|
|
||||||
build_sound_event_condition,
|
build_sound_event_condition,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import call_type, generic_class
|
from batdetect2.targets.terms import call_type, generic_class
|
||||||
@ -33,12 +32,11 @@ class TargetClassConfig(BaseConfig):
|
|||||||
condition_input: SoundEventConditionConfig | None = Field(
|
condition_input: SoundEventConditionConfig | None = Field(
|
||||||
alias="match_if",
|
alias="match_if",
|
||||||
default=None,
|
default=None,
|
||||||
exclude=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
assign_tags: List[TagInfo] = Field(default_factory=list)
|
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||||
|
|
||||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from collections import Counter
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
@ -14,7 +13,6 @@ from batdetect2.targets.rois import ROIMappingConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"build_default_target_config",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -44,20 +42,3 @@ class TargetConfig(BaseConfig):
|
|||||||
f"{', '.join(duplicates)}"
|
f"{', '.join(duplicates)}"
|
||||||
)
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def build_default_target_config(class_names: list[str]) -> TargetConfig:
|
|
||||||
"""Build a default target configuration object."""
|
|
||||||
return TargetConfig(
|
|
||||||
detection_target=DEFAULT_DETECTION_CLASS,
|
|
||||||
classification_targets=[
|
|
||||||
TargetClassConfig(
|
|
||||||
name=class_name,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="class", value=class_name),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
for class_name in class_names
|
|
||||||
],
|
|
||||||
roi=ROIMappingConfig(),
|
|
||||||
)
|
|
||||||
|
|||||||
@ -50,31 +50,21 @@ class Targets(TargetProtocol):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self._filter_fn = build_sound_event_condition(
|
self._filter_fn = build_sound_event_condition(
|
||||||
self.config.detection_target.match_if
|
config.detection_target.match_if
|
||||||
)
|
)
|
||||||
self._encode_fn = build_sound_event_encoder(
|
self._encode_fn = build_sound_event_encoder(
|
||||||
self.config.classification_targets
|
config.classification_targets
|
||||||
)
|
)
|
||||||
self._decode_fn = build_sound_event_decoder(
|
self._decode_fn = build_sound_event_decoder(
|
||||||
self.config.classification_targets
|
config.classification_targets
|
||||||
)
|
)
|
||||||
|
|
||||||
self.class_names = get_class_names_from_config(
|
self.class_names = get_class_names_from_config(
|
||||||
self.config.classification_targets
|
config.classification_targets
|
||||||
)
|
)
|
||||||
|
|
||||||
self.detection_class_name = self.config.detection_target.name
|
self.detection_class_name = config.detection_target.name
|
||||||
self.detection_class_tags = self.config.detection_target.assign_tags
|
self.detection_class_tags = config.detection_target.assign_tags
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: dict) -> "Targets":
|
|
||||||
"""Build a Targets object from a serialized config dictionary."""
|
|
||||||
validated_config = TargetConfig.model_validate(config)
|
|
||||||
return cls(config=validated_config)
|
|
||||||
|
|
||||||
def get_config(self) -> dict:
|
|
||||||
"""Return the serialized target config used to build this object."""
|
|
||||||
return self.config.model_dump(mode="json")
|
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
"""Apply the configured filter to a sound event annotation.
|
"""Apply the configured filter to a sound event annotation.
|
||||||
@ -141,7 +131,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_targets(config: TargetConfig | dict | None = None) -> Targets:
|
def build_targets(config: TargetConfig | None = None) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -163,10 +153,6 @@ def build_targets(config: TargetConfig | dict | None = None) -> Targets:
|
|||||||
If dynamic import of a derivation function fails (when configured).
|
If dynamic import of a derivation function fails (when configured).
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_TARGET_CONFIG
|
config = config or DEFAULT_TARGET_CONFIG
|
||||||
|
|
||||||
if not isinstance(config, TargetConfig):
|
|
||||||
config = TargetConfig.model_validate(config)
|
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building targets with config: \n{}",
|
"Building targets with config: \n{}",
|
||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
|
|||||||
@ -28,11 +28,6 @@ class TargetProtocol(Protocol):
|
|||||||
detection_class_tags: list[data.Tag]
|
detection_class_tags: list[data.Tag]
|
||||||
detection_class_name: str
|
detection_class_name: str
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: dict) -> "TargetProtocol": ...
|
|
||||||
|
|
||||||
def get_config(self) -> dict: ...
|
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||||
|
|
||||||
def encode_class(
|
def encode_class(
|
||||||
|
|||||||
@ -1,29 +0,0 @@
|
|||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
def check_target_compatibility(
|
|
||||||
targets: "TargetProtocol",
|
|
||||||
class_names: list[str],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a target definition can decode a model's outputs.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
targets : TargetProtocol
|
|
||||||
Target definition that would be used with the model outputs.
|
|
||||||
class_names : list[str]
|
|
||||||
Class names produced by the model checkpoint.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True when every model class name exists in the provided targets,
|
|
||||||
False otherwise.
|
|
||||||
"""
|
|
||||||
target_class_names = set(targets.class_names)
|
|
||||||
model_class_names = set(class_names)
|
|
||||||
|
|
||||||
return model_class_names.issubset(target_class_names)
|
|
||||||
@ -4,24 +4,10 @@ from batdetect2.train.lightning import (
|
|||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
)
|
)
|
||||||
from batdetect2.train.logging import (
|
|
||||||
ConfigHyperparameterLogging,
|
|
||||||
DatasetConfigArtifact,
|
|
||||||
DatasetConfigArtifactLogging,
|
|
||||||
DataSummaryArtifactLogging,
|
|
||||||
TargetConfigArtifactLogging,
|
|
||||||
TrainLoggingContext,
|
|
||||||
)
|
|
||||||
from batdetect2.train.train import build_trainer, run_train
|
from batdetect2.train.train import build_trainer, run_train
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConfigHyperparameterLogging",
|
|
||||||
"DataSummaryArtifactLogging",
|
|
||||||
"DEFAULT_CHECKPOINT_DIR",
|
"DEFAULT_CHECKPOINT_DIR",
|
||||||
"DatasetConfigArtifact",
|
|
||||||
"DatasetConfigArtifactLogging",
|
|
||||||
"TargetConfigArtifactLogging",
|
|
||||||
"TrainLoggingContext",
|
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from batdetect2.logging import get_image_logger
|
|||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.types import TrainExample
|
from batdetect2.train.types import TrainExample
|
||||||
@ -20,15 +19,11 @@ class ValidationMetrics(Callback):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
targets: TargetProtocol,
|
|
||||||
roi_mapper: ROIMapperProtocol,
|
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.targets = targets
|
|
||||||
self.roi_mapper = roi_mapper
|
|
||||||
self.output_transform = output_transform
|
self.output_transform = output_transform
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
@ -98,8 +93,8 @@ class ValidationMetrics(Callback):
|
|||||||
model = pl_module.model
|
model = pl_module.model
|
||||||
if self.output_transform is None:
|
if self.output_transform is None:
|
||||||
self.output_transform = build_output_transform(
|
self.output_transform = build_output_transform(
|
||||||
targets=self.targets,
|
targets=model.targets,
|
||||||
roi_mapper=self.roi_mapper,
|
roi_mapper=model.roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_transform = self.output_transform
|
output_transform = self.output_transform
|
||||||
|
|||||||
@ -34,8 +34,6 @@ def build_checkpoint_callback(
|
|||||||
if checkpoint_dir is None:
|
if checkpoint_dir is None:
|
||||||
checkpoint_dir = config.checkpoint_dir
|
checkpoint_dir = config.checkpoint_dir
|
||||||
|
|
||||||
checkpoint_dir = Path(checkpoint_dir)
|
|
||||||
|
|
||||||
if experiment_name is not None:
|
if experiment_name is not None:
|
||||||
checkpoint_dir = checkpoint_dir / experiment_name
|
checkpoint_dir = checkpoint_dir / experiment_name
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,8 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.train.optimizers import build_optimizer
|
from batdetect2.train.optimizers import build_optimizer
|
||||||
@ -14,7 +11,6 @@ from batdetect2.train.types import LossProtocol, TrainExample
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"load_model_from_checkpoint",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -25,9 +21,6 @@ class TrainingModule(L.LightningModule):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
targets_config: dict | None = None,
|
|
||||||
class_names: list[str] | None = None,
|
|
||||||
dimension_names: list[str] | None = None,
|
|
||||||
train_config: dict | None = None,
|
train_config: dict | None = None,
|
||||||
loss: LossProtocol | None = None,
|
loss: LossProtocol | None = None,
|
||||||
model: Model | None = None,
|
model: Model | None = None,
|
||||||
@ -36,34 +29,14 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||||
|
|
||||||
self.model_config: dict = model_config or {}
|
self.model_config = ModelConfig.model_validate(model_config or {})
|
||||||
self.targets_config: dict = targets_config or {}
|
|
||||||
self.class_names = list(class_names or [])
|
|
||||||
self.dimension_names = list(dimension_names or [])
|
|
||||||
|
|
||||||
self.train_config = TrainingConfig.model_validate(train_config or {})
|
self.train_config = TrainingConfig.model_validate(train_config or {})
|
||||||
|
|
||||||
if loss is None:
|
if loss is None:
|
||||||
loss = build_loss(config=self.train_config.loss)
|
loss = build_loss(config=self.train_config.loss)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
if not self.class_names:
|
model = build_model(config=self.model_config)
|
||||||
raise ValueError(
|
|
||||||
"class_names must be provided when rebuilding a training "
|
|
||||||
"module without a model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.dimension_names:
|
|
||||||
raise ValueError(
|
|
||||||
"dimension_names must be provided when rebuilding a "
|
|
||||||
"training module without a model."
|
|
||||||
)
|
|
||||||
|
|
||||||
model = build_model(
|
|
||||||
config=self.model_config,
|
|
||||||
class_names=self.class_names,
|
|
||||||
dimension_names=self.dimension_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -122,16 +95,9 @@ class TrainingModule(L.LightningModule):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StoredConfig:
|
|
||||||
model: ModelConfig
|
|
||||||
targets: TargetConfig
|
|
||||||
train: TrainingConfig
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
) -> tuple[Model, StoredConfig]:
|
) -> tuple[Model, ModelConfig]:
|
||||||
"""Load a model and its configuration from a Lightning checkpoint.
|
"""Load a model and its configuration from a Lightning checkpoint.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -144,24 +110,15 @@ def load_model_from_checkpoint(
|
|||||||
-------
|
-------
|
||||||
tuple[Model, ModelConfig]
|
tuple[Model, ModelConfig]
|
||||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||||
describes its architecture, preprocessing, and postprocessing.
|
describes its architecture, preprocessing, postprocessing, and
|
||||||
|
targets.
|
||||||
"""
|
"""
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
training_config = TrainingConfig.model_validate(module.train_config)
|
return module.model, module.model_config
|
||||||
model_config = ModelConfig.model_validate(module.model_config)
|
|
||||||
targets_config = TargetConfig.model_validate(module.targets_config)
|
|
||||||
return module.model, StoredConfig(
|
|
||||||
model=model_config,
|
|
||||||
targets=targets_config,
|
|
||||||
train=training_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model_config: ModelConfig | None = None,
|
model_config: ModelConfig | None = None,
|
||||||
targets_config: TargetConfig | dict | None = None,
|
|
||||||
class_names: list[str] | None = None,
|
|
||||||
dimension_names: list[str] | None = None,
|
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
model: Model | None = None,
|
model: Model | None = None,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
@ -171,16 +128,8 @@ def build_training_module(
|
|||||||
if train_config is None:
|
if train_config is None:
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
|
|
||||||
if targets_config is None:
|
|
||||||
targets_config = TargetConfig()
|
|
||||||
|
|
||||||
targets_config = TargetConfig.model_validate(targets_config)
|
|
||||||
|
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model_config=model_config.model_dump(mode="json"),
|
model_config=model_config.model_dump(mode="json"),
|
||||||
targets_config=targets_config.model_dump(mode="json"),
|
|
||||||
train_config=train_config.model_dump(mode="json"),
|
train_config=train_config.model_dump(mode="json"),
|
||||||
class_names=class_names,
|
|
||||||
dimension_names=dimension_names,
|
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,164 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from lightning.pytorch.loggers import Logger
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.data import Dataset, compute_class_summary
|
|
||||||
from batdetect2.logging import log_config_artifact, log_csv_artifact
|
|
||||||
from batdetect2.models import ModelConfig
|
|
||||||
from batdetect2.targets import TargetConfig, TargetProtocol
|
|
||||||
from batdetect2.train.config import TrainingConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ConfigHyperparameterLogging",
|
|
||||||
"DataSummaryArtifactLogging",
|
|
||||||
"DatasetConfigArtifact",
|
|
||||||
"DatasetConfigArtifactLogging",
|
|
||||||
"TargetConfigArtifactLogging",
|
|
||||||
"TrainLoggingContext",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TrainLoggingContext:
|
|
||||||
model_config: ModelConfig
|
|
||||||
train_config: TrainingConfig
|
|
||||||
audio_config: AudioConfig
|
|
||||||
targets: TargetProtocol
|
|
||||||
train_dataset: Dataset
|
|
||||||
val_dataset: Dataset | None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class DatasetConfigArtifact:
|
|
||||||
filename: str
|
|
||||||
config: BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigHyperparameterLogging:
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
logger: Logger,
|
|
||||||
artifact_path: Path,
|
|
||||||
context: TrainLoggingContext,
|
|
||||||
) -> None:
|
|
||||||
logger.log_hyperparams(
|
|
||||||
{
|
|
||||||
"model": context.model_config.model_dump(
|
|
||||||
mode="json",
|
|
||||||
exclude_none=True,
|
|
||||||
),
|
|
||||||
"training": context.train_config.model_dump(
|
|
||||||
mode="json",
|
|
||||||
exclude_none=True,
|
|
||||||
),
|
|
||||||
"audio": context.audio_config.model_dump(
|
|
||||||
mode="json",
|
|
||||||
exclude_none=True,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TargetConfigArtifactLogging:
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
logger: Logger,
|
|
||||||
artifact_path: Path,
|
|
||||||
context: TrainLoggingContext,
|
|
||||||
) -> None:
|
|
||||||
targets_config = TargetConfig.model_validate(
|
|
||||||
context.targets.get_config()
|
|
||||||
)
|
|
||||||
log_config_artifact(
|
|
||||||
logger,
|
|
||||||
targets_config,
|
|
||||||
filename="targets.yaml",
|
|
||||||
artifact_path=artifact_path / "training_artifacts",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetConfigArtifactLogging:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
train_dataset_config: DatasetConfigArtifact,
|
|
||||||
val_dataset_config: DatasetConfigArtifact | None = None,
|
|
||||||
):
|
|
||||||
self.train_dataset_config = train_dataset_config
|
|
||||||
self.val_dataset_config = val_dataset_config
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
logger: Logger,
|
|
||||||
artifact_path: Path,
|
|
||||||
context: TrainLoggingContext,
|
|
||||||
) -> None:
|
|
||||||
training_artifact_path = artifact_path / "training_artifacts"
|
|
||||||
|
|
||||||
log_config_artifact(
|
|
||||||
logger,
|
|
||||||
self.train_dataset_config.config,
|
|
||||||
filename=self.train_dataset_config.filename,
|
|
||||||
artifact_path=training_artifact_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.val_dataset_config is not None:
|
|
||||||
log_config_artifact(
|
|
||||||
logger,
|
|
||||||
self.val_dataset_config.config,
|
|
||||||
filename=self.val_dataset_config.filename,
|
|
||||||
artifact_path=training_artifact_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DataSummaryArtifactLogging:
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
logger: Logger,
|
|
||||||
artifact_path: Path,
|
|
||||||
context: TrainLoggingContext,
|
|
||||||
) -> None:
|
|
||||||
training_artifact_path = artifact_path / "training_artifacts"
|
|
||||||
|
|
||||||
log_csv_artifact(
|
|
||||||
logger,
|
|
||||||
_compute_class_summary_or_empty(
|
|
||||||
context.train_dataset,
|
|
||||||
context.targets,
|
|
||||||
),
|
|
||||||
filename="train_class_summary.csv",
|
|
||||||
artifact_path=training_artifact_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if context.val_dataset is not None:
|
|
||||||
log_csv_artifact(
|
|
||||||
logger,
|
|
||||||
_compute_class_summary_or_empty(
|
|
||||||
context.val_dataset,
|
|
||||||
context.targets,
|
|
||||||
),
|
|
||||||
filename="val_class_summary.csv",
|
|
||||||
artifact_path=training_artifact_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_class_summary_or_empty(
|
|
||||||
dataset: Sequence[data.ClipAnnotation],
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
try:
|
|
||||||
return compute_class_summary(dataset, targets)
|
|
||||||
except KeyError as error:
|
|
||||||
if error.args != ("class_name",):
|
|
||||||
raise
|
|
||||||
|
|
||||||
return pd.DataFrame(
|
|
||||||
columns=["num calls", "num recordings", "duration", "call_rate"]
|
|
||||||
)
|
|
||||||
@ -3,7 +3,6 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam, Optimizer
|
from torch.optim import Adam, Optimizer
|
||||||
@ -85,10 +84,4 @@ def build_optimizer(
|
|||||||
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
||||||
"""
|
"""
|
||||||
config = config or AdamOptimizerConfig()
|
config = config or AdamOptimizerConfig()
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building optimizer with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_registry.build(config, parameters)
|
return optimizer_registry.build(config, parameters)
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
|
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
|
||||||
@ -79,9 +78,4 @@ def build_scheduler(
|
|||||||
"""Build a scheduler from configuration."""
|
"""Build a scheduler from configuration."""
|
||||||
config = config or CosineAnnealingSchedulerConfig()
|
config = config or CosineAnnealingSchedulerConfig()
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building scheduler with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return scheduler_registry.build(config, optimizer)
|
return scheduler_registry.build(config, optimizer)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from lightning import Trainer, seed_everything
|
from lightning import Trainer, seed_everything
|
||||||
from lightning.pytorch.loggers import Logger
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -11,7 +10,6 @@ from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
|||||||
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
||||||
from batdetect2.logging import (
|
from batdetect2.logging import (
|
||||||
LoggerConfig,
|
LoggerConfig,
|
||||||
LoggingCallback,
|
|
||||||
TensorBoardLoggerConfig,
|
TensorBoardLoggerConfig,
|
||||||
build_logger,
|
build_logger,
|
||||||
)
|
)
|
||||||
@ -19,7 +17,6 @@ from batdetect2.models import Model, ModelConfig, build_model
|
|||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
ROIMapperProtocol,
|
ROIMapperProtocol,
|
||||||
TargetConfig,
|
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
build_roi_mapping,
|
build_roi_mapping,
|
||||||
build_targets,
|
build_targets,
|
||||||
@ -30,12 +27,6 @@ from batdetect2.train.config import TrainingConfig
|
|||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
from batdetect2.train.logging import (
|
|
||||||
ConfigHyperparameterLogging,
|
|
||||||
DataSummaryArtifactLogging,
|
|
||||||
TargetConfigArtifactLogging,
|
|
||||||
TrainLoggingContext,
|
|
||||||
)
|
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -44,9 +35,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
|
||||||
|
|
||||||
|
|
||||||
def run_train(
|
def run_train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
@ -58,7 +46,6 @@ def run_train(
|
|||||||
labeller: Optional["ClipLabeller"] = None,
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
audio_config: Optional[AudioConfig] = None,
|
audio_config: Optional[AudioConfig] = None,
|
||||||
model_config: Optional[ModelConfig] = None,
|
model_config: Optional[ModelConfig] = None,
|
||||||
targets_config: TargetConfig | None = None,
|
|
||||||
train_config: Optional[TrainingConfig] = None,
|
train_config: Optional[TrainingConfig] = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
trainer: Trainer | None = None,
|
trainer: Trainer | None = None,
|
||||||
@ -70,42 +57,27 @@ def run_train(
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
|
||||||
):
|
):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
model_config = model_config or ModelConfig()
|
model_config = model_config or ModelConfig()
|
||||||
targets_config = targets_config or TargetConfig()
|
|
||||||
audio_config = audio_config or AudioConfig()
|
audio_config = audio_config or AudioConfig()
|
||||||
train_config = train_config or TrainingConfig()
|
train_config = train_config or TrainingConfig()
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
if targets is None:
|
_validate_model_compatibility(model=model, model_config=model_config)
|
||||||
raise ValueError(
|
|
||||||
"targets must be provided when training with an existing "
|
|
||||||
"model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if roi_mapper is None:
|
|
||||||
raise ValueError(
|
|
||||||
"roi_mapper must be provided when training with an existing "
|
|
||||||
"model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if targets is None:
|
|
||||||
targets = build_targets(config=targets_config)
|
|
||||||
else:
|
|
||||||
targets_config = TargetConfig.model_validate(targets.get_config())
|
|
||||||
|
|
||||||
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
|
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
_validate_model_compatibility(
|
targets = targets or model.targets
|
||||||
model=model,
|
|
||||||
model_config=model_config,
|
if roi_mapper is None and targets is model.targets:
|
||||||
class_names=targets.class_names,
|
roi_mapper = model.roi_mapper
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
|
targets = targets or build_targets(config=model_config.targets)
|
||||||
|
|
||||||
|
roi_mapper = roi_mapper or build_roi_mapping(
|
||||||
|
config=model_config.targets.roi
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
@ -147,57 +119,21 @@ def run_train(
|
|||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
targets_config=targets_config,
|
|
||||||
class_names=targets.class_names,
|
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trainer = trainer or build_trainer(
|
||||||
|
train_config,
|
||||||
|
logger_config=logger_config,
|
||||||
evaluator=build_evaluator(
|
evaluator=build_evaluator(
|
||||||
train_config.validation,
|
train_config.validation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
)
|
),
|
||||||
|
|
||||||
train_logger = build_logger(
|
|
||||||
logger_config or TensorBoardLoggerConfig(),
|
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
)
|
|
||||||
root_artifact_path = (
|
|
||||||
Path(log_dir) if log_dir is not None else DEFAULT_LOG_DIR
|
|
||||||
)
|
|
||||||
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
logging_context = TrainLoggingContext(
|
|
||||||
model_config=model_config,
|
|
||||||
train_config=train_config,
|
|
||||||
audio_config=audio_config,
|
|
||||||
targets=targets,
|
|
||||||
train_dataset=train_annotations,
|
|
||||||
val_dataset=val_annotations,
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved_logging_callbacks = (
|
|
||||||
ConfigHyperparameterLogging(),
|
|
||||||
TargetConfigArtifactLogging(),
|
|
||||||
DataSummaryArtifactLogging(),
|
|
||||||
*logging_callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
for callback in resolved_logging_callbacks:
|
|
||||||
callback.run(train_logger, root_artifact_path, logging_context)
|
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
|
||||||
train_config,
|
|
||||||
train_logger=train_logger,
|
|
||||||
evaluator=evaluator,
|
|
||||||
targets=targets,
|
|
||||||
roi_mapper=roi_mapper,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
@ -216,14 +152,8 @@ def run_train(
|
|||||||
def _validate_model_compatibility(
|
def _validate_model_compatibility(
|
||||||
model: Model,
|
model: Model,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
class_names: list[str],
|
|
||||||
dimension_names: list[str],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
reference_model = build_model(
|
reference_model = build_model(config=model_config)
|
||||||
config=model_config,
|
|
||||||
class_names=class_names,
|
|
||||||
dimension_names=dimension_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_shapes = {
|
expected_shapes = {
|
||||||
key: tuple(value.shape)
|
key: tuple(value.shape)
|
||||||
@ -264,11 +194,10 @@ def _validate_model_compatibility(
|
|||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
train_logger: Logger,
|
logger_config: LoggerConfig | None,
|
||||||
evaluator: "EvaluatorProtocol",
|
evaluator: "EvaluatorProtocol",
|
||||||
targets: "TargetProtocol",
|
|
||||||
roi_mapper: "ROIMapperProtocol",
|
|
||||||
checkpoint_dir: Path | None = None,
|
checkpoint_dir: Path | None = None,
|
||||||
|
log_dir: Path | None = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
@ -279,11 +208,25 @@ def build_trainer(
|
|||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_epochs is not None:
|
train_logger = build_logger(
|
||||||
trainer_conf.max_epochs = num_epochs
|
logger_config or TensorBoardLoggerConfig(),
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_logger.log_hyperparams(
|
||||||
|
config.model_dump(
|
||||||
|
mode="json",
|
||||||
|
exclude_none=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
train_config = trainer_conf.model_dump(exclude_none=True)
|
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
if num_epochs is not None:
|
||||||
|
train_config["max_epochs"] = num_epochs
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**train_config,
|
**train_config,
|
||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
@ -294,6 +237,6 @@ def build_trainer(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
),
|
),
|
||||||
ValidationMetrics(evaluator, targets, roi_mapper),
|
ValidationMetrics(evaluator),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,14 +13,13 @@ from soundevent import data, terms
|
|||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.audio.clips import build_clipper
|
from batdetect2.audio.clips import build_clipper
|
||||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
ROIMapperProtocol,
|
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_roi_mapping,
|
|
||||||
build_targets,
|
build_targets,
|
||||||
call_type,
|
call_type,
|
||||||
)
|
)
|
||||||
@ -405,13 +404,6 @@ def sample_targets(
|
|||||||
return build_targets(sample_target_config)
|
return build_targets(sample_target_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_roi_mapper(
|
|
||||||
sample_target_config: TargetConfig,
|
|
||||||
) -> ROIMapperProtocol:
|
|
||||||
return build_roi_mapping(sample_target_config.roi)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_labeller(
|
def sample_labeller(
|
||||||
sample_targets: TargetProtocol,
|
sample_targets: TargetProtocol,
|
||||||
@ -466,16 +458,8 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tiny_checkpoint_path(
|
def tiny_checkpoint_path(tmp_path: Path) -> Path:
|
||||||
sample_targets: TargetProtocol,
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
sample_roi_mapper: ROIMapperProtocol,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> Path:
|
|
||||||
module = build_training_module(
|
|
||||||
targets_config=sample_targets.get_config(),
|
|
||||||
class_names=sample_targets.class_names,
|
|
||||||
dimension_names=sample_roi_mapper.dimension_names,
|
|
||||||
)
|
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
checkpoint_path = tmp_path / "model.ckpt"
|
checkpoint_path = tmp_path / "model.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
|
|||||||
@ -8,43 +8,20 @@ import torch
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.models.detectors import Detector
|
from batdetect2.models.detectors import Detector
|
||||||
from batdetect2.targets import TargetConfig
|
from batdetect2.models.heads import ClassifierHead
|
||||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
from batdetect2.train import load_model_from_checkpoint
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def train_config() -> TrainingConfig:
|
def api_v2() -> BatDetect2API:
|
||||||
"""Train config with a small batch size for testing."""
|
|
||||||
return TrainingConfig.model_validate({"train_loader": {"batch_size": 2}})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def inference_config() -> InferenceConfig:
|
|
||||||
"""Inference config with a small batch size for testing."""
|
|
||||||
return InferenceConfig.model_validate({"loader": {"batch_size": 2}})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def example_targets_config(example_data_dir: Path) -> TargetConfig:
|
|
||||||
return TargetConfig.load(example_data_dir / "targets.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def api_v2(
|
|
||||||
train_config: TrainingConfig,
|
|
||||||
inference_config: InferenceConfig,
|
|
||||||
) -> BatDetect2API:
|
|
||||||
"""User story: users can create a ready-to-use API from config."""
|
"""User story: users can create a ready-to-use API from config."""
|
||||||
|
|
||||||
api = BatDetect2API.from_config(
|
config = BatDetect2Config()
|
||||||
train_config=train_config,
|
config.inference.loader.batch_size = 2
|
||||||
inference_config=inference_config,
|
return BatDetect2API.from_config(config)
|
||||||
)
|
|
||||||
assert api.inference_config.loader.batch_size == 2
|
|
||||||
return api
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_returns_recording_level_predictions(
|
def test_process_file_returns_recording_level_predictions(
|
||||||
@ -53,10 +30,8 @@ def test_process_file_returns_recording_level_predictions(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: process a file and get detections in recording time."""
|
"""User story: process a file and get detections in recording time."""
|
||||||
|
|
||||||
# When
|
|
||||||
prediction = api_v2.process_file(example_audio_files[0])
|
prediction = api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
# Then
|
|
||||||
assert prediction.clip.recording.path == example_audio_files[0]
|
assert prediction.clip.recording.path == example_audio_files[0]
|
||||||
assert prediction.clip.start_time == 0
|
assert prediction.clip.start_time == 0
|
||||||
assert prediction.clip.end_time == prediction.clip.recording.duration
|
assert prediction.clip.end_time == prediction.clip.recording.duration
|
||||||
@ -78,11 +53,9 @@ def test_process_files_is_batch_size_invariant(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: changing batch size should not change predictions."""
|
"""User story: changing batch size should not change predictions."""
|
||||||
|
|
||||||
# When
|
|
||||||
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
|
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
|
||||||
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
|
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
|
||||||
|
|
||||||
# Then
|
|
||||||
assert len(preds_batch_1) == len(preds_batch_3)
|
assert len(preds_batch_1) == len(preds_batch_3)
|
||||||
|
|
||||||
by_key_1 = {
|
by_key_1 = {
|
||||||
@ -118,14 +91,12 @@ def test_process_audio_matches_process_spectrogram(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: users can call either audio or spectrogram entrypoint."""
|
"""User story: users can call either audio or spectrogram entrypoint."""
|
||||||
|
|
||||||
# When
|
|
||||||
audio = api_v2.load_audio(example_audio_files[0])
|
audio = api_v2.load_audio(example_audio_files[0])
|
||||||
from_audio = api_v2.process_audio(audio)
|
from_audio = api_v2.process_audio(audio)
|
||||||
|
|
||||||
spec = api_v2.generate_spectrogram(audio)
|
spec = api_v2.generate_spectrogram(audio)
|
||||||
from_spec = api_v2.process_spectrogram(spec)
|
from_spec = api_v2.process_spectrogram(spec)
|
||||||
|
|
||||||
# Then
|
|
||||||
assert len(from_audio) == len(from_spec)
|
assert len(from_audio) == len(from_spec)
|
||||||
|
|
||||||
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
|
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
|
||||||
@ -145,10 +116,8 @@ def test_process_spectrogram_rejects_batched_input(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: invalid batched input gives a clear error."""
|
"""User story: invalid batched input gives a clear error."""
|
||||||
|
|
||||||
# Given
|
|
||||||
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
|
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
|
||||||
|
|
||||||
# When/Then
|
|
||||||
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
||||||
api_v2.process_spectrogram(spec)
|
api_v2.process_spectrogram(spec)
|
||||||
|
|
||||||
@ -215,35 +184,26 @@ def test_user_can_read_extracted_features_per_detection(
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_user_can_load_checkpoint_and_finetune(
|
def test_user_can_load_checkpoint_and_finetune(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
example_targets_config: TargetConfig,
|
|
||||||
example_annotations,
|
example_annotations,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: load a checkpoint and continue training from it."""
|
"""User story: load a checkpoint and continue training from it."""
|
||||||
|
|
||||||
api = BatDetect2API.from_config(
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
targets_config=example_targets_config,
|
|
||||||
)
|
|
||||||
module = build_training_module(
|
|
||||||
model_config=api.model_config,
|
|
||||||
targets_config=example_targets_config,
|
|
||||||
class_names=api.targets.class_names,
|
|
||||||
dimension_names=api.roi_mapper.dimension_names,
|
|
||||||
)
|
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
checkpoint_path = tmp_path / "base.ckpt"
|
checkpoint_path = tmp_path / "base.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(checkpoint_path)
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
train_config = api.train_config.model_copy(deep=True)
|
config = BatDetect2Config()
|
||||||
train_config.trainer.limit_train_batches = 1
|
config.train.trainer.limit_train_batches = 1
|
||||||
train_config.trainer.limit_val_batches = 1
|
config.train.trainer.limit_val_batches = 1
|
||||||
train_config.trainer.log_every_n_steps = 1
|
config.train.trainer.log_every_n_steps = 1
|
||||||
train_config.train_loader.batch_size = 1
|
config.train.train_loader.batch_size = 1
|
||||||
train_config.train_loader.augmentations.enabled = False
|
config.train.train_loader.augmentations.enabled = False
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
train_config=train_config,
|
train_config=config.train,
|
||||||
)
|
)
|
||||||
finetune_dir = tmp_path / "finetuned"
|
finetune_dir = tmp_path / "finetuned"
|
||||||
|
|
||||||
@ -262,34 +222,62 @@ def test_user_can_load_checkpoint_and_finetune(
|
|||||||
assert checkpoints
|
assert checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_load_checkpoint_with_new_targets(
|
||||||
|
tmp_path: Path,
|
||||||
|
sample_targets,
|
||||||
|
) -> None:
|
||||||
|
"""User story: start from checkpoint with a new target definition."""
|
||||||
|
|
||||||
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
|
checkpoint_path = tmp_path / "base_transfer.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
|
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||||
|
api = BatDetect2API.from_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
targets_config=sample_targets.config,
|
||||||
|
)
|
||||||
|
source_detector = cast(Detector, source_model.detector)
|
||||||
|
detector = cast(Detector, api.model.detector)
|
||||||
|
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||||
|
|
||||||
|
assert api.targets.config == sample_targets.config # type: ignore
|
||||||
|
assert detector.num_classes == len(sample_targets.class_names)
|
||||||
|
assert (
|
||||||
|
classifier_head.classifier.out_channels
|
||||||
|
== len(sample_targets.class_names) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
source_backbone = source_detector.backbone.state_dict()
|
||||||
|
target_backbone = detector.backbone.state_dict()
|
||||||
|
assert source_backbone
|
||||||
|
for key, value in source_backbone.items():
|
||||||
|
assert key in target_backbone
|
||||||
|
torch.testing.assert_close(target_backbone[key], value)
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||||
example_targets_config: TargetConfig,
|
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: same targets config does not rebuild prediction heads."""
|
"""User story: same targets config does not rebuild prediction heads."""
|
||||||
|
|
||||||
# Given
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
source_api = BatDetect2API.from_config(
|
|
||||||
targets_config=example_targets_config
|
|
||||||
)
|
|
||||||
module = build_training_module(
|
|
||||||
model_config=source_api.model_config,
|
|
||||||
targets_config=example_targets_config,
|
|
||||||
class_names=source_api.targets.class_names,
|
|
||||||
dimension_names=source_api.roi_mapper.dimension_names,
|
|
||||||
)
|
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
checkpoint_path = tmp_path / "same_targets.ckpt"
|
checkpoint_path = tmp_path / "same_targets.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(checkpoint_path)
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
source_model, source_model_config = load_model_from_checkpoint(
|
||||||
|
checkpoint_path
|
||||||
|
)
|
||||||
source_detector = cast(Detector, source_model.detector)
|
source_detector = cast(Detector, source_model.detector)
|
||||||
|
|
||||||
# When
|
api = BatDetect2API.from_checkpoint(
|
||||||
api = BatDetect2API.from_checkpoint(checkpoint_path)
|
checkpoint_path,
|
||||||
|
targets_config=source_model_config.targets,
|
||||||
# Then
|
)
|
||||||
detector = cast(Detector, api.model.detector)
|
detector = cast(Detector, api.model.detector)
|
||||||
|
|
||||||
for key, value in source_detector.classifier_head.state_dict().items():
|
for key, value in source_detector.classifier_head.state_dict().items():
|
||||||
@ -307,6 +295,42 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_user_can_finetune_only_heads(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations,
|
||||||
|
) -> None:
|
||||||
|
"""User story: fine-tune only prediction heads."""
|
||||||
|
|
||||||
|
api = BatDetect2API.from_config(BatDetect2Config())
|
||||||
|
finetune_dir = tmp_path / "heads_only"
|
||||||
|
|
||||||
|
api.finetune(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
trainable="heads",
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=finetune_dir,
|
||||||
|
log_dir=tmp_path / "logs",
|
||||||
|
num_epochs=1,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
detector = cast(Detector, api.model.detector)
|
||||||
|
|
||||||
|
backbone_params = list(detector.backbone.parameters())
|
||||||
|
classifier_params = list(detector.classifier_head.parameters())
|
||||||
|
bbox_params = list(detector.bbox_head.parameters())
|
||||||
|
|
||||||
|
assert backbone_params
|
||||||
|
assert classifier_params
|
||||||
|
assert bbox_params
|
||||||
|
assert all(not parameter.requires_grad for parameter in backbone_params)
|
||||||
|
assert all(parameter.requires_grad for parameter in classifier_params)
|
||||||
|
assert all(parameter.requires_grad for parameter in bbox_params)
|
||||||
|
assert list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
@ -324,6 +348,8 @@ def test_user_can_evaluate_small_dataset_and_get_metrics(
|
|||||||
|
|
||||||
assert isinstance(metrics, list)
|
assert isinstance(metrics, list)
|
||||||
assert len(metrics) == 1
|
assert len(metrics) == 1
|
||||||
|
assert isinstance(metrics[0], dict)
|
||||||
|
assert len(metrics[0]) > 0
|
||||||
assert isinstance(predictions, list)
|
assert isinstance(predictions, list)
|
||||||
assert len(predictions) == 1
|
assert len(predictions) == 1
|
||||||
|
|
||||||
@ -424,17 +450,8 @@ def test_detection_threshold_override_changes_spectrogram_results(
|
|||||||
spec = api_v2.generate_spectrogram(audio)
|
spec = api_v2.generate_spectrogram(audio)
|
||||||
default_detections = api_v2.process_spectrogram(spec)
|
default_detections = api_v2.process_spectrogram(spec)
|
||||||
strict_detections = api_v2.process_spectrogram(
|
strict_detections = api_v2.process_spectrogram(
|
||||||
spec, detection_threshold=1.0
|
spec,
|
||||||
|
detection_threshold=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(strict_detections) <= len(default_detections)
|
assert len(strict_detections) <= len(default_detections)
|
||||||
|
|
||||||
|
|
||||||
def test_user_can_create_api_with_custom_targets_and_model_metadata_matches(
|
|
||||||
sample_targets,
|
|
||||||
) -> None:
|
|
||||||
"""User story: custom targets define model output names for a new API."""
|
|
||||||
|
|
||||||
api = BatDetect2API.from_config(targets_config=sample_targets.config)
|
|
||||||
|
|
||||||
assert api.model.class_names == sample_targets.class_names
|
|
||||||
|
|||||||
@ -1,114 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
|
||||||
from batdetect2.models.detectors import Detector
|
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train import load_model_from_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_user_can_finetune_only_heads(
|
|
||||||
tmp_path: Path,
|
|
||||||
example_annotations,
|
|
||||||
) -> None:
|
|
||||||
"""User story: fine-tune only prediction heads."""
|
|
||||||
|
|
||||||
api = BatDetect2API.from_config()
|
|
||||||
source_classifier_head = api.model.detector.classifier_head
|
|
||||||
source_bbox_head = api.model.detector.bbox_head
|
|
||||||
source_backbone = api.model.detector.backbone
|
|
||||||
finetune_dir = tmp_path / "heads_only"
|
|
||||||
|
|
||||||
finetuned_api = api.finetune(
|
|
||||||
train_annotations=example_annotations[:1],
|
|
||||||
val_annotations=example_annotations[:1],
|
|
||||||
targets_config=TargetConfig(),
|
|
||||||
trainable="heads",
|
|
||||||
train_workers=0,
|
|
||||||
val_workers=0,
|
|
||||||
checkpoint_dir=finetune_dir,
|
|
||||||
log_dir=tmp_path / "logs",
|
|
||||||
num_epochs=1,
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
detector = cast(Detector, finetuned_api.model.detector)
|
|
||||||
|
|
||||||
backbone_params = list(detector.backbone.parameters())
|
|
||||||
classifier_params = list(detector.classifier_head.parameters())
|
|
||||||
bbox_params = list(detector.bbox_head.parameters())
|
|
||||||
|
|
||||||
assert backbone_params
|
|
||||||
assert classifier_params
|
|
||||||
assert bbox_params
|
|
||||||
assert all(not parameter.requires_grad for parameter in backbone_params)
|
|
||||||
assert all(parameter.requires_grad for parameter in classifier_params)
|
|
||||||
assert all(parameter.requires_grad for parameter in bbox_params)
|
|
||||||
assert finetuned_api is not api
|
|
||||||
assert detector.backbone is source_backbone
|
|
||||||
assert detector.classifier_head is not source_classifier_head
|
|
||||||
assert detector.bbox_head is not source_bbox_head
|
|
||||||
assert list(finetune_dir.rglob("*.ckpt"))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_finetune_replaces_targets_and_checkpoint_owns_new_targets(
|
|
||||||
tmp_path: Path,
|
|
||||||
example_annotations,
|
|
||||||
) -> None:
|
|
||||||
"""User story: fine-tuning writes checkpoints with the new targets."""
|
|
||||||
|
|
||||||
source_api = BatDetect2API.from_config()
|
|
||||||
source_evaluator = source_api.evaluator
|
|
||||||
source_formatter = source_api.formatter
|
|
||||||
source_output_transform = source_api.output_transform
|
|
||||||
new_targets = TargetConfig.model_validate(
|
|
||||||
{
|
|
||||||
"classification_targets": [
|
|
||||||
{
|
|
||||||
"name": "single_class",
|
|
||||||
"tags": [{"key": "class", "value": "single_class"}],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"roi": {"mapper": "top_left"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
finetune_dir = tmp_path / "new_targets"
|
|
||||||
|
|
||||||
finetuned_api = source_api.finetune(
|
|
||||||
train_annotations=example_annotations[:1],
|
|
||||||
val_annotations=example_annotations[:1],
|
|
||||||
targets_config=new_targets,
|
|
||||||
trainable="heads",
|
|
||||||
train_workers=0,
|
|
||||||
val_workers=0,
|
|
||||||
checkpoint_dir=finetune_dir,
|
|
||||||
log_dir=tmp_path / "logs",
|
|
||||||
num_epochs=1,
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoints = list(finetune_dir.rglob("*.ckpt"))
|
|
||||||
|
|
||||||
assert source_api.targets.get_config() != new_targets.model_dump(
|
|
||||||
mode="json"
|
|
||||||
)
|
|
||||||
assert finetuned_api.targets.get_config() == new_targets.model_dump(
|
|
||||||
mode="json"
|
|
||||||
)
|
|
||||||
assert finetuned_api.evaluator is not source_evaluator
|
|
||||||
assert finetuned_api.formatter is not source_formatter
|
|
||||||
assert finetuned_api.output_transform is not source_output_transform
|
|
||||||
assert finetuned_api.evaluator.targets is finetuned_api.targets
|
|
||||||
assert finetuned_api.evaluator.transform is finetuned_api.output_transform
|
|
||||||
assert finetuned_api.model.class_names == ["single_class"]
|
|
||||||
assert finetuned_api.model.dimension_names == ["width", "height"]
|
|
||||||
assert checkpoints
|
|
||||||
|
|
||||||
_, configs = load_model_from_checkpoint(checkpoints[0])
|
|
||||||
assert configs.targets.model_dump(mode="json") == new_targets.model_dump(
|
|
||||||
mode="json"
|
|
||||||
)
|
|
||||||
@ -5,6 +5,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.outputs import build_output_formatter
|
from batdetect2.outputs import build_output_formatter
|
||||||
from batdetect2.outputs.formats import (
|
from batdetect2.outputs.formats import (
|
||||||
BatDetect2OutputConfig,
|
BatDetect2OutputConfig,
|
||||||
@ -17,7 +18,7 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
def api_v2() -> BatDetect2API:
|
def api_v2() -> BatDetect2API:
|
||||||
"""User story: API object manages prediction IO formats."""
|
"""User story: API object manages prediction IO formats."""
|
||||||
|
|
||||||
return BatDetect2API.from_config()
|
return BatDetect2API.from_config(BatDetect2Config())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -1,99 +0,0 @@
|
|||||||
"""CLI tests for finetune command."""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from click.testing import CliRunner
|
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_finetune_help() -> None:
|
|
||||||
"""User story: inspect finetune command interface and options."""
|
|
||||||
|
|
||||||
result = CliRunner().invoke(cli, ["finetune", "--help"])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "TRAIN_DATASET" in result.output
|
|
||||||
assert "--model" in result.output
|
|
||||||
assert "--targets" in result.output
|
|
||||||
assert "--training-config" in result.output
|
|
||||||
assert "--audio-config" in result.output
|
|
||||||
assert "--logging-config" in result.output
|
|
||||||
assert "--evaluation-config" not in result.output
|
|
||||||
assert "--inference-config" not in result.output
|
|
||||||
assert "--outputs-config" not in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_finetune_requires_model() -> None:
|
|
||||||
"""User story: finetune requires a checkpoint argument."""
|
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"finetune",
|
|
||||||
"example_data/dataset.yaml",
|
|
||||||
"--targets",
|
|
||||||
"example_data/targets.yaml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code != 0
|
|
||||||
assert "--model" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
|
||||||
"""User story: finetune requires a new target definition."""
|
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"finetune",
|
|
||||||
"example_data/dataset.yaml",
|
|
||||||
"--model",
|
|
||||||
str(tiny_checkpoint_path),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code != 0
|
|
||||||
assert "--targets" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_cli_finetune_from_checkpoint_runs_on_small_dataset(
|
|
||||||
tmp_path: Path,
|
|
||||||
tiny_checkpoint_path: Path,
|
|
||||||
) -> None:
|
|
||||||
"""User story: fine-tune a checkpoint via CLI with new targets."""
|
|
||||||
|
|
||||||
ckpt_dir = tmp_path / "checkpoints"
|
|
||||||
log_dir = tmp_path / "logs"
|
|
||||||
ckpt_dir.mkdir()
|
|
||||||
log_dir.mkdir()
|
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"finetune",
|
|
||||||
"example_data/dataset.yaml",
|
|
||||||
"--val-dataset",
|
|
||||||
"example_data/dataset.yaml",
|
|
||||||
"--model",
|
|
||||||
str(tiny_checkpoint_path),
|
|
||||||
"--targets",
|
|
||||||
"example_data/targets.yaml",
|
|
||||||
"--num-epochs",
|
|
||||||
"1",
|
|
||||||
"--train-workers",
|
|
||||||
"0",
|
|
||||||
"--val-workers",
|
|
||||||
"0",
|
|
||||||
"--ckpt-dir",
|
|
||||||
str(ckpt_dir),
|
|
||||||
"--log-dir",
|
|
||||||
str(log_dir),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1
|
|
||||||
@ -81,24 +81,3 @@ def test_cli_train_rejects_model_and_model_config_together(
|
|||||||
|
|
||||||
assert result.exit_code != 0
|
assert result.exit_code != 0
|
||||||
assert "--model-config cannot be used with --model" in result.output
|
assert "--model-config cannot be used with --model" in result.output
|
||||||
|
|
||||||
|
|
||||||
def test_cli_train_rejects_model_and_targets_together(
|
|
||||||
tiny_checkpoint_path: Path,
|
|
||||||
) -> None:
|
|
||||||
"""User story: checkpoint training does not accept new targets."""
|
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"train",
|
|
||||||
"example_data/dataset.yaml",
|
|
||||||
"--model",
|
|
||||||
str(tiny_checkpoint_path),
|
|
||||||
"--targets",
|
|
||||||
"example_data/targets.yaml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code != 0
|
|
||||||
assert "--targets cannot be used with --model" in result.output
|
|
||||||
|
|||||||
@ -203,8 +203,8 @@ in_channels: 1
|
|||||||
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
||||||
"""load_backbone_config loads the real example config correctly."""
|
"""load_backbone_config loads the real example config correctly."""
|
||||||
config = load_backbone_config(
|
config = load_backbone_config(
|
||||||
example_data_dir / "configs" / "model.yaml",
|
example_data_dir / "config.yaml",
|
||||||
field="architecture",
|
field="model.architecture",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(config, UNetBackboneConfig)
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
|||||||
@ -1,34 +1,9 @@
|
|||||||
import json
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||||
TargetConfig,
|
|
||||||
Targets,
|
|
||||||
build_roi_mapping,
|
|
||||||
build_targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_targets_get_config_returns_a_json_serializable_dict() -> None:
|
|
||||||
targets = build_targets(TargetConfig())
|
|
||||||
|
|
||||||
config_dict = targets.get_config()
|
|
||||||
assert isinstance(config_dict, dict)
|
|
||||||
assert json.dumps(config_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def test_targets_from_config_rebuilds_equivalent_targets() -> None:
|
|
||||||
original = build_targets(TargetConfig())
|
|
||||||
|
|
||||||
rebuilt = Targets.from_config(original.get_config())
|
|
||||||
|
|
||||||
assert rebuilt.class_names == original.class_names
|
|
||||||
assert rebuilt.detection_class_name == original.detection_class_name
|
|
||||||
assert rebuilt.detection_class_tags == original.detection_class_tags
|
|
||||||
assert rebuilt.get_config() == original.get_config()
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_override_default_roi_mapper_per_class(
|
def test_can_override_default_roi_mapper_per_class(
|
||||||
|
|||||||
@ -1,40 +0,0 @@
|
|||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.targets import (
|
|
||||||
TargetClassConfig,
|
|
||||||
TargetConfig,
|
|
||||||
build_targets,
|
|
||||||
check_target_compatibility,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _target_class(name: str) -> TargetClassConfig:
|
|
||||||
return TargetClassConfig(
|
|
||||||
name=name,
|
|
||||||
tags=[data.Tag(key="class", value=name)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_target_compatibility_accepts_superset_targets() -> None:
|
|
||||||
config = TargetConfig(
|
|
||||||
classification_targets=[
|
|
||||||
_target_class("pip35"),
|
|
||||||
_target_class("myo"),
|
|
||||||
_target_class("extra"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
targets = build_targets(config)
|
|
||||||
|
|
||||||
assert check_target_compatibility(targets, ["pip35", "myo"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_target_compatibility_rejects_missing_model_classes() -> None:
|
|
||||||
config = TargetConfig(
|
|
||||||
classification_targets=[
|
|
||||||
_target_class("pip35"),
|
|
||||||
_target_class("myo"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
targets = build_targets(config)
|
|
||||||
|
|
||||||
assert not check_target_compatibility(targets, ["pip35", "nyc"])
|
|
||||||
@ -3,19 +3,20 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.train import TrainingConfig, run_train
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
pytestmark = pytest.mark.slow
|
pytestmark = pytest.mark.slow
|
||||||
|
|
||||||
|
|
||||||
def _build_fast_train_config() -> TrainingConfig:
|
def _build_fast_train_config() -> BatDetect2Config:
|
||||||
config = TrainingConfig()
|
config = BatDetect2Config()
|
||||||
config.trainer.limit_train_batches = 1
|
config.train.trainer.limit_train_batches = 1
|
||||||
config.trainer.limit_val_batches = 1
|
config.train.trainer.limit_val_batches = 1
|
||||||
config.trainer.log_every_n_steps = 1
|
config.train.trainer.log_every_n_steps = 1
|
||||||
config.trainer.check_val_every_n_epoch = 1
|
config.train.trainer.check_val_every_n_epoch = 1
|
||||||
config.train_loader.batch_size = 1
|
config.train.train_loader.batch_size = 1
|
||||||
config.train_loader.augmentations.enabled = False
|
config.train.train_loader.augmentations.enabled = False
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +29,9 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
|
|||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=example_annotations[:1],
|
val_annotations=example_annotations[:1],
|
||||||
train_config=config,
|
train_config=config.train,
|
||||||
|
model_config=config.model,
|
||||||
|
audio_config=config.audio,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
@ -47,12 +50,14 @@ def test_train_without_validation_can_still_save_last_checkpoint(
|
|||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
config = _build_fast_train_config()
|
config = _build_fast_train_config()
|
||||||
config.checkpoints.save_last = True
|
config.train.checkpoints.save_last = True
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=None,
|
val_annotations=None,
|
||||||
train_config=config,
|
train_config=config.train,
|
||||||
|
model_config=config.model,
|
||||||
|
audio_config=config.audio,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
@ -68,14 +73,16 @@ def test_train_controls_which_checkpoints_are_kept(
|
|||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
config = _build_fast_train_config()
|
config = _build_fast_train_config()
|
||||||
config.checkpoints.save_top_k = 1
|
config.train.checkpoints.save_top_k = 1
|
||||||
config.checkpoints.save_last = True
|
config.train.checkpoints.save_last = True
|
||||||
config.checkpoints.filename = "epoch{epoch}"
|
config.train.checkpoints.filename = "epoch{epoch}"
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=example_annotations[:1],
|
val_annotations=example_annotations[:1],
|
||||||
train_config=config,
|
train_config=config.train,
|
||||||
|
model_config=config.model,
|
||||||
|
audio_config=config.audio,
|
||||||
num_epochs=3,
|
num_epochs=3,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
|
|||||||
@ -1,43 +1,12 @@
|
|||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.core import load_config
|
||||||
from batdetect2.inference import InferenceConfig
|
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.models import ModelConfig
|
|
||||||
from batdetect2.outputs import OutputsConfig
|
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train import TrainingConfig
|
|
||||||
|
|
||||||
|
|
||||||
def test_example_split_configs_are_valid(example_data_dir):
|
def test_example_config_is_valid(example_data_dir):
|
||||||
configs_dir = example_data_dir / "configs"
|
conf = load_config(
|
||||||
|
example_data_dir / "config.yaml",
|
||||||
assert isinstance(
|
schema=BatDetect2Config,
|
||||||
AudioConfig.load(configs_dir / "audio.yaml"), AudioConfig
|
extra="forbid",
|
||||||
)
|
strict=True,
|
||||||
assert isinstance(
|
|
||||||
ModelConfig.load(configs_dir / "model.yaml"), ModelConfig
|
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
TargetConfig.load(example_data_dir / "targets.yaml"),
|
|
||||||
TargetConfig,
|
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
TrainingConfig.load(configs_dir / "training.yaml"),
|
|
||||||
TrainingConfig,
|
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
EvaluationConfig.load(configs_dir / "evaluation.yaml"),
|
|
||||||
EvaluationConfig,
|
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
InferenceConfig.load(configs_dir / "inference.yaml"),
|
|
||||||
InferenceConfig,
|
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
OutputsConfig.load(configs_dir / "outputs.yaml"),
|
|
||||||
OutputsConfig,
|
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
AppLoggingConfig.load(configs_dir / "logging.yaml"),
|
|
||||||
AppLoggingConfig,
|
|
||||||
)
|
)
|
||||||
|
assert isinstance(conf, BatDetect2Config)
|
||||||
|
|||||||
@ -10,42 +10,25 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
|||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.models import (
|
from batdetect2.config import BatDetect2Config
|
||||||
ModelConfig,
|
from batdetect2.models import ModelConfig, build_model
|
||||||
build_model,
|
from batdetect2.targets.classes import TargetClassConfig
|
||||||
build_model_with_new_targets,
|
|
||||||
)
|
|
||||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
run_train,
|
run_train,
|
||||||
)
|
)
|
||||||
from batdetect2.train.logging import (
|
|
||||||
DatasetConfigArtifact,
|
|
||||||
DatasetConfigArtifactLogging,
|
|
||||||
)
|
|
||||||
from batdetect2.train.optimizers import AdamOptimizerConfig
|
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||||
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
|
|
||||||
|
|
||||||
def build_default_module(
|
def build_default_module(config: BatDetect2Config | None = None):
|
||||||
target_config: TargetConfig | None = None,
|
config = config or BatDetect2Config()
|
||||||
model_config: ModelConfig | None = None,
|
|
||||||
train_config: TrainingConfig | None = None,
|
|
||||||
):
|
|
||||||
target_config = target_config or TargetConfig()
|
|
||||||
model_config = model_config or ModelConfig()
|
|
||||||
train_config = train_config or TrainingConfig()
|
|
||||||
targets = build_targets(target_config)
|
|
||||||
roi_mapper = build_roi_mapping(target_config.roi)
|
|
||||||
return build_training_module(
|
return build_training_module(
|
||||||
model_config=model_config,
|
model_config=config.model,
|
||||||
class_names=targets.class_names,
|
train_config=config.train,
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
train_config=train_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +64,7 @@ def test_can_save_checkpoint(
|
|||||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
def test_load_model_from_checkpoint_returns_model_and_configs(
|
def test_load_model_from_checkpoint_returns_model_and_config(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
input_model_config = ModelConfig(samplerate=192_000)
|
input_model_config = ModelConfig(samplerate=192_000)
|
||||||
@ -89,13 +72,8 @@ def test_load_model_from_checkpoint_returns_model_and_configs(
|
|||||||
input_model_config.model_dump(mode="json")
|
input_model_config.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
targets_config = TargetConfig()
|
|
||||||
targets = build_targets(targets_config)
|
|
||||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=input_model_config,
|
model_config=input_model_config,
|
||||||
class_names=targets.class_names,
|
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
trainer = L.Trainer()
|
trainer = L.Trainer()
|
||||||
@ -103,20 +81,12 @@ def test_load_model_from_checkpoint_returns_model_and_configs(
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
model, loaded_configs = load_model_from_checkpoint(path)
|
model, loaded_model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
assert loaded_configs.model.model_dump(
|
assert loaded_model_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == expected_model_config.model_dump(mode="json")
|
) == expected_model_config.model_dump(mode="json")
|
||||||
assert loaded_configs.targets.model_dump(
|
|
||||||
mode="json"
|
|
||||||
) == targets_config.model_dump(mode="json")
|
|
||||||
assert loaded_configs.train.model_dump(
|
|
||||||
mode="json"
|
|
||||||
) == train_config.model_dump(mode="json")
|
|
||||||
assert model.class_names == targets.class_names
|
|
||||||
assert model.dimension_names == roi_mapper.dimension_names
|
|
||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
assert recovered.train_config.model_dump(
|
assert recovered.train_config.model_dump(
|
||||||
@ -130,9 +100,6 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
|||||||
model_config.model_dump(mode="json")
|
model_config.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
targets_config = TargetConfig()
|
|
||||||
targets = build_targets(targets_config)
|
|
||||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
|
||||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
|
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
|
||||||
train_config.trainer.max_epochs = 3
|
train_config.trainer.max_epochs = 3
|
||||||
@ -140,8 +107,6 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
|||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
class_names=targets.class_names,
|
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
trainer = L.Trainer()
|
trainer = L.Trainer()
|
||||||
@ -149,56 +114,28 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
_, recovered_configs = load_model_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
assert not DeepDiff(
|
assert not DeepDiff(
|
||||||
recovered_configs.model.model_dump(mode="json"),
|
recovered.model_config.model_dump(mode="json"),
|
||||||
expected_model_config.model_dump(mode="json"),
|
expected_model_config.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
assert not DeepDiff(
|
assert not DeepDiff(
|
||||||
recovered_configs.train.model_dump(mode="json"),
|
recovered.train_config.model_dump(mode="json"),
|
||||||
train_config.model_dump(mode="json"),
|
train_config.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_load_model_from_checkpoint_includes_targets_config(tmp_path: Path):
|
|
||||||
targets_config = TargetConfig()
|
|
||||||
targets = build_targets(targets_config)
|
|
||||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
|
||||||
module = build_training_module(
|
|
||||||
model_config=ModelConfig(),
|
|
||||||
targets_config=targets_config,
|
|
||||||
class_names=targets.class_names,
|
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
train_config=TrainingConfig(),
|
|
||||||
)
|
|
||||||
trainer = L.Trainer()
|
|
||||||
path = tmp_path / "example.ckpt"
|
|
||||||
trainer.strategy.connect(module)
|
|
||||||
trainer.save_checkpoint(path)
|
|
||||||
|
|
||||||
_, loaded_configs = load_model_from_checkpoint(path)
|
|
||||||
|
|
||||||
assert loaded_configs.targets.model_dump(
|
|
||||||
mode="json"
|
|
||||||
) == targets_config.model_dump(mode="json")
|
|
||||||
|
|
||||||
|
|
||||||
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
expected_model_config = ModelConfig.model_validate(
|
expected_model_config = ModelConfig.model_validate(
|
||||||
model_config.model_dump(mode="json")
|
model_config.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
targets_config = TargetConfig()
|
|
||||||
targets = build_targets(targets_config)
|
|
||||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
|
||||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
|
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
class_names=targets.class_names,
|
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -216,16 +153,14 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
_, recovered_configs = load_model_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
assert recovered_configs.model.model_dump(
|
assert recovered.model_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == expected_model_config.model_dump(mode="json")
|
) == expected_model_config.model_dump(mode="json")
|
||||||
assert recovered_configs.train.model_dump(
|
assert recovered.train_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == train_config.model_dump(mode="json")
|
) == train_config.model_dump(mode="json")
|
||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
|
||||||
|
|
||||||
loaded_optimization_config = recovered.configure_optimizers()
|
loaded_optimization_config = recovered.configure_optimizers()
|
||||||
loaded_optimizer = loaded_optimization_config["optimizer"]
|
loaded_optimizer = loaded_optimization_config["optimizer"]
|
||||||
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
|
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
|
||||||
@ -240,28 +175,12 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
_, stored_configs = load_model_from_checkpoint(path)
|
|
||||||
api = BatDetect2API.from_checkpoint(path)
|
api = BatDetect2API.from_checkpoint(path)
|
||||||
|
|
||||||
assert api.model_config.model_dump(
|
assert api.model_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == stored_configs.model.model_dump(mode="json")
|
) == module.model_config.model_dump(mode="json")
|
||||||
assert api.audio_config.samplerate == stored_configs.model.samplerate
|
assert api.audio_config.samplerate == module.model_config.samplerate
|
||||||
|
|
||||||
|
|
||||||
def test_api_from_checkpoint_reconstructs_targets_from_checkpoint(
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
targets_config = TargetConfig()
|
|
||||||
module = build_default_module(target_config=targets_config)
|
|
||||||
trainer = L.Trainer()
|
|
||||||
path = tmp_path / "example.ckpt"
|
|
||||||
trainer.strategy.connect(module)
|
|
||||||
trainer.save_checkpoint(path)
|
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(path)
|
|
||||||
|
|
||||||
assert api.targets.get_config() == targets_config.model_dump(mode="json")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@ -270,26 +189,19 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
sample_audio_loader: AudioLoader,
|
sample_audio_loader: AudioLoader,
|
||||||
):
|
):
|
||||||
# Given
|
config = BatDetect2Config()
|
||||||
train_config = TrainingConfig.model_validate(
|
config.train.trainer.limit_train_batches = 1
|
||||||
{
|
config.train.trainer.limit_val_batches = 1
|
||||||
"trainer": {
|
config.train.trainer.log_every_n_steps = 1
|
||||||
"limit_train_batches": 1,
|
config.train.train_loader.batch_size = 1
|
||||||
"limit_val_batches": 1,
|
config.train.train_loader.augmentations.enabled = False
|
||||||
"log_every_n_steps": 1,
|
|
||||||
},
|
|
||||||
"train_loader": {
|
|
||||||
"batch_size": 1,
|
|
||||||
"augmentations": {"enabled": False},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# When
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=example_annotations[:1],
|
val_annotations=example_annotations[:1],
|
||||||
train_config=train_config,
|
train_config=config.train,
|
||||||
|
model_config=config.model,
|
||||||
|
audio_config=config.audio,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
@ -297,11 +209,18 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Then
|
|
||||||
checkpoints = list(tmp_path.rglob("*.ckpt"))
|
checkpoints = list(tmp_path.rglob("*.ckpt"))
|
||||||
assert checkpoints
|
assert checkpoints
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(checkpoints[0])
|
model, model_config = load_model_from_checkpoint(checkpoints[0])
|
||||||
|
assert model_config.samplerate == config.model.samplerate
|
||||||
|
assert model_config.architecture.name == config.model.architecture.name
|
||||||
|
assert model_config.preprocess.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == config.model.preprocess.model_dump(mode="json")
|
||||||
|
assert model_config.postprocess.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == config.model.postprocess.model_dump(mode="json")
|
||||||
|
|
||||||
wav = torch.tensor(
|
wav = torch.tensor(
|
||||||
sample_audio_loader.load_clip(example_annotations[0].clip)
|
sample_audio_loader.load_clip(example_annotations[0].clip)
|
||||||
@ -311,18 +230,10 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
def test_build_training_module_uses_provided_model() -> None:
|
def test_build_training_module_uses_provided_model() -> None:
|
||||||
targets = build_targets(TargetConfig())
|
model = build_model(ModelConfig())
|
||||||
roi_mapper = build_roi_mapping(TargetConfig().roi)
|
|
||||||
model = build_model(
|
|
||||||
ModelConfig(),
|
|
||||||
class_names=targets.class_names,
|
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=ModelConfig(),
|
model_config=ModelConfig(),
|
||||||
class_names=targets.class_names,
|
|
||||||
dimension_names=roi_mapper.dimension_names,
|
|
||||||
train_config=TrainingConfig(),
|
train_config=TrainingConfig(),
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
@ -330,117 +241,18 @@ def test_build_training_module_uses_provided_model() -> None:
|
|||||||
assert module.model is model
|
assert module.model is model
|
||||||
|
|
||||||
|
|
||||||
def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
|
||||||
None
|
|
||||||
):
|
|
||||||
source_targets_config = TargetConfig()
|
|
||||||
source_targets = build_targets(source_targets_config)
|
|
||||||
source_roi_mapper = build_roi_mapping(source_targets_config.roi)
|
|
||||||
source_model = build_model(
|
|
||||||
ModelConfig(),
|
|
||||||
class_names=source_targets.class_names,
|
|
||||||
dimension_names=source_roi_mapper.dimension_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
new_targets_config = TargetConfig.model_validate(
|
|
||||||
{
|
|
||||||
"classification_targets": [
|
|
||||||
{
|
|
||||||
"name": "single_class",
|
|
||||||
"tags": [{"key": "class", "value": "single_class"}],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
new_targets = build_targets(new_targets_config)
|
|
||||||
new_roi_mapper = build_roi_mapping(new_targets_config.roi)
|
|
||||||
|
|
||||||
rebuilt_model = build_model_with_new_targets(
|
|
||||||
model=source_model,
|
|
||||||
targets=new_targets,
|
|
||||||
roi_mapper=new_roi_mapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
source_detector = source_model.detector
|
|
||||||
rebuilt_detector = rebuilt_model.detector
|
|
||||||
|
|
||||||
assert rebuilt_detector.backbone is source_detector.backbone
|
|
||||||
assert (
|
|
||||||
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
|
||||||
)
|
|
||||||
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
|
|
||||||
assert rebuilt_model.class_names == ["single_class"]
|
|
||||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_run_train_logs_training_artifacts(
|
|
||||||
tmp_path: Path,
|
|
||||||
example_annotations: list[data.ClipAnnotation],
|
|
||||||
example_dataset,
|
|
||||||
) -> None:
|
|
||||||
train_config = TrainingConfig.model_validate(
|
|
||||||
{
|
|
||||||
"trainer": {
|
|
||||||
"limit_train_batches": 1,
|
|
||||||
"limit_val_batches": 1,
|
|
||||||
"log_every_n_steps": 1,
|
|
||||||
},
|
|
||||||
"train_loader": {
|
|
||||||
"batch_size": 1,
|
|
||||||
"augmentations": {"enabled": False},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
run_train(
|
|
||||||
train_annotations=example_annotations[:1],
|
|
||||||
val_annotations=example_annotations[:1],
|
|
||||||
train_config=train_config,
|
|
||||||
num_epochs=1,
|
|
||||||
train_workers=0,
|
|
||||||
val_workers=0,
|
|
||||||
checkpoint_dir=tmp_path / "checkpoints",
|
|
||||||
log_dir=tmp_path / "logs",
|
|
||||||
seed=0,
|
|
||||||
logging_callbacks=[
|
|
||||||
DatasetConfigArtifactLogging(
|
|
||||||
train_dataset_config=DatasetConfigArtifact(
|
|
||||||
filename="train_dataset.yaml",
|
|
||||||
config=example_dataset,
|
|
||||||
),
|
|
||||||
val_dataset_config=DatasetConfigArtifact(
|
|
||||||
filename="val_dataset.yaml",
|
|
||||||
config=example_dataset,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
artifact_root = next((tmp_path / "logs").rglob("training_artifacts"))
|
|
||||||
|
|
||||||
assert (artifact_root / "targets.yaml").exists()
|
|
||||||
assert (artifact_root / "train_dataset.yaml").exists()
|
|
||||||
assert (artifact_root / "val_dataset.yaml").exists()
|
|
||||||
assert (artifact_root / "train_class_summary.csv").exists()
|
|
||||||
assert (artifact_root / "val_class_summary.csv").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_train_rejects_incompatible_model_config(
|
def test_run_train_rejects_incompatible_model_config(
|
||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Given
|
model = build_model(ModelConfig())
|
||||||
targets_config = TargetConfig()
|
|
||||||
targets = build_targets(targets_config)
|
|
||||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
|
||||||
incompatible_config = ModelConfig()
|
incompatible_config = ModelConfig()
|
||||||
incompatible_model = build_model(
|
incompatible_config.targets.classification_targets.append(
|
||||||
incompatible_config,
|
TargetClassConfig(
|
||||||
class_names=targets.class_names,
|
name="dummy_class",
|
||||||
dimension_names=[*roi_mapper.dimension_names, "extra_dim"],
|
tags=[data.Tag(key="class", value="Dummy class")],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# When/Then
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="Provided model is incompatible with model_config",
|
match="Provided model is incompatible with model_config",
|
||||||
@ -448,10 +260,7 @@ def test_run_train_rejects_incompatible_model_config(
|
|||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=None,
|
val_annotations=None,
|
||||||
model=incompatible_model,
|
model=model,
|
||||||
targets=targets,
|
|
||||||
roi_mapper=roi_mapper,
|
|
||||||
model_config=incompatible_config,
|
model_config=incompatible_config,
|
||||||
targets_config=targets_config,
|
|
||||||
train_config=TrainingConfig(),
|
train_config=TrainingConfig(),
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user