mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
17 Commits
f82ec218f0
...
5a974711b0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a974711b0 | ||
|
|
7b2699786f | ||
|
|
75e52cc548 | ||
|
|
7d416e0f99 | ||
|
|
2d0b810ed3 | ||
|
|
7a46fa021b | ||
|
|
cbe428fc3f | ||
|
|
7a10b7ffff | ||
|
|
c27e7f9f52 | ||
|
|
aa36df668f | ||
|
|
20a7c058fc | ||
|
|
eec126a502 | ||
|
|
57236fc82a | ||
|
|
e33053614a | ||
|
|
ae4f742345 | ||
|
|
44f9870e9e | ||
|
|
d7e61ccd43 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -132,3 +132,4 @@ notebooks/tmp
|
||||
|
||||
# Assets
|
||||
!assets/*
|
||||
/models
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
|
||||
`BatDetect2API` is the main entry point for the current Python workflow.
|
||||
|
||||
It wraps model loading, inference, evaluation, output formatting, and training-related entry points behind one object.
|
||||
It wraps model loading, inference, evaluation, output formatting, and
|
||||
training-related entry points behind one object.
|
||||
|
||||
Defined in `batdetect2.api_v2`.
|
||||
|
||||
@ -10,8 +11,8 @@ Defined in `batdetect2.api_v2`.
|
||||
|
||||
- `BatDetect2API.from_checkpoint(path, ...)`
|
||||
- load a trained checkpoint and optional config overrides.
|
||||
- `BatDetect2API.from_config(config)`
|
||||
- build a full stack from a `BatDetect2Config` object.
|
||||
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
||||
- build a full stack from separate config objects.
|
||||
|
||||
## Inference methods
|
||||
|
||||
@ -46,10 +47,12 @@ Defined in `batdetect2.api_v2`.
|
||||
|
||||
## Output persistence helpers
|
||||
|
||||
- `save_predictions(predictions, path, audio_dir=None, format=None, config=None)`
|
||||
- `save_predictions(predictions, path, audio_dir=None, format=None,
|
||||
config=None)`
|
||||
- `load_predictions(path, format=None, config=None)`
|
||||
|
||||
Use these when you want to save programmatic predictions without going through the CLI.
|
||||
Use these when you want to save programmatic predictions without going through
|
||||
the CLI.
|
||||
|
||||
## Training and evaluation entry points
|
||||
|
||||
@ -60,6 +63,9 @@ Use these when you want to save programmatic predictions without going through t
|
||||
|
||||
## Related pages
|
||||
|
||||
- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||
- Outputs config reference: {doc}`outputs-config`
|
||||
- Output formats reference: {doc}`output-formats`
|
||||
- Python tutorial:
|
||||
{doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||
- Outputs config reference:
|
||||
{doc}`outputs-config`
|
||||
- Output formats reference:
|
||||
{doc}`output-formats`
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
# Top-level app config reference
|
||||
|
||||
The top-level config object is `BatDetect2Config`.
|
||||
|
||||
Defined in `batdetect2.config`.
|
||||
|
||||
It combines the main configuration surfaces used across training, inference, evaluation, outputs, and logging.
|
||||
|
||||
## Fields
|
||||
|
||||
- `config_version`
|
||||
- `train`
|
||||
- training-specific config.
|
||||
- `evaluation`
|
||||
- evaluation task and plot config.
|
||||
- `model`
|
||||
- model architecture, preprocessing, postprocessing, and targets.
|
||||
- `audio`
|
||||
- audio loading and resampling config.
|
||||
- `inference`
|
||||
- clipping and loader config for prediction-time workflows.
|
||||
- `outputs`
|
||||
- output format and output transform config.
|
||||
- `logging`
|
||||
- logging backend and formatting config.
|
||||
|
||||
## Mental model
|
||||
|
||||
Think of `BatDetect2Config` as the complete application wiring for the current stack.
|
||||
|
||||
Use it when you want one reproducible config that describes the whole workflow.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Inference config: {doc}`inference-config`
|
||||
- Evaluation config: {doc}`evaluation-config`
|
||||
- Outputs config: {doc}`outputs-config`
|
||||
- General config reference: {doc}`configs`
|
||||
@ -24,8 +24,8 @@ for full options and argument details.
|
||||
- Global CLI options are documented in {doc}`base`.
|
||||
- Paths with spaces should be wrapped in quotes.
|
||||
- Input audio is expected to be mono.
|
||||
- Legacy `detect` uses a required threshold argument, while `predict` uses
|
||||
the optional `--detection-threshold` override.
|
||||
- Legacy `detect` uses a required threshold argument, while `predict` uses the
|
||||
optional `--detection-threshold` override.
|
||||
|
||||
```{warning}
|
||||
`batdetect2 detect` is a legacy command.
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
Config reference
|
||||
================
|
||||
|
||||
.. automodule:: batdetect2.config
|
||||
:members:
|
||||
BatDetect2 uses separate config objects for different workflow surfaces.
|
||||
|
||||
Use the dedicated reference pages for each config family:
|
||||
|
||||
- inference config
|
||||
- evaluation config
|
||||
- outputs config
|
||||
- preprocessing config
|
||||
- postprocess config
|
||||
- targets config workflow
|
||||
|
||||
Example config files live under `example_data/configs/`.
|
||||
|
||||
@ -2,14 +2,14 @@
|
||||
|
||||
Reference pages are the detailed lookup pages.
|
||||
|
||||
Use this section when you need exact command options, setting names, output details, or Python API entries.
|
||||
Use this section when you need exact command options, setting names, output
|
||||
details, or Python API entries.
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
cli/index
|
||||
api
|
||||
app-config
|
||||
inference-config
|
||||
evaluation-config
|
||||
outputs-config
|
||||
|
||||
@ -1,192 +0,0 @@
|
||||
config_version: v1
|
||||
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: true
|
||||
method: poly
|
||||
|
||||
model:
|
||||
samplerate: 256000
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
architecture:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
name: adam
|
||||
learning_rate: 0.001
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 10
|
||||
check_val_every_n_epoch: 5
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
shuffle: true
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
duration: 0.256
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
classification:
|
||||
weight: 2.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
validation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
|
||||
logging:
|
||||
train:
|
||||
name: csv
|
||||
|
||||
evaluation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: score_distribution
|
||||
- name: example_detection
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: top_class_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: confusion_matrix
|
||||
- name: example_classification
|
||||
- name: clip_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
- name: score_distribution
|
||||
- name: clip_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
4
example_data/configs/audio.yaml
Normal file
4
example_data/configs/audio.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: true
|
||||
method: poly
|
||||
37
example_data/configs/evaluation.yaml
Normal file
37
example_data/configs/evaluation.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: score_distribution
|
||||
- name: example_detection
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: top_class_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: confusion_matrix
|
||||
- name: example_classification
|
||||
- name: clip_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
- name: score_distribution
|
||||
- name: clip_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
9
example_data/configs/inference.yaml
Normal file
9
example_data/configs/inference.yaml
Normal file
@ -0,0 +1,9 @@
|
||||
loader:
|
||||
batch_size: 8
|
||||
|
||||
clipping:
|
||||
enabled: true
|
||||
duration: 0.5
|
||||
overlap: 0.0
|
||||
max_empty: 0.0
|
||||
discard_empty: true
|
||||
2
example_data/configs/logging.yaml
Normal file
2
example_data/configs/logging.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
train:
|
||||
name: csv
|
||||
59
example_data/configs/model.yaml
Normal file
59
example_data/configs/model.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
samplerate: 256000
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
architecture:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
9
example_data/configs/outputs.yaml
Normal file
9
example_data/configs/outputs.yaml
Normal file
@ -0,0 +1,9 @@
|
||||
format:
|
||||
name: raw
|
||||
include_class_scores: true
|
||||
include_features: true
|
||||
include_geometry: true
|
||||
|
||||
transform:
|
||||
detection_transforms: []
|
||||
clip_transforms: []
|
||||
79
example_data/configs/training.yaml
Normal file
79
example_data/configs/training.yaml
Normal file
@ -0,0 +1,79 @@
|
||||
optimizer:
|
||||
name: adam
|
||||
learning_rate: 0.001
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 10
|
||||
check_val_every_n_epoch: 5
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
shuffle: true
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
duration: 0.256
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
classification:
|
||||
weight: 2.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
validation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
8
justfile
8
justfile
@ -112,6 +112,12 @@ clean: clean-build clean-pyc clean-test clean-docs
|
||||
example-train OPTIONS="":
|
||||
uv run batdetect2 train \
|
||||
--val-dataset example_data/dataset.yaml \
|
||||
--config example_data/config.yaml \
|
||||
--base-dir . \
|
||||
--targets example_data/targets.yaml \
|
||||
--model-config example_data/configs/model.yaml \
|
||||
--training-config example_data/configs/training.yaml \
|
||||
--audio-config example_data/configs/audio.yaml \
|
||||
--evaluation-config example_data/configs/evaluation.yaml \
|
||||
--logging-config example_data/configs/logging.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/dataset.yaml
|
||||
|
||||
@ -12,11 +12,14 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from batdetect2.audio import AudioConfig, AudioLoader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import Dataset
|
||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
||||
from batdetect2.logging import (
|
||||
AppLoggingConfig,
|
||||
LoggerConfig,
|
||||
LoggingCallback,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
@ -36,6 +39,7 @@ if TYPE_CHECKING:
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.logging import TrainLoggingContext
|
||||
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
@ -107,9 +111,11 @@ class BatDetect2API:
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
):
|
||||
from batdetect2.train import run_train
|
||||
|
||||
self.model.train()
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
@ -130,12 +136,15 @@ class BatDetect2API:
|
||||
train_config=train_config or self.train_config,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
logger_config=logger_config or self.logging_config.train,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
self.model.eval()
|
||||
return self
|
||||
|
||||
def finetune(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
targets_config: TargetConfig,
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
trainable: Literal[
|
||||
"all", "heads", "classifier_head", "bbox_head"
|
||||
@ -148,25 +157,77 @@ class BatDetect2API:
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune the model with trainable-parameter selection."""
|
||||
"""Fine-tune from a checkpoint using a new target definition."""
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.models import build_model_with_new_targets
|
||||
from batdetect2.outputs import (
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
from batdetect2.train import run_train
|
||||
|
||||
self._set_trainable_parameters(trainable)
|
||||
target_config = TargetConfig.model_validate(targets_config)
|
||||
targets = build_targets(config=target_config)
|
||||
roi_mapper = build_roi_mapping(config=target_config.roi)
|
||||
model = build_model_with_new_targets(
|
||||
model=self.model,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=self.outputs_config.transform,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
api = BatDetect2API(
|
||||
model_config=self.model_config,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
train_config=train_config or self.train_config,
|
||||
evaluation_config=self.evaluation_config,
|
||||
inference_config=self.inference_config,
|
||||
outputs_config=self.outputs_config,
|
||||
logging_config=self.logging_config,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
postprocessor=self.postprocessor,
|
||||
evaluator=build_evaluator(
|
||||
config=self.evaluation_config,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
transform=output_transform,
|
||||
),
|
||||
formatter=build_output_formatter(
|
||||
targets,
|
||||
config=self.outputs_config.format,
|
||||
),
|
||||
output_transform=output_transform,
|
||||
model=model,
|
||||
)
|
||||
|
||||
api._set_trainable_parameters(trainable)
|
||||
api.model.train()
|
||||
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
model_config=model_config or self.model_config,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
model=api.model,
|
||||
targets=api.targets,
|
||||
roi_mapper=api.roi_mapper,
|
||||
model_config=api.model_config,
|
||||
preprocessor=api.preprocessor,
|
||||
audio_loader=api.audio_loader,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
@ -175,11 +236,13 @@ class BatDetect2API:
|
||||
num_epochs=num_epochs,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
train_config=train_config or self.train_config,
|
||||
logger_config=logger_config or self.logging_config.train,
|
||||
audio_config=api.audio_config,
|
||||
train_config=api.train_config,
|
||||
logger_config=logger_config or api.logging_config.train,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
return self
|
||||
api.model.eval()
|
||||
return api
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
@ -483,46 +546,70 @@ class BatDetect2API:
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: BatDetect2Config,
|
||||
model_config: ModelConfig | None = None,
|
||||
targets_config: TargetConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logging_config: AppLoggingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig, build_model
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
from batdetect2.train import TrainingConfig
|
||||
|
||||
targets = build_targets(config=config.model.targets)
|
||||
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||
model_config = model_config or ModelConfig()
|
||||
targets_config = targets_config or TargetConfig()
|
||||
audio_config = audio_config or AudioConfig()
|
||||
train_config = train_config or TrainingConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
outputs_config = outputs_config or OutputsConfig()
|
||||
logging_config = logging_config or AppLoggingConfig()
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
targets = build_targets(config=targets_config)
|
||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||
|
||||
audio_loader = build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.model.preprocess,
|
||||
config=model_config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.model.postprocess,
|
||||
config=model_config.postprocess,
|
||||
)
|
||||
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=config.outputs.format,
|
||||
config=outputs_config.format,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform,
|
||||
config=outputs_config.transform,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
config=evaluation_config,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
transform=output_transform,
|
||||
@ -531,27 +618,27 @@ class BatDetect2API:
|
||||
# NOTE: Build separate instances of preprocessor and postprocessor
|
||||
# to avoid device mismatch errors
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
config=model_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
preprocessor=build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.model.preprocess,
|
||||
config=model_config.preprocess,
|
||||
),
|
||||
postprocessor=build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.model.postprocess,
|
||||
config=model_config.postprocess,
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config.train,
|
||||
evaluation_config=config.evaluation,
|
||||
inference_config=config.inference,
|
||||
outputs_config=config.outputs,
|
||||
logging_config=config.logging,
|
||||
model_config=model_config,
|
||||
audio_config=audio_config,
|
||||
train_config=train_config,
|
||||
evaluation_config=evaluation_config,
|
||||
inference_config=inference_config,
|
||||
outputs_config=outputs_config,
|
||||
logging_config=logging_config,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
audio_loader=audio_loader,
|
||||
@ -567,7 +654,6 @@ class BatDetect2API:
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
targets_config: TargetConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
@ -579,7 +665,6 @@ class BatDetect2API:
|
||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import build_model_with_new_targets
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
build_output_formatter,
|
||||
@ -587,37 +672,41 @@ class BatDetect2API:
|
||||
)
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
from batdetect2.targets import (
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
check_target_compatibility,
|
||||
)
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
model, configs = load_model_from_checkpoint(path)
|
||||
|
||||
model_config = configs.model
|
||||
train_config = train_config or configs.train
|
||||
|
||||
audio_config = audio_config or AudioConfig(
|
||||
samplerate=model_config.samplerate,
|
||||
)
|
||||
train_config = train_config or TrainingConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
outputs_config = outputs_config or OutputsConfig()
|
||||
logging_config = logging_config or AppLoggingConfig()
|
||||
targets_config = configs.targets
|
||||
|
||||
if (
|
||||
targets_config is not None
|
||||
and targets_config != model_config.targets
|
||||
):
|
||||
targets = build_targets(config=targets_config)
|
||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||
model = build_model_with_new_targets(
|
||||
model=model,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
model_config = model_config.model_copy(
|
||||
update={"targets": targets_config}
|
||||
targets = build_targets(config=targets_config)
|
||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||
|
||||
if not check_target_compatibility(targets, model.class_names):
|
||||
raise ValueError(
|
||||
"Provided targets_config is incompatible with the "
|
||||
"checkpoint model: missing one or more model classes."
|
||||
)
|
||||
|
||||
targets = build_targets(config=model_config.targets)
|
||||
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
|
||||
if model.dimension_names != roi_mapper.dimension_names:
|
||||
raise ValueError(
|
||||
"Provided targets_config is incompatible with the "
|
||||
"checkpoint model: mismatched dimension names."
|
||||
)
|
||||
|
||||
audio_loader = build_audio_loader(config=audio_config)
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.evaluate import evaluate_command
|
||||
from batdetect2.cli.finetune import finetune_command
|
||||
from batdetect2.cli.inference import predict
|
||||
from batdetect2.cli.train import train_command
|
||||
|
||||
@ -10,6 +11,7 @@ __all__ = [
|
||||
"detect",
|
||||
"data",
|
||||
"train_command",
|
||||
"finetune_command",
|
||||
"evaluate_command",
|
||||
"predict",
|
||||
]
|
||||
|
||||
@ -77,6 +77,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
"num_workers",
|
||||
type=int,
|
||||
help="Number of worker processes for dataset loading.",
|
||||
default=0,
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
@ -105,7 +106,6 @@ def evaluate_command(
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
@ -119,11 +119,6 @@ def evaluate_command(
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
target_conf = (
|
||||
TargetConfig.load(targets_config)
|
||||
if targets_config is not None
|
||||
else None
|
||||
)
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
@ -150,7 +145,6 @@ def evaluate_command(
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
targets_config=target_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
|
||||
211
src/batdetect2/cli/finetune.py
Normal file
211
src/batdetect2/cli/finetune.py
Normal file
@ -0,0 +1,211 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, cast
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
__all__ = ["finetune_command"]
|
||||
|
||||
|
||||
@cli.command(
|
||||
name="finetune", short_help="Fine-tune a checkpoint on new targets."
|
||||
)
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--model",
|
||||
"model_path",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Path to a checkpoint to fine-tune from.",
|
||||
)
|
||||
@click.option(
|
||||
"--targets",
|
||||
"targets_config",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Path to the new targets config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--val-dataset",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to validation dataset config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Base directory used to resolve relative paths inside the training "
|
||||
"and validation dataset configs."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--training-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to training config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--audio-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to audio config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--logging-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to logging config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--trainable",
|
||||
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
||||
default="heads",
|
||||
show_default=True,
|
||||
help="Which model parameters remain trainable during fine-tuning.",
|
||||
)
|
||||
@click.option(
|
||||
"--ckpt-dir",
|
||||
type=click.Path(exists=True),
|
||||
help="Directory where checkpoints are saved.",
|
||||
)
|
||||
@click.option(
|
||||
"--log-dir",
|
||||
type=click.Path(exists=True),
|
||||
help="Directory where logs are written.",
|
||||
)
|
||||
@click.option(
|
||||
"--train-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of worker processes for training data loading.",
|
||||
)
|
||||
@click.option(
|
||||
"--val-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of worker processes for validation data loading.",
|
||||
)
|
||||
@click.option(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
help="Maximum number of training epochs.",
|
||||
)
|
||||
@click.option(
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
help="Experiment name used for logging backends.",
|
||||
)
|
||||
@click.option(
|
||||
"--run-name",
|
||||
type=str,
|
||||
help="Run name used for logging backends.",
|
||||
)
|
||||
@click.option(
|
||||
"--seed",
|
||||
type=int,
|
||||
help="Random seed used for reproducibility.",
|
||||
)
|
||||
def finetune_command(
|
||||
train_dataset: Path,
|
||||
model_path: Path,
|
||||
targets_config: Path,
|
||||
val_dataset: Path | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
base_dir: Path | None = None,
|
||||
training_config: Path | None = None,
|
||||
audio_config: Path | None = None,
|
||||
logging_config: Path | None = None,
|
||||
trainable: str = "heads",
|
||||
seed: int | None = None,
|
||||
num_epochs: int | None = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.data import load_dataset, load_dataset_config
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.logging import (
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
)
|
||||
|
||||
logger.info("Initiating fine-tuning process...")
|
||||
|
||||
target_conf = TargetConfig.load(targets_config)
|
||||
train_conf = (
|
||||
TrainingConfig.load(training_config)
|
||||
if training_config is not None
|
||||
else None
|
||||
)
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
logging_conf = (
|
||||
AppLoggingConfig.load(logging_config)
|
||||
if logging_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
train_dataset_conf = load_dataset_config(train_dataset)
|
||||
train_annotations = load_dataset(train_dataset_conf, base_dir=base_dir)
|
||||
|
||||
val_dataset_conf = (
|
||||
load_dataset_config(val_dataset) if val_dataset else None
|
||||
)
|
||||
val_annotations = (
|
||||
load_dataset(val_dataset_conf, base_dir=base_dir)
|
||||
if val_dataset_conf
|
||||
else None
|
||||
)
|
||||
|
||||
logging_callbacks = [
|
||||
DatasetConfigArtifactLogging(
|
||||
train_dataset_config=DatasetConfigArtifact(
|
||||
filename="train_dataset.yaml",
|
||||
config=train_dataset_conf,
|
||||
),
|
||||
val_dataset_config=(
|
||||
DatasetConfigArtifact(
|
||||
filename="val_dataset.yaml",
|
||||
config=val_dataset_conf,
|
||||
)
|
||||
if val_dataset_conf
|
||||
else None
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
logging_config=logging_conf,
|
||||
)
|
||||
|
||||
return api.finetune(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
targets_config=target_conf,
|
||||
trainable=cast(
|
||||
Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||
trainable,
|
||||
),
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
num_epochs=num_epochs,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
logger_config=logging_conf.train if logging_conf is not None else None,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
@ -86,11 +86,13 @@ __all__ = ["train_command"]
|
||||
@click.option(
|
||||
"--train-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of worker processes for training data loading.",
|
||||
)
|
||||
@click.option(
|
||||
"--val-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of worker processes for validation data loading.",
|
||||
)
|
||||
@click.option(
|
||||
@ -143,8 +145,7 @@ def train_command(
|
||||
"""
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.data import load_dataset_config, load_dataset_from_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
@ -152,6 +153,10 @@ def train_command(
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.logging import (
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
)
|
||||
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
@ -196,9 +201,6 @@ def train_command(
|
||||
if target_conf is not None:
|
||||
logger.info("Loaded targets configuration.")
|
||||
|
||||
if model_conf is not None and target_conf is not None:
|
||||
model_conf = model_conf.model_copy(update={"targets": target_conf})
|
||||
|
||||
logger.info("Loading training dataset...")
|
||||
train_annotations = load_dataset_from_config(
|
||||
train_dataset,
|
||||
@ -224,37 +226,49 @@ def train_command(
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
|
||||
logging_callbacks = [
|
||||
DatasetConfigArtifactLogging(
|
||||
train_dataset_config=DatasetConfigArtifact(
|
||||
filename="train_dataset.yaml",
|
||||
config=load_dataset_config(train_dataset),
|
||||
),
|
||||
val_dataset_config=(
|
||||
DatasetConfigArtifact(
|
||||
filename="val_dataset.yaml",
|
||||
config=load_dataset_config(val_dataset),
|
||||
)
|
||||
if val_dataset is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
if model_path is not None and model_conf is not None:
|
||||
raise click.UsageError(
|
||||
"--model-config cannot be used with --model. "
|
||||
"Checkpoint model configuration is loaded from the checkpoint."
|
||||
)
|
||||
|
||||
if model_path is not None and target_conf is not None:
|
||||
raise click.UsageError(
|
||||
"--targets cannot be used with --model. "
|
||||
"Checkpoint target configuration is loaded from the checkpoint."
|
||||
)
|
||||
|
||||
if model_path is None:
|
||||
conf = BatDetect2Config()
|
||||
if model_conf is not None:
|
||||
conf.model = model_conf
|
||||
elif target_conf is not None:
|
||||
conf.model = conf.model.model_copy(update={"targets": target_conf})
|
||||
|
||||
if train_conf is not None:
|
||||
conf.train = train_conf
|
||||
if audio_conf is not None:
|
||||
conf.audio = audio_conf
|
||||
if eval_conf is not None:
|
||||
conf.evaluation = eval_conf
|
||||
if inference_conf is not None:
|
||||
conf.inference = inference_conf
|
||||
if outputs_conf is not None:
|
||||
conf.outputs = outputs_conf
|
||||
if logging_conf is not None:
|
||||
conf.logging = logging_conf
|
||||
|
||||
api = BatDetect2API.from_config(conf)
|
||||
api = BatDetect2API.from_config(
|
||||
model_config=model_conf,
|
||||
targets_config=target_conf,
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
logging_config=logging_conf,
|
||||
)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
targets_config=target_conf,
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
@ -274,4 +288,5 @@ def train_command(
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.evaluate.config import (
|
||||
EvaluationConfig,
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = ["BatDetect2Config"]
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
config_version: Literal["v1"] = "v1"
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(
|
||||
default_factory=get_default_eval_config
|
||||
)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||
@ -12,6 +12,8 @@ from typing import (
|
||||
from hydra.utils import instantiate
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"add_import_config",
|
||||
"ImportConfig",
|
||||
@ -120,7 +122,7 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
return self._registry[name](config, *args, **kwargs)
|
||||
|
||||
|
||||
class ImportConfig(BaseModel):
|
||||
class ImportConfig(BaseConfig):
|
||||
"""Base config for dynamic instantiation via Hydra.
|
||||
|
||||
Subclass this to create a registry-specific import escape hatch.
|
||||
|
||||
@ -16,6 +16,7 @@ from batdetect2.data.conditions.common import (
|
||||
IdInListConfig,
|
||||
JsonList,
|
||||
ListFormatConfig,
|
||||
TagInfo,
|
||||
TxtList,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
@ -63,16 +64,17 @@ __all__ = [
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"PathInListConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingCondition",
|
||||
"RecordingConditionConfig",
|
||||
"RecordingConditionImportConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingNotConfig",
|
||||
"RecordingSatisfiesConfig",
|
||||
"SoundEventCondition",
|
||||
"SoundEventConditionConfig",
|
||||
"SoundEventConditionImportConfig",
|
||||
"TagInfo",
|
||||
"TxtList",
|
||||
"build_clip_annotation_condition",
|
||||
"build_recording_condition",
|
||||
|
||||
@ -2,10 +2,23 @@ import csv
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PlainSerializer,
|
||||
model_validator,
|
||||
)
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
@ -138,19 +151,26 @@ class IdInList(Generic[UUIDObject]):
|
||||
return obj.uuid in self.ids
|
||||
|
||||
|
||||
def dump_tag(tag: data.Tag) -> dict[str, Any]:
|
||||
return {"key": tag.term.name, "value": tag.value}
|
||||
|
||||
|
||||
TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)]
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
tag: TagInfo
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: list[data.Tag]
|
||||
tags: list[TagInfo]
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: list[data.Tag]
|
||||
tags: list[TagInfo]
|
||||
|
||||
|
||||
class JsonList(BaseConfig):
|
||||
|
||||
@ -94,7 +94,7 @@ def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
|
||||
return partial(operator.ge, value)
|
||||
|
||||
if op == "eq":
|
||||
return partial(operator.eq, b=value)
|
||||
return partial(operator.eq, value)
|
||||
|
||||
raise ValueError(f"Invalid operator {op}")
|
||||
|
||||
|
||||
@ -24,8 +24,8 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
def run_evaluate(
|
||||
model: Model,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
@ -46,8 +46,6 @@ def run_evaluate(
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
roi_mapper = roi_mapper or model.roi_mapper
|
||||
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
|
||||
@ -45,8 +45,16 @@ def run_batch_inference(
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
roi_mapper = roi_mapper or model.roi_mapper
|
||||
|
||||
if targets is None:
|
||||
raise ValueError(
|
||||
"targets must be provided when running batch inference."
|
||||
)
|
||||
|
||||
if roi_mapper is None:
|
||||
raise ValueError(
|
||||
"roi_mapper must be provided when running batch inference."
|
||||
)
|
||||
|
||||
output_transform = output_transform or build_output_transform(
|
||||
config=output_config.transform,
|
||||
|
||||
@ -7,21 +7,37 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
|
||||
class InferenceModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.detection_threshold = detection_threshold
|
||||
|
||||
if output_transform is None and targets is None:
|
||||
raise ValueError(
|
||||
"targets must be provided when building inference output "
|
||||
"transforms."
|
||||
)
|
||||
|
||||
if output_transform is None and roi_mapper is None:
|
||||
raise ValueError(
|
||||
"roi_mapper must be provided when building inference output "
|
||||
"transforms."
|
||||
)
|
||||
|
||||
self.output_transform = output_transform or build_output_transform(
|
||||
targets=model.targets,
|
||||
roi_mapper=model.roi_mapper,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
def predict_step(
|
||||
|
||||
@ -24,12 +24,7 @@ from batdetect2.core.configs import BaseConfig
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
Logger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
@ -43,11 +38,15 @@ __all__ = [
|
||||
"DVCLiveConfig",
|
||||
"LoggerConfig",
|
||||
"MLFlowLoggerConfig",
|
||||
"LoggingCallback",
|
||||
"TensorBoardLoggerConfig",
|
||||
"build_logger",
|
||||
"enable_logging",
|
||||
"get_image_logger",
|
||||
"get_table_logger",
|
||||
"log_artifact_file",
|
||||
"log_config_artifact",
|
||||
"log_csv_artifact",
|
||||
]
|
||||
|
||||
|
||||
@ -123,6 +122,18 @@ class LoggerBuilder(Protocol, Generic[T]):
|
||||
) -> Logger: ...
|
||||
|
||||
|
||||
LoggingContext = TypeVar("LoggingContext", contravariant=True)
|
||||
|
||||
|
||||
class LoggingCallback(Protocol, Generic[LoggingContext]):
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: LoggingContext,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def create_dvclive_logger(
|
||||
config: DVCLiveConfig,
|
||||
log_dir: Path | None = None,
|
||||
@ -276,6 +287,71 @@ def build_logger(
|
||||
)
|
||||
|
||||
|
||||
def log_artifact_file(
|
||||
runtime_logger: Logger,
|
||||
path: Path,
|
||||
artifact_path: str = "artifacts",
|
||||
) -> None:
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
|
||||
if isinstance(runtime_logger, MLFlowLogger):
|
||||
runtime_logger.experiment.log_artifact( # type: ignore[call-arg]
|
||||
local_path=str(path),
|
||||
artifact_path=artifact_path,
|
||||
run_id=runtime_logger.run_id,
|
||||
)
|
||||
return
|
||||
|
||||
experiment = getattr(runtime_logger, "experiment", None)
|
||||
if experiment is not None and hasattr(experiment, "log_artifact"):
|
||||
experiment.log_artifact(path=path, name=path.name, copy=True)
|
||||
return
|
||||
|
||||
if isinstance(runtime_logger, (CSVLogger, TensorBoardLogger)):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Skipping artifact logging for unsupported logger type {logger_type}",
|
||||
logger_type=type(runtime_logger).__name__,
|
||||
)
|
||||
|
||||
|
||||
def log_config_artifact(
|
||||
logger: Logger,
|
||||
config: BaseConfig,
|
||||
filename: str,
|
||||
artifact_path: Path,
|
||||
) -> None:
|
||||
artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
path = artifact_path / filename
|
||||
path.write_text(config.to_yaml_string())
|
||||
log_artifact_file(
|
||||
logger,
|
||||
path,
|
||||
artifact_path=artifact_path.name,
|
||||
)
|
||||
|
||||
|
||||
def log_csv_artifact(
|
||||
logger: Logger,
|
||||
df: pd.DataFrame,
|
||||
filename: str,
|
||||
artifact_path: Path,
|
||||
) -> None:
|
||||
artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
path = artifact_path / filename
|
||||
df.to_csv(path, index=False)
|
||||
log_artifact_file(
|
||||
logger,
|
||||
path,
|
||||
artifact_path=artifact_path.name,
|
||||
)
|
||||
|
||||
|
||||
PlotLogger = Callable[[str, "Figure", int], None]
|
||||
|
||||
|
||||
|
||||
@ -26,11 +26,8 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
|
||||
is the ``build_model`` factory function exported from this module.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
@ -73,7 +70,6 @@ from batdetect2.postprocess.types import (
|
||||
)
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -131,10 +127,6 @@ class ModelConfig(BaseConfig):
|
||||
Parameters for converting raw model outputs into detections (NMS
|
||||
kernel, thresholds, top-k limit). Defaults to
|
||||
``PostprocessConfig()``.
|
||||
targets : TargetConfig
|
||||
Detection and classification target definitions (class list,
|
||||
detection target, bounding-box mapper). Defaults to
|
||||
``TargetConfig()``.
|
||||
"""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
@ -143,23 +135,6 @@ class ModelConfig(BaseConfig):
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
path: PathLike,
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
targets: TargetConfig | None = None,
|
||||
) -> "ModelConfig":
|
||||
config = super().load(path, field, extra, strict)
|
||||
|
||||
if targets is None:
|
||||
return config
|
||||
|
||||
return config.model_copy(update={"targets": targets})
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
@ -183,33 +158,32 @@ class Model(torch.nn.Module):
|
||||
postprocessor : PostprocessorProtocol
|
||||
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
||||
per-clip detection tensors.
|
||||
targets : TargetProtocol
|
||||
Describes the set of target classes; used when building heads and
|
||||
during training target construction.
|
||||
roi_mapper : ROIMapperProtocol
|
||||
Maps geometries to target-size channels and back.
|
||||
class_names : list[str]
|
||||
Class names corresponding to the model classification outputs.
|
||||
dimension_names : list[str]
|
||||
Size-dimension names corresponding to the model size outputs.
|
||||
"""
|
||||
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
targets: TargetProtocol
|
||||
roi_mapper: ROIMapperProtocol
|
||||
class_names: list[str]
|
||||
dimension_names: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
self.roi_mapper = roi_mapper
|
||||
self.class_names = class_names
|
||||
self.dimension_names = dimension_names
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||
"""Run the full detection pipeline on a waveform tensor.
|
||||
@ -237,9 +211,9 @@ class Model(torch.nn.Module):
|
||||
|
||||
|
||||
def build_model(
|
||||
config: ModelConfig | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
config: ModelConfig | dict | None = None,
|
||||
class_names: list[str] | None = None,
|
||||
dimension_names: list[str] | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
) -> Model:
|
||||
@ -254,11 +228,13 @@ def build_model(
|
||||
----------
|
||||
config : ModelConfig, optional
|
||||
Full model configuration (samplerate, architecture, preprocessing,
|
||||
postprocessing, targets). Defaults to ``ModelConfig()`` if not
|
||||
provided.
|
||||
targets : TargetProtocol, optional
|
||||
Pre-built targets object. If given, overrides
|
||||
``config.targets``.
|
||||
postprocessing). Defaults to ``ModelConfig()`` if not provided.
|
||||
class_names : list[str], optional
|
||||
Class names used to size the classifier head. Required when building
|
||||
a new model.
|
||||
dimension_names : list[str], optional
|
||||
Dimension names used to size the bbox head. Required when building a
|
||||
new model.
|
||||
preprocessor : PreprocessorProtocol, optional
|
||||
Pre-built preprocessor. If given, overrides
|
||||
``config.preprocess`` and ``config.samplerate`` for the
|
||||
@ -278,19 +254,20 @@ def build_model(
|
||||
"""
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
|
||||
config = config or ModelConfig()
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
|
||||
targets_config = getattr(targets, "config", None)
|
||||
roi_config = (
|
||||
targets_config.roi
|
||||
if isinstance(targets_config, TargetConfig)
|
||||
else config.targets.roi
|
||||
)
|
||||
if isinstance(config, dict):
|
||||
config = ModelConfig.model_validate(config)
|
||||
|
||||
if class_names is None:
|
||||
raise ValueError("class_names must be provided when building a model.")
|
||||
|
||||
if dimension_names is None:
|
||||
raise ValueError(
|
||||
"dimension_names must be provided when building a model."
|
||||
)
|
||||
|
||||
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=config.samplerate,
|
||||
@ -300,16 +277,16 @@ def build_model(
|
||||
config=config.postprocess,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
num_sizes=len(roi_mapper.dimension_names),
|
||||
num_classes=len(class_names),
|
||||
num_sizes=len(dimension_names),
|
||||
config=config.architecture,
|
||||
)
|
||||
return Model(
|
||||
detector=detector,
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
class_names=class_names,
|
||||
dimension_names=dimension_names,
|
||||
)
|
||||
|
||||
|
||||
@ -329,6 +306,6 @@ def build_model_with_new_targets(
|
||||
detector=detector,
|
||||
postprocessor=model.postprocessor,
|
||||
preprocessor=model.preprocessor,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
)
|
||||
|
||||
@ -53,8 +53,12 @@ import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlockImportConfig",
|
||||
|
||||
@ -6,6 +6,7 @@ from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.heatmaps import (
|
||||
plot_classification_heatmap,
|
||||
plot_detection_heatmap,
|
||||
plot_size_heatmap,
|
||||
)
|
||||
from batdetect2.plotting.matches import (
|
||||
plot_cross_trigger_match,
|
||||
@ -25,5 +26,6 @@ __all__ = [
|
||||
"plot_true_positive_match",
|
||||
"plot_detection_heatmap",
|
||||
"plot_classification_heatmap",
|
||||
"plot_size_heatmap",
|
||||
"plot_match_gallery",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Plot heatmaps"""
|
||||
"""Plot heatmaps."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -8,6 +8,12 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
||||
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
__all__ = [
|
||||
"plot_detection_heatmap",
|
||||
"plot_classification_heatmap",
|
||||
"plot_size_heatmap",
|
||||
]
|
||||
|
||||
|
||||
def plot_detection_heatmap(
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
@ -108,7 +114,91 @@ def plot_classification_heatmap(
|
||||
return ax
|
||||
|
||||
|
||||
def create_colormap(color: str) -> Colormap:
|
||||
def plot_size_heatmap(
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
dimension_names: list[str],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
color: str = "crimson",
|
||||
size: float = 20,
|
||||
fontsize: float = 8,
|
||||
) -> axes.Axes:
|
||||
"""Plot sparse size labels from a size heatmap.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
heatmap : torch.Tensor | np.ndarray
|
||||
Size heatmap with shape ``[num_dims, height, width]``. Entries are
|
||||
expected to be zero everywhere except at labelled positions.
|
||||
dimension_names : list[str]
|
||||
Names corresponding to the first heatmap dimension.
|
||||
ax : matplotlib.axes.Axes | None, default=None
|
||||
Axis to plot on. If ``None``, a new axis is created.
|
||||
figsize : tuple[int, int], default=(10, 10)
|
||||
Figure size used when creating a new axis.
|
||||
color : str, default="crimson"
|
||||
Color used for scatter points and text labels.
|
||||
size : float, default=20
|
||||
Marker size for plotted points.
|
||||
fontsize : float, default=8
|
||||
Font size used for the text labels.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.axes.Axes
|
||||
Axis containing the plotted size labels.
|
||||
"""
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
if isinstance(heatmap, torch.Tensor):
|
||||
heatmap = heatmap.numpy()
|
||||
|
||||
if heatmap.ndim == 4:
|
||||
heatmap = heatmap[0]
|
||||
|
||||
if heatmap.ndim != 3:
|
||||
raise ValueError("Expecting a 3-dimensional array")
|
||||
|
||||
if len(dimension_names) != heatmap.shape[0]:
|
||||
raise ValueError("Inconsistent number of dimension names")
|
||||
|
||||
point_mask = np.any(heatmap != 0, axis=0)
|
||||
rows, cols = np.nonzero(point_mask)
|
||||
|
||||
if len(rows) == 0:
|
||||
return ax
|
||||
|
||||
ax.scatter(cols, rows, c=color, s=size)
|
||||
|
||||
for row, col in zip(rows, cols, strict=False):
|
||||
values = heatmap[:, row, col]
|
||||
labels = [
|
||||
f"{name}={value:.2f}"
|
||||
for name, value in zip(
|
||||
dimension_names,
|
||||
values,
|
||||
strict=False,
|
||||
)
|
||||
if value != 0
|
||||
]
|
||||
ax.text(
|
||||
float(col),
|
||||
float(row),
|
||||
"\n".join(labels),
|
||||
fontsize=fontsize,
|
||||
color=color,
|
||||
va="bottom",
|
||||
ha="left",
|
||||
)
|
||||
|
||||
ax.set_xlim(0, heatmap.shape[2])
|
||||
ax.set_ylim(0, heatmap.shape[1])
|
||||
return ax
|
||||
|
||||
|
||||
def create_colormap(
|
||||
color: str | tuple[float, float, float, float],
|
||||
) -> Colormap:
|
||||
(r, g, b, a) = to_rgba(color)
|
||||
return LinearSegmentedColormap.from_list(
|
||||
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
||||
|
||||
@ -6,7 +6,7 @@ from batdetect2.targets.classes import (
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
)
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.targets.config import TargetConfig, build_default_target_config
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
ROIMapperConfig,
|
||||
@ -36,13 +36,14 @@ from batdetect2.targets.types import (
|
||||
SoundEventFilter,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.targets.utils import check_target_compatibility
|
||||
|
||||
__all__ = [
|
||||
"AnchorBBoxMapperConfig",
|
||||
"Position",
|
||||
"ROIMappingConfig",
|
||||
"ROIMapperProtocol",
|
||||
"ROIMapperConfig",
|
||||
"ROIMapperProtocol",
|
||||
"ROIMappingConfig",
|
||||
"ROITargetMapper",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
@ -52,12 +53,14 @@ __all__ = [
|
||||
"TargetConfig",
|
||||
"TargetProtocol",
|
||||
"Targets",
|
||||
"build_roi_mapping",
|
||||
"build_default_target_config",
|
||||
"build_roi_mapper",
|
||||
"build_roi_mapping",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"build_targets",
|
||||
"call_type",
|
||||
"check_target_compatibility",
|
||||
"data_source",
|
||||
"generic_class",
|
||||
"get_class_names_from_config",
|
||||
|
||||
@ -12,6 +12,7 @@ from batdetect2.data.conditions import (
|
||||
NotConfig,
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
TagInfo,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
from batdetect2.targets.terms import call_type, generic_class
|
||||
@ -32,11 +33,12 @@ class TargetClassConfig(BaseConfig):
|
||||
condition_input: SoundEventConditionConfig | None = Field(
|
||||
alias="match_if",
|
||||
default=None,
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
||||
|
||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||
assign_tags: List[TagInfo] = Field(default_factory=list)
|
||||
|
||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from collections import Counter
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.targets.classes import (
|
||||
@ -13,6 +14,7 @@ from batdetect2.targets.rois import ROIMappingConfig
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"build_default_target_config",
|
||||
]
|
||||
|
||||
|
||||
@ -42,3 +44,20 @@ class TargetConfig(BaseConfig):
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def build_default_target_config(class_names: list[str]) -> TargetConfig:
|
||||
"""Build a default target configuration object."""
|
||||
return TargetConfig(
|
||||
detection_target=DEFAULT_DETECTION_CLASS,
|
||||
classification_targets=[
|
||||
TargetClassConfig(
|
||||
name=class_name,
|
||||
tags=[
|
||||
data.Tag(key="class", value=class_name),
|
||||
],
|
||||
)
|
||||
for class_name in class_names
|
||||
],
|
||||
roi=ROIMappingConfig(),
|
||||
)
|
||||
|
||||
@ -50,21 +50,31 @@ class Targets(TargetProtocol):
|
||||
self.config = config
|
||||
|
||||
self._filter_fn = build_sound_event_condition(
|
||||
config.detection_target.match_if
|
||||
self.config.detection_target.match_if
|
||||
)
|
||||
self._encode_fn = build_sound_event_encoder(
|
||||
config.classification_targets
|
||||
self.config.classification_targets
|
||||
)
|
||||
self._decode_fn = build_sound_event_decoder(
|
||||
config.classification_targets
|
||||
self.config.classification_targets
|
||||
)
|
||||
|
||||
self.class_names = get_class_names_from_config(
|
||||
config.classification_targets
|
||||
self.config.classification_targets
|
||||
)
|
||||
|
||||
self.detection_class_name = config.detection_target.name
|
||||
self.detection_class_tags = config.detection_target.assign_tags
|
||||
self.detection_class_name = self.config.detection_target.name
|
||||
self.detection_class_tags = self.config.detection_target.assign_tags
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "Targets":
|
||||
"""Build a Targets object from a serialized config dictionary."""
|
||||
validated_config = TargetConfig.model_validate(config)
|
||||
return cls(config=validated_config)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
"""Return the serialized target config used to build this object."""
|
||||
return self.config.model_dump(mode="json")
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
@ -131,7 +141,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
)
|
||||
|
||||
|
||||
def build_targets(config: TargetConfig | None = None) -> Targets:
|
||||
def build_targets(config: TargetConfig | dict | None = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
Parameters
|
||||
@ -153,6 +163,10 @@ def build_targets(config: TargetConfig | None = None) -> Targets:
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
config = config or DEFAULT_TARGET_CONFIG
|
||||
|
||||
if not isinstance(config, TargetConfig):
|
||||
config = TargetConfig.model_validate(config)
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building targets with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
|
||||
@ -28,6 +28,11 @@ class TargetProtocol(Protocol):
|
||||
detection_class_tags: list[data.Tag]
|
||||
detection_class_name: str
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "TargetProtocol": ...
|
||||
|
||||
def get_config(self) -> dict: ...
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||
|
||||
def encode_class(
|
||||
|
||||
29
src/batdetect2/targets/utils.py
Normal file
29
src/batdetect2/targets/utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def check_target_compatibility(
|
||||
targets: "TargetProtocol",
|
||||
class_names: list[str],
|
||||
) -> bool:
|
||||
"""Check if a target definition can decode a model's outputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
Target definition that would be used with the model outputs.
|
||||
class_names : list[str]
|
||||
Class names produced by the model checkpoint.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True when every model class name exists in the provided targets,
|
||||
False otherwise.
|
||||
"""
|
||||
target_class_names = set(targets.class_names)
|
||||
model_class_names = set(class_names)
|
||||
|
||||
return model_class_names.issubset(target_class_names)
|
||||
@ -4,10 +4,24 @@ from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
)
|
||||
from batdetect2.train.logging import (
|
||||
ConfigHyperparameterLogging,
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
DataSummaryArtifactLogging,
|
||||
TargetConfigArtifactLogging,
|
||||
TrainLoggingContext,
|
||||
)
|
||||
from batdetect2.train.train import build_trainer, run_train
|
||||
|
||||
__all__ = [
|
||||
"ConfigHyperparameterLogging",
|
||||
"DataSummaryArtifactLogging",
|
||||
"DEFAULT_CHECKPOINT_DIR",
|
||||
"DatasetConfigArtifact",
|
||||
"DatasetConfigArtifactLogging",
|
||||
"TargetConfigArtifactLogging",
|
||||
"TrainLoggingContext",
|
||||
"TrainingConfig",
|
||||
"TrainingModule",
|
||||
"build_trainer",
|
||||
|
||||
@ -10,6 +10,7 @@ from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.types import TrainExample
|
||||
@ -19,11 +20,15 @@ class ValidationMetrics(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
evaluator: EvaluatorProtocol,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.evaluator = evaluator
|
||||
self.targets = targets
|
||||
self.roi_mapper = roi_mapper
|
||||
self.output_transform = output_transform
|
||||
|
||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||
@ -93,8 +98,8 @@ class ValidationMetrics(Callback):
|
||||
model = pl_module.model
|
||||
if self.output_transform is None:
|
||||
self.output_transform = build_output_transform(
|
||||
targets=model.targets,
|
||||
roi_mapper=model.roi_mapper,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
)
|
||||
|
||||
output_transform = self.output_transform
|
||||
|
||||
@ -34,6 +34,8 @@ def build_checkpoint_callback(
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = config.checkpoint_dir
|
||||
|
||||
checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
if experiment_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / experiment_name
|
||||
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import lightning as L
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train.optimizers import build_optimizer
|
||||
@ -11,6 +14,7 @@ from batdetect2.train.types import LossProtocol, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
"load_model_from_checkpoint",
|
||||
]
|
||||
|
||||
|
||||
@ -21,6 +25,9 @@ class TrainingModule(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: dict | None = None,
|
||||
targets_config: dict | None = None,
|
||||
class_names: list[str] | None = None,
|
||||
dimension_names: list[str] | None = None,
|
||||
train_config: dict | None = None,
|
||||
loss: LossProtocol | None = None,
|
||||
model: Model | None = None,
|
||||
@ -29,14 +36,34 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||
|
||||
self.model_config = ModelConfig.model_validate(model_config or {})
|
||||
self.model_config: dict = model_config or {}
|
||||
self.targets_config: dict = targets_config or {}
|
||||
self.class_names = list(class_names or [])
|
||||
self.dimension_names = list(dimension_names or [])
|
||||
|
||||
self.train_config = TrainingConfig.model_validate(train_config or {})
|
||||
|
||||
if loss is None:
|
||||
loss = build_loss(config=self.train_config.loss)
|
||||
|
||||
if model is None:
|
||||
model = build_model(config=self.model_config)
|
||||
if not self.class_names:
|
||||
raise ValueError(
|
||||
"class_names must be provided when rebuilding a training "
|
||||
"module without a model."
|
||||
)
|
||||
|
||||
if not self.dimension_names:
|
||||
raise ValueError(
|
||||
"dimension_names must be provided when rebuilding a "
|
||||
"training module without a model."
|
||||
)
|
||||
|
||||
model = build_model(
|
||||
config=self.model_config,
|
||||
class_names=self.class_names,
|
||||
dimension_names=self.dimension_names,
|
||||
)
|
||||
|
||||
self.loss = loss
|
||||
self.model = model
|
||||
@ -95,9 +122,16 @@ class TrainingModule(L.LightningModule):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredConfig:
|
||||
model: ModelConfig
|
||||
targets: TargetConfig
|
||||
train: TrainingConfig
|
||||
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> tuple[Model, ModelConfig]:
|
||||
) -> tuple[Model, StoredConfig]:
|
||||
"""Load a model and its configuration from a Lightning checkpoint.
|
||||
|
||||
Parameters
|
||||
@ -110,15 +144,24 @@ def load_model_from_checkpoint(
|
||||
-------
|
||||
tuple[Model, ModelConfig]
|
||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||
describes its architecture, preprocessing, postprocessing, and
|
||||
targets.
|
||||
describes its architecture, preprocessing, and postprocessing.
|
||||
"""
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
return module.model, module.model_config
|
||||
training_config = TrainingConfig.model_validate(module.train_config)
|
||||
model_config = ModelConfig.model_validate(module.model_config)
|
||||
targets_config = TargetConfig.model_validate(module.targets_config)
|
||||
return module.model, StoredConfig(
|
||||
model=model_config,
|
||||
targets=targets_config,
|
||||
train=training_config,
|
||||
)
|
||||
|
||||
|
||||
def build_training_module(
|
||||
model_config: ModelConfig | None = None,
|
||||
targets_config: TargetConfig | dict | None = None,
|
||||
class_names: list[str] | None = None,
|
||||
dimension_names: list[str] | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
model: Model | None = None,
|
||||
) -> TrainingModule:
|
||||
@ -128,8 +171,16 @@ def build_training_module(
|
||||
if train_config is None:
|
||||
train_config = TrainingConfig()
|
||||
|
||||
if targets_config is None:
|
||||
targets_config = TargetConfig()
|
||||
|
||||
targets_config = TargetConfig.model_validate(targets_config)
|
||||
|
||||
return TrainingModule(
|
||||
model_config=model_config.model_dump(mode="json"),
|
||||
targets_config=targets_config.model_dump(mode="json"),
|
||||
train_config=train_config.model_dump(mode="json"),
|
||||
class_names=class_names,
|
||||
dimension_names=dimension_names,
|
||||
model=model,
|
||||
)
|
||||
|
||||
164
src/batdetect2/train/logging.py
Normal file
164
src/batdetect2/train/logging.py
Normal file
@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data import Dataset, compute_class_summary
|
||||
from batdetect2.logging import log_config_artifact, log_csv_artifact
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
"ConfigHyperparameterLogging",
|
||||
"DataSummaryArtifactLogging",
|
||||
"DatasetConfigArtifact",
|
||||
"DatasetConfigArtifactLogging",
|
||||
"TargetConfigArtifactLogging",
|
||||
"TrainLoggingContext",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainLoggingContext:
|
||||
model_config: ModelConfig
|
||||
train_config: TrainingConfig
|
||||
audio_config: AudioConfig
|
||||
targets: TargetProtocol
|
||||
train_dataset: Dataset
|
||||
val_dataset: Dataset | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatasetConfigArtifact:
|
||||
filename: str
|
||||
config: BaseConfig
|
||||
|
||||
|
||||
class ConfigHyperparameterLogging:
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
logger.log_hyperparams(
|
||||
{
|
||||
"model": context.model_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
"training": context.train_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
"audio": context.audio_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TargetConfigArtifactLogging:
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
targets_config = TargetConfig.model_validate(
|
||||
context.targets.get_config()
|
||||
)
|
||||
log_config_artifact(
|
||||
logger,
|
||||
targets_config,
|
||||
filename="targets.yaml",
|
||||
artifact_path=artifact_path / "training_artifacts",
|
||||
)
|
||||
|
||||
|
||||
class DatasetConfigArtifactLogging:
|
||||
def __init__(
|
||||
self,
|
||||
train_dataset_config: DatasetConfigArtifact,
|
||||
val_dataset_config: DatasetConfigArtifact | None = None,
|
||||
):
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.val_dataset_config = val_dataset_config
|
||||
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
training_artifact_path = artifact_path / "training_artifacts"
|
||||
|
||||
log_config_artifact(
|
||||
logger,
|
||||
self.train_dataset_config.config,
|
||||
filename=self.train_dataset_config.filename,
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
if self.val_dataset_config is not None:
|
||||
log_config_artifact(
|
||||
logger,
|
||||
self.val_dataset_config.config,
|
||||
filename=self.val_dataset_config.filename,
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
|
||||
class DataSummaryArtifactLogging:
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
training_artifact_path = artifact_path / "training_artifacts"
|
||||
|
||||
log_csv_artifact(
|
||||
logger,
|
||||
_compute_class_summary_or_empty(
|
||||
context.train_dataset,
|
||||
context.targets,
|
||||
),
|
||||
filename="train_class_summary.csv",
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
if context.val_dataset is not None:
|
||||
log_csv_artifact(
|
||||
logger,
|
||||
_compute_class_summary_or_empty(
|
||||
context.val_dataset,
|
||||
context.targets,
|
||||
),
|
||||
filename="val_class_summary.csv",
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
|
||||
def _compute_class_summary_or_empty(
|
||||
dataset: Sequence[data.ClipAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> pd.DataFrame:
|
||||
try:
|
||||
return compute_class_summary(dataset, targets)
|
||||
except KeyError as error:
|
||||
if error.args != ("class_name",):
|
||||
raise
|
||||
|
||||
return pd.DataFrame(
|
||||
columns=["num calls", "num recordings", "duration", "call_rate"]
|
||||
)
|
||||
@ -3,6 +3,7 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
from torch.optim import Adam, Optimizer
|
||||
@ -84,4 +85,10 @@ def build_optimizer(
|
||||
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
||||
"""
|
||||
config = config or AdamOptimizerConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building optimizer with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
return optimizer_registry.build(config, parameters)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
|
||||
@ -78,4 +79,9 @@ def build_scheduler(
|
||||
"""Build a scheduler from configuration."""
|
||||
config = config or CosineAnnealingSchedulerConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building scheduler with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
return scheduler_registry.build(config, optimizer)
|
||||
|
||||
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from lightning import Trainer, seed_everything
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
@ -10,6 +11,7 @@ from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
||||
from batdetect2.logging import (
|
||||
LoggerConfig,
|
||||
LoggingCallback,
|
||||
TensorBoardLoggerConfig,
|
||||
build_logger,
|
||||
)
|
||||
@ -17,6 +19,7 @@ from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetConfig,
|
||||
TargetProtocol,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
@ -27,6 +30,12 @@ from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
from batdetect2.train.logging import (
|
||||
ConfigHyperparameterLogging,
|
||||
DataSummaryArtifactLogging,
|
||||
TargetConfigArtifactLogging,
|
||||
TrainLoggingContext,
|
||||
)
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
__all__ = [
|
||||
@ -35,6 +44,9 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
||||
|
||||
|
||||
def run_train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
@ -46,6 +58,7 @@ def run_train(
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
audio_config: Optional[AudioConfig] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
targets_config: TargetConfig | None = None,
|
||||
train_config: Optional[TrainingConfig] = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
trainer: Trainer | None = None,
|
||||
@ -57,28 +70,43 @@ def run_train(
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
):
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
model_config = model_config or ModelConfig()
|
||||
targets_config = targets_config or TargetConfig()
|
||||
audio_config = audio_config or AudioConfig()
|
||||
train_config = train_config or TrainingConfig()
|
||||
|
||||
if model is not None:
|
||||
_validate_model_compatibility(model=model, model_config=model_config)
|
||||
if targets is None:
|
||||
raise ValueError(
|
||||
"targets must be provided when training with an existing "
|
||||
"model."
|
||||
)
|
||||
|
||||
if roi_mapper is None:
|
||||
raise ValueError(
|
||||
"roi_mapper must be provided when training with an existing "
|
||||
"model."
|
||||
)
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets(config=targets_config)
|
||||
else:
|
||||
targets_config = TargetConfig.model_validate(targets.get_config())
|
||||
|
||||
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
|
||||
|
||||
if model is not None:
|
||||
targets = targets or model.targets
|
||||
|
||||
if roi_mapper is None and targets is model.targets:
|
||||
roi_mapper = model.roi_mapper
|
||||
|
||||
targets = targets or build_targets(config=model_config.targets)
|
||||
|
||||
roi_mapper = roi_mapper or build_roi_mapping(
|
||||
config=model_config.targets.roi
|
||||
)
|
||||
_validate_model_compatibility(
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
)
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
@ -119,21 +147,57 @@ def run_train(
|
||||
|
||||
module = build_training_module(
|
||||
model_config=model_config,
|
||||
targets_config=targets_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
train_config=train_config,
|
||||
model=model,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
train_config.validation,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
logger_config or TensorBoardLoggerConfig(),
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
root_artifact_path = (
|
||||
Path(log_dir) if log_dir is not None else DEFAULT_LOG_DIR
|
||||
)
|
||||
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging_context = TrainLoggingContext(
|
||||
model_config=model_config,
|
||||
train_config=train_config,
|
||||
audio_config=audio_config,
|
||||
targets=targets,
|
||||
train_dataset=train_annotations,
|
||||
val_dataset=val_annotations,
|
||||
)
|
||||
|
||||
resolved_logging_callbacks = (
|
||||
ConfigHyperparameterLogging(),
|
||||
TargetConfigArtifactLogging(),
|
||||
DataSummaryArtifactLogging(),
|
||||
*logging_callbacks,
|
||||
)
|
||||
|
||||
for callback in resolved_logging_callbacks:
|
||||
callback.run(train_logger, root_artifact_path, logging_context)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
train_config,
|
||||
logger_config=logger_config,
|
||||
evaluator=build_evaluator(
|
||||
train_config.validation,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
),
|
||||
train_logger=train_logger,
|
||||
evaluator=evaluator,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
num_epochs=num_epochs,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
@ -152,8 +216,14 @@ def run_train(
|
||||
def _validate_model_compatibility(
|
||||
model: Model,
|
||||
model_config: ModelConfig,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
) -> None:
|
||||
reference_model = build_model(config=model_config)
|
||||
reference_model = build_model(
|
||||
config=model_config,
|
||||
class_names=class_names,
|
||||
dimension_names=dimension_names,
|
||||
)
|
||||
|
||||
expected_shapes = {
|
||||
key: tuple(value.shape)
|
||||
@ -194,10 +264,11 @@ def _validate_model_compatibility(
|
||||
|
||||
def build_trainer(
|
||||
config: TrainingConfig,
|
||||
logger_config: LoggerConfig | None,
|
||||
train_logger: Logger,
|
||||
evaluator: "EvaluatorProtocol",
|
||||
targets: "TargetProtocol",
|
||||
roi_mapper: "ROIMapperProtocol",
|
||||
checkpoint_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
@ -208,25 +279,11 @@ def build_trainer(
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
logger_config or TensorBoardLoggerConfig(),
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
train_logger.log_hyperparams(
|
||||
config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
)
|
||||
if num_epochs is not None:
|
||||
trainer_conf.max_epochs = num_epochs
|
||||
|
||||
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||
|
||||
if num_epochs is not None:
|
||||
train_config["max_epochs"] = num_epochs
|
||||
|
||||
return Trainer(
|
||||
**train_config,
|
||||
logger=train_logger,
|
||||
@ -237,6 +294,6 @@ def build_trainer(
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
ValidationMetrics(evaluator, targets, roi_mapper),
|
||||
],
|
||||
)
|
||||
|
||||
@ -13,13 +13,14 @@ from soundevent import data, terms
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.clips import build_clipper
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetConfig,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
call_type,
|
||||
)
|
||||
@ -404,6 +405,13 @@ def sample_targets(
|
||||
return build_targets(sample_target_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_roi_mapper(
|
||||
sample_target_config: TargetConfig,
|
||||
) -> ROIMapperProtocol:
|
||||
return build_roi_mapping(sample_target_config.roi)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_labeller(
|
||||
sample_targets: TargetProtocol,
|
||||
@ -458,8 +466,16 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tiny_checkpoint_path(tmp_path: Path) -> Path:
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
def tiny_checkpoint_path(
|
||||
sample_targets: TargetProtocol,
|
||||
sample_roi_mapper: ROIMapperProtocol,
|
||||
tmp_path: Path,
|
||||
) -> Path:
|
||||
module = build_training_module(
|
||||
targets_config=sample_targets.get_config(),
|
||||
class_names=sample_targets.class_names,
|
||||
dimension_names=sample_roi_mapper.dimension_names,
|
||||
)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
checkpoint_path = tmp_path / "model.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
|
||||
0
tests/test_api_v2/__init__.py
Normal file
0
tests/test_api_v2/__init__.py
Normal file
@ -8,20 +8,43 @@ import torch
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.models.heads import ClassifierHead
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_v2() -> BatDetect2API:
|
||||
def train_config() -> TrainingConfig:
|
||||
"""Train config with a small batch size for testing."""
|
||||
return TrainingConfig.model_validate({"train_loader": {"batch_size": 2}})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inference_config() -> InferenceConfig:
|
||||
"""Inference config with a small batch size for testing."""
|
||||
return InferenceConfig.model_validate({"loader": {"batch_size": 2}})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_targets_config(example_data_dir: Path) -> TargetConfig:
|
||||
return TargetConfig.load(example_data_dir / "targets.yaml")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_v2(
|
||||
train_config: TrainingConfig,
|
||||
inference_config: InferenceConfig,
|
||||
) -> BatDetect2API:
|
||||
"""User story: users can create a ready-to-use API from config."""
|
||||
|
||||
config = BatDetect2Config()
|
||||
config.inference.loader.batch_size = 2
|
||||
return BatDetect2API.from_config(config)
|
||||
api = BatDetect2API.from_config(
|
||||
train_config=train_config,
|
||||
inference_config=inference_config,
|
||||
)
|
||||
assert api.inference_config.loader.batch_size == 2
|
||||
return api
|
||||
|
||||
|
||||
def test_process_file_returns_recording_level_predictions(
|
||||
@ -30,8 +53,10 @@ def test_process_file_returns_recording_level_predictions(
|
||||
) -> None:
|
||||
"""User story: process a file and get detections in recording time."""
|
||||
|
||||
# When
|
||||
prediction = api_v2.process_file(example_audio_files[0])
|
||||
|
||||
# Then
|
||||
assert prediction.clip.recording.path == example_audio_files[0]
|
||||
assert prediction.clip.start_time == 0
|
||||
assert prediction.clip.end_time == prediction.clip.recording.duration
|
||||
@ -53,9 +78,11 @@ def test_process_files_is_batch_size_invariant(
|
||||
) -> None:
|
||||
"""User story: changing batch size should not change predictions."""
|
||||
|
||||
# When
|
||||
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
|
||||
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
|
||||
|
||||
# Then
|
||||
assert len(preds_batch_1) == len(preds_batch_3)
|
||||
|
||||
by_key_1 = {
|
||||
@ -91,12 +118,14 @@ def test_process_audio_matches_process_spectrogram(
|
||||
) -> None:
|
||||
"""User story: users can call either audio or spectrogram entrypoint."""
|
||||
|
||||
# When
|
||||
audio = api_v2.load_audio(example_audio_files[0])
|
||||
from_audio = api_v2.process_audio(audio)
|
||||
|
||||
spec = api_v2.generate_spectrogram(audio)
|
||||
from_spec = api_v2.process_spectrogram(spec)
|
||||
|
||||
# Then
|
||||
assert len(from_audio) == len(from_spec)
|
||||
|
||||
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
|
||||
@ -116,8 +145,10 @@ def test_process_spectrogram_rejects_batched_input(
|
||||
) -> None:
|
||||
"""User story: invalid batched input gives a clear error."""
|
||||
|
||||
# Given
|
||||
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
|
||||
|
||||
# When/Then
|
||||
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
||||
api_v2.process_spectrogram(spec)
|
||||
|
||||
@ -184,26 +215,35 @@ def test_user_can_read_extracted_features_per_detection(
|
||||
@pytest.mark.slow
|
||||
def test_user_can_load_checkpoint_and_finetune(
|
||||
tmp_path: Path,
|
||||
example_targets_config: TargetConfig,
|
||||
example_annotations,
|
||||
) -> None:
|
||||
"""User story: load a checkpoint and continue training from it."""
|
||||
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
api = BatDetect2API.from_config(
|
||||
targets_config=example_targets_config,
|
||||
)
|
||||
module = build_training_module(
|
||||
model_config=api.model_config,
|
||||
targets_config=example_targets_config,
|
||||
class_names=api.targets.class_names,
|
||||
dimension_names=api.roi_mapper.dimension_names,
|
||||
)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
checkpoint_path = tmp_path / "base.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
config = BatDetect2Config()
|
||||
config.train.trainer.limit_train_batches = 1
|
||||
config.train.trainer.limit_val_batches = 1
|
||||
config.train.trainer.log_every_n_steps = 1
|
||||
config.train.train_loader.batch_size = 1
|
||||
config.train.train_loader.augmentations.enabled = False
|
||||
train_config = api.train_config.model_copy(deep=True)
|
||||
train_config.trainer.limit_train_batches = 1
|
||||
train_config.trainer.limit_val_batches = 1
|
||||
train_config.trainer.log_every_n_steps = 1
|
||||
train_config.train_loader.batch_size = 1
|
||||
train_config.train_loader.augmentations.enabled = False
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
train_config=config.train,
|
||||
train_config=train_config,
|
||||
)
|
||||
finetune_dir = tmp_path / "finetuned"
|
||||
|
||||
@ -222,62 +262,34 @@ def test_user_can_load_checkpoint_and_finetune(
|
||||
assert checkpoints
|
||||
|
||||
|
||||
def test_user_can_load_checkpoint_with_new_targets(
|
||||
tmp_path: Path,
|
||||
sample_targets,
|
||||
) -> None:
|
||||
"""User story: start from checkpoint with a new target definition."""
|
||||
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
checkpoint_path = tmp_path / "base_transfer.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
targets_config=sample_targets.config,
|
||||
)
|
||||
source_detector = cast(Detector, source_model.detector)
|
||||
detector = cast(Detector, api.model.detector)
|
||||
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||
|
||||
assert api.targets.config == sample_targets.config # type: ignore
|
||||
assert detector.num_classes == len(sample_targets.class_names)
|
||||
assert (
|
||||
classifier_head.classifier.out_channels
|
||||
== len(sample_targets.class_names) + 1
|
||||
)
|
||||
|
||||
source_backbone = source_detector.backbone.state_dict()
|
||||
target_backbone = detector.backbone.state_dict()
|
||||
assert source_backbone
|
||||
for key, value in source_backbone.items():
|
||||
assert key in target_backbone
|
||||
torch.testing.assert_close(target_backbone[key], value)
|
||||
|
||||
|
||||
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
example_targets_config: TargetConfig,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""User story: same targets config does not rebuild prediction heads."""
|
||||
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
# Given
|
||||
source_api = BatDetect2API.from_config(
|
||||
targets_config=example_targets_config
|
||||
)
|
||||
module = build_training_module(
|
||||
model_config=source_api.model_config,
|
||||
targets_config=example_targets_config,
|
||||
class_names=source_api.targets.class_names,
|
||||
dimension_names=source_api.roi_mapper.dimension_names,
|
||||
)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
checkpoint_path = tmp_path / "same_targets.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
source_model, source_model_config = load_model_from_checkpoint(
|
||||
checkpoint_path
|
||||
)
|
||||
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||
source_detector = cast(Detector, source_model.detector)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
targets_config=source_model_config.targets,
|
||||
)
|
||||
# When
|
||||
api = BatDetect2API.from_checkpoint(checkpoint_path)
|
||||
|
||||
# Then
|
||||
detector = cast(Detector, api.model.detector)
|
||||
|
||||
for key, value in source_detector.classifier_head.state_dict().items():
|
||||
@ -295,42 +307,6 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_finetune_only_heads(
|
||||
tmp_path: Path,
|
||||
example_annotations,
|
||||
) -> None:
|
||||
"""User story: fine-tune only prediction heads."""
|
||||
|
||||
api = BatDetect2API.from_config(BatDetect2Config())
|
||||
finetune_dir = tmp_path / "heads_only"
|
||||
|
||||
api.finetune(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
trainable="heads",
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=finetune_dir,
|
||||
log_dir=tmp_path / "logs",
|
||||
num_epochs=1,
|
||||
seed=0,
|
||||
)
|
||||
detector = cast(Detector, api.model.detector)
|
||||
|
||||
backbone_params = list(detector.backbone.parameters())
|
||||
classifier_params = list(detector.classifier_head.parameters())
|
||||
bbox_params = list(detector.bbox_head.parameters())
|
||||
|
||||
assert backbone_params
|
||||
assert classifier_params
|
||||
assert bbox_params
|
||||
assert all(not parameter.requires_grad for parameter in backbone_params)
|
||||
assert all(parameter.requires_grad for parameter in classifier_params)
|
||||
assert all(parameter.requires_grad for parameter in bbox_params)
|
||||
assert list(finetune_dir.rglob("*.ckpt"))
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||
api_v2: BatDetect2API,
|
||||
@ -348,8 +324,6 @@ def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||
|
||||
assert isinstance(metrics, list)
|
||||
assert len(metrics) == 1
|
||||
assert isinstance(metrics[0], dict)
|
||||
assert len(metrics[0]) > 0
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) == 1
|
||||
|
||||
@ -450,8 +424,17 @@ def test_detection_threshold_override_changes_spectrogram_results(
|
||||
spec = api_v2.generate_spectrogram(audio)
|
||||
default_detections = api_v2.process_spectrogram(spec)
|
||||
strict_detections = api_v2.process_spectrogram(
|
||||
spec,
|
||||
detection_threshold=1.0,
|
||||
spec, detection_threshold=1.0
|
||||
)
|
||||
|
||||
assert len(strict_detections) <= len(default_detections)
|
||||
|
||||
|
||||
def test_user_can_create_api_with_custom_targets_and_model_metadata_matches(
|
||||
sample_targets,
|
||||
) -> None:
|
||||
"""User story: custom targets define model output names for a new API."""
|
||||
|
||||
api = BatDetect2API.from_config(targets_config=sample_targets.config)
|
||||
|
||||
assert api.model.class_names == sample_targets.class_names
|
||||
|
||||
114
tests/test_api_v2/test_finetune.py
Normal file
114
tests/test_api_v2/test_finetune.py
Normal file
@ -0,0 +1,114 @@
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_finetune_only_heads(
|
||||
tmp_path: Path,
|
||||
example_annotations,
|
||||
) -> None:
|
||||
"""User story: fine-tune only prediction heads."""
|
||||
|
||||
api = BatDetect2API.from_config()
|
||||
source_classifier_head = api.model.detector.classifier_head
|
||||
source_bbox_head = api.model.detector.bbox_head
|
||||
source_backbone = api.model.detector.backbone
|
||||
finetune_dir = tmp_path / "heads_only"
|
||||
|
||||
finetuned_api = api.finetune(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
targets_config=TargetConfig(),
|
||||
trainable="heads",
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=finetune_dir,
|
||||
log_dir=tmp_path / "logs",
|
||||
num_epochs=1,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
detector = cast(Detector, finetuned_api.model.detector)
|
||||
|
||||
backbone_params = list(detector.backbone.parameters())
|
||||
classifier_params = list(detector.classifier_head.parameters())
|
||||
bbox_params = list(detector.bbox_head.parameters())
|
||||
|
||||
assert backbone_params
|
||||
assert classifier_params
|
||||
assert bbox_params
|
||||
assert all(not parameter.requires_grad for parameter in backbone_params)
|
||||
assert all(parameter.requires_grad for parameter in classifier_params)
|
||||
assert all(parameter.requires_grad for parameter in bbox_params)
|
||||
assert finetuned_api is not api
|
||||
assert detector.backbone is source_backbone
|
||||
assert detector.classifier_head is not source_classifier_head
|
||||
assert detector.bbox_head is not source_bbox_head
|
||||
assert list(finetune_dir.rglob("*.ckpt"))
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_finetune_replaces_targets_and_checkpoint_owns_new_targets(
|
||||
tmp_path: Path,
|
||||
example_annotations,
|
||||
) -> None:
|
||||
"""User story: fine-tuning writes checkpoints with the new targets."""
|
||||
|
||||
source_api = BatDetect2API.from_config()
|
||||
source_evaluator = source_api.evaluator
|
||||
source_formatter = source_api.formatter
|
||||
source_output_transform = source_api.output_transform
|
||||
new_targets = TargetConfig.model_validate(
|
||||
{
|
||||
"classification_targets": [
|
||||
{
|
||||
"name": "single_class",
|
||||
"tags": [{"key": "class", "value": "single_class"}],
|
||||
}
|
||||
],
|
||||
"roi": {"mapper": "top_left"},
|
||||
}
|
||||
)
|
||||
finetune_dir = tmp_path / "new_targets"
|
||||
|
||||
finetuned_api = source_api.finetune(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
targets_config=new_targets,
|
||||
trainable="heads",
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=finetune_dir,
|
||||
log_dir=tmp_path / "logs",
|
||||
num_epochs=1,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
checkpoints = list(finetune_dir.rglob("*.ckpt"))
|
||||
|
||||
assert source_api.targets.get_config() != new_targets.model_dump(
|
||||
mode="json"
|
||||
)
|
||||
assert finetuned_api.targets.get_config() == new_targets.model_dump(
|
||||
mode="json"
|
||||
)
|
||||
assert finetuned_api.evaluator is not source_evaluator
|
||||
assert finetuned_api.formatter is not source_formatter
|
||||
assert finetuned_api.output_transform is not source_output_transform
|
||||
assert finetuned_api.evaluator.targets is finetuned_api.targets
|
||||
assert finetuned_api.evaluator.transform is finetuned_api.output_transform
|
||||
assert finetuned_api.model.class_names == ["single_class"]
|
||||
assert finetuned_api.model.dimension_names == ["width", "height"]
|
||||
assert checkpoints
|
||||
|
||||
_, configs = load_model_from_checkpoint(checkpoints[0])
|
||||
assert configs.targets.model_dump(mode="json") == new_targets.model_dump(
|
||||
mode="json"
|
||||
)
|
||||
@ -5,7 +5,6 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.outputs import build_output_formatter
|
||||
from batdetect2.outputs.formats import (
|
||||
BatDetect2OutputConfig,
|
||||
@ -18,7 +17,7 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
def api_v2() -> BatDetect2API:
|
||||
"""User story: API object manages prediction IO formats."""
|
||||
|
||||
return BatDetect2API.from_config(BatDetect2Config())
|
||||
return BatDetect2API.from_config()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
0
tests/test_cli/__init__.py
Normal file
0
tests/test_cli/__init__.py
Normal file
99
tests/test_cli/test_finetune.py
Normal file
99
tests/test_cli/test_finetune.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""CLI tests for finetune command."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
def test_cli_finetune_help() -> None:
|
||||
"""User story: inspect finetune command interface and options."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["finetune", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "TRAIN_DATASET" in result.output
|
||||
assert "--model" in result.output
|
||||
assert "--targets" in result.output
|
||||
assert "--training-config" in result.output
|
||||
assert "--audio-config" in result.output
|
||||
assert "--logging-config" in result.output
|
||||
assert "--evaluation-config" not in result.output
|
||||
assert "--inference-config" not in result.output
|
||||
assert "--outputs-config" not in result.output
|
||||
|
||||
|
||||
def test_cli_finetune_requires_model() -> None:
|
||||
"""User story: finetune requires a checkpoint argument."""
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"finetune",
|
||||
"example_data/dataset.yaml",
|
||||
"--targets",
|
||||
"example_data/targets.yaml",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--model" in result.output
|
||||
|
||||
|
||||
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
||||
"""User story: finetune requires a new target definition."""
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"finetune",
|
||||
"example_data/dataset.yaml",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--targets" in result.output
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_finetune_from_checkpoint_runs_on_small_dataset(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
) -> None:
|
||||
"""User story: fine-tune a checkpoint via CLI with new targets."""
|
||||
|
||||
ckpt_dir = tmp_path / "checkpoints"
|
||||
log_dir = tmp_path / "logs"
|
||||
ckpt_dir.mkdir()
|
||||
log_dir.mkdir()
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"finetune",
|
||||
"example_data/dataset.yaml",
|
||||
"--val-dataset",
|
||||
"example_data/dataset.yaml",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
"--targets",
|
||||
"example_data/targets.yaml",
|
||||
"--num-epochs",
|
||||
"1",
|
||||
"--train-workers",
|
||||
"0",
|
||||
"--val-workers",
|
||||
"0",
|
||||
"--ckpt-dir",
|
||||
str(ckpt_dir),
|
||||
"--log-dir",
|
||||
str(log_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1
|
||||
@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together(
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--model-config cannot be used with --model" in result.output
|
||||
|
||||
|
||||
def test_cli_train_rejects_model_and_targets_together(
|
||||
tiny_checkpoint_path: Path,
|
||||
) -> None:
|
||||
"""User story: checkpoint training does not accept new targets."""
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
"example_data/dataset.yaml",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
"--targets",
|
||||
"example_data/targets.yaml",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--targets cannot be used with --model" in result.output
|
||||
|
||||
@ -203,8 +203,8 @@ in_channels: 1
|
||||
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
||||
"""load_backbone_config loads the real example config correctly."""
|
||||
config = load_backbone_config(
|
||||
example_data_dir / "config.yaml",
|
||||
field="model.architecture",
|
||||
example_data_dir / "configs" / "model.yaml",
|
||||
field="architecture",
|
||||
)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
|
||||
@ -1,9 +1,34 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
Targets,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
|
||||
|
||||
def test_targets_get_config_returns_a_json_serializable_dict() -> None:
|
||||
targets = build_targets(TargetConfig())
|
||||
|
||||
config_dict = targets.get_config()
|
||||
assert isinstance(config_dict, dict)
|
||||
assert json.dumps(config_dict)
|
||||
|
||||
|
||||
def test_targets_from_config_rebuilds_equivalent_targets() -> None:
|
||||
original = build_targets(TargetConfig())
|
||||
|
||||
rebuilt = Targets.from_config(original.get_config())
|
||||
|
||||
assert rebuilt.class_names == original.class_names
|
||||
assert rebuilt.detection_class_name == original.detection_class_name
|
||||
assert rebuilt.detection_class_tags == original.detection_class_tags
|
||||
assert rebuilt.get_config() == original.get_config()
|
||||
|
||||
|
||||
def test_can_override_default_roi_mapper_per_class(
|
||||
|
||||
40
tests/test_targets/test_utils.py
Normal file
40
tests/test_targets/test_utils.py
Normal file
@ -0,0 +1,40 @@
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import (
|
||||
TargetClassConfig,
|
||||
TargetConfig,
|
||||
build_targets,
|
||||
check_target_compatibility,
|
||||
)
|
||||
|
||||
|
||||
def _target_class(name: str) -> TargetClassConfig:
|
||||
return TargetClassConfig(
|
||||
name=name,
|
||||
tags=[data.Tag(key="class", value=name)],
|
||||
)
|
||||
|
||||
|
||||
def test_check_target_compatibility_accepts_superset_targets() -> None:
|
||||
config = TargetConfig(
|
||||
classification_targets=[
|
||||
_target_class("pip35"),
|
||||
_target_class("myo"),
|
||||
_target_class("extra"),
|
||||
]
|
||||
)
|
||||
targets = build_targets(config)
|
||||
|
||||
assert check_target_compatibility(targets, ["pip35", "myo"])
|
||||
|
||||
|
||||
def test_check_target_compatibility_rejects_missing_model_classes() -> None:
|
||||
config = TargetConfig(
|
||||
classification_targets=[
|
||||
_target_class("pip35"),
|
||||
_target_class("myo"),
|
||||
]
|
||||
)
|
||||
targets = build_targets(config)
|
||||
|
||||
assert not check_target_compatibility(targets, ["pip35", "nyc"])
|
||||
@ -3,20 +3,19 @@ from pathlib import Path
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.train import run_train
|
||||
from batdetect2.train import TrainingConfig, run_train
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
def _build_fast_train_config() -> BatDetect2Config:
|
||||
config = BatDetect2Config()
|
||||
config.train.trainer.limit_train_batches = 1
|
||||
config.train.trainer.limit_val_batches = 1
|
||||
config.train.trainer.log_every_n_steps = 1
|
||||
config.train.trainer.check_val_every_n_epoch = 1
|
||||
config.train.train_loader.batch_size = 1
|
||||
config.train.train_loader.augmentations.enabled = False
|
||||
def _build_fast_train_config() -> TrainingConfig:
|
||||
config = TrainingConfig()
|
||||
config.trainer.limit_train_batches = 1
|
||||
config.trainer.limit_val_batches = 1
|
||||
config.trainer.log_every_n_steps = 1
|
||||
config.trainer.check_val_every_n_epoch = 1
|
||||
config.train_loader.batch_size = 1
|
||||
config.train_loader.augmentations.enabled = False
|
||||
return config
|
||||
|
||||
|
||||
@ -29,9 +28,7 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
@ -50,14 +47,12 @@ def test_train_without_validation_can_still_save_last_checkpoint(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
config = _build_fast_train_config()
|
||||
config.train.checkpoints.save_last = True
|
||||
config.checkpoints.save_last = True
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=None,
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
@ -73,16 +68,14 @@ def test_train_controls_which_checkpoints_are_kept(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
config = _build_fast_train_config()
|
||||
config.train.checkpoints.save_top_k = 1
|
||||
config.train.checkpoints.save_last = True
|
||||
config.train.checkpoints.filename = "epoch{epoch}"
|
||||
config.checkpoints.save_top_k = 1
|
||||
config.checkpoints.save_last = True
|
||||
config.checkpoints.filename = "epoch{epoch}"
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config,
|
||||
num_epochs=3,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
|
||||
@ -1,12 +1,43 @@
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import load_config
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig
|
||||
|
||||
|
||||
def test_example_config_is_valid(example_data_dir):
|
||||
conf = load_config(
|
||||
example_data_dir / "config.yaml",
|
||||
schema=BatDetect2Config,
|
||||
extra="forbid",
|
||||
strict=True,
|
||||
def test_example_split_configs_are_valid(example_data_dir):
|
||||
configs_dir = example_data_dir / "configs"
|
||||
|
||||
assert isinstance(
|
||||
AudioConfig.load(configs_dir / "audio.yaml"), AudioConfig
|
||||
)
|
||||
assert isinstance(
|
||||
ModelConfig.load(configs_dir / "model.yaml"), ModelConfig
|
||||
)
|
||||
assert isinstance(
|
||||
TargetConfig.load(example_data_dir / "targets.yaml"),
|
||||
TargetConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
TrainingConfig.load(configs_dir / "training.yaml"),
|
||||
TrainingConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
EvaluationConfig.load(configs_dir / "evaluation.yaml"),
|
||||
EvaluationConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
InferenceConfig.load(configs_dir / "inference.yaml"),
|
||||
InferenceConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
OutputsConfig.load(configs_dir / "outputs.yaml"),
|
||||
OutputsConfig,
|
||||
)
|
||||
assert isinstance(
|
||||
AppLoggingConfig.load(configs_dir / "logging.yaml"),
|
||||
AppLoggingConfig,
|
||||
)
|
||||
assert isinstance(conf, BatDetect2Config)
|
||||
|
||||
@ -10,25 +10,42 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.models import ModelConfig, build_model
|
||||
from batdetect2.targets.classes import TargetClassConfig
|
||||
from batdetect2.models import (
|
||||
ModelConfig,
|
||||
build_model,
|
||||
build_model_with_new_targets,
|
||||
)
|
||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||
from batdetect2.train import (
|
||||
TrainingConfig,
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
run_train,
|
||||
)
|
||||
from batdetect2.train.logging import (
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
)
|
||||
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||
from batdetect2.train.train import build_training_module
|
||||
|
||||
|
||||
def build_default_module(config: BatDetect2Config | None = None):
|
||||
config = config or BatDetect2Config()
|
||||
def build_default_module(
|
||||
target_config: TargetConfig | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
):
|
||||
target_config = target_config or TargetConfig()
|
||||
model_config = model_config or ModelConfig()
|
||||
train_config = train_config or TrainingConfig()
|
||||
targets = build_targets(target_config)
|
||||
roi_mapper = build_roi_mapping(target_config.roi)
|
||||
return build_training_module(
|
||||
model_config=config.model,
|
||||
train_config=config.train,
|
||||
model_config=model_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
train_config=train_config,
|
||||
)
|
||||
|
||||
|
||||
@ -64,7 +81,7 @@ def test_can_save_checkpoint(
|
||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_load_model_from_checkpoint_returns_model_and_config(
|
||||
def test_load_model_from_checkpoint_returns_model_and_configs(
|
||||
tmp_path: Path,
|
||||
):
|
||||
input_model_config = ModelConfig(samplerate=192_000)
|
||||
@ -72,8 +89,13 @@ def test_load_model_from_checkpoint_returns_model_and_config(
|
||||
input_model_config.model_dump(mode="json")
|
||||
)
|
||||
train_config = TrainingConfig()
|
||||
targets_config = TargetConfig()
|
||||
targets = build_targets(targets_config)
|
||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||
module = build_training_module(
|
||||
model_config=input_model_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
train_config=train_config,
|
||||
)
|
||||
trainer = L.Trainer()
|
||||
@ -81,12 +103,20 @@ def test_load_model_from_checkpoint_returns_model_and_config(
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
model, loaded_model_config = load_model_from_checkpoint(path)
|
||||
model, loaded_configs = load_model_from_checkpoint(path)
|
||||
|
||||
assert model is not None
|
||||
assert loaded_model_config.model_dump(
|
||||
assert loaded_configs.model.model_dump(
|
||||
mode="json"
|
||||
) == expected_model_config.model_dump(mode="json")
|
||||
assert loaded_configs.targets.model_dump(
|
||||
mode="json"
|
||||
) == targets_config.model_dump(mode="json")
|
||||
assert loaded_configs.train.model_dump(
|
||||
mode="json"
|
||||
) == train_config.model_dump(mode="json")
|
||||
assert model.class_names == targets.class_names
|
||||
assert model.dimension_names == roi_mapper.dimension_names
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
assert recovered.train_config.model_dump(
|
||||
@ -100,6 +130,9 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
||||
model_config.model_dump(mode="json")
|
||||
)
|
||||
train_config = TrainingConfig()
|
||||
targets_config = TargetConfig()
|
||||
targets = build_targets(targets_config)
|
||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
|
||||
train_config.trainer.max_epochs = 3
|
||||
@ -107,6 +140,8 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
||||
|
||||
module = build_training_module(
|
||||
model_config=model_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
train_config=train_config,
|
||||
)
|
||||
trainer = L.Trainer()
|
||||
@ -114,28 +149,56 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
_, recovered_configs = load_model_from_checkpoint(path)
|
||||
assert not DeepDiff(
|
||||
recovered.model_config.model_dump(mode="json"),
|
||||
recovered_configs.model.model_dump(mode="json"),
|
||||
expected_model_config.model_dump(mode="json"),
|
||||
)
|
||||
assert not DeepDiff(
|
||||
recovered.train_config.model_dump(mode="json"),
|
||||
recovered_configs.train.model_dump(mode="json"),
|
||||
train_config.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
def test_load_model_from_checkpoint_includes_targets_config(tmp_path: Path):
|
||||
targets_config = TargetConfig()
|
||||
targets = build_targets(targets_config)
|
||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||
module = build_training_module(
|
||||
model_config=ModelConfig(),
|
||||
targets_config=targets_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
train_config=TrainingConfig(),
|
||||
)
|
||||
trainer = L.Trainer()
|
||||
path = tmp_path / "example.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
_, loaded_configs = load_model_from_checkpoint(path)
|
||||
|
||||
assert loaded_configs.targets.model_dump(
|
||||
mode="json"
|
||||
) == targets_config.model_dump(mode="json")
|
||||
|
||||
|
||||
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
||||
model_config = ModelConfig()
|
||||
expected_model_config = ModelConfig.model_validate(
|
||||
model_config.model_dump(mode="json")
|
||||
)
|
||||
train_config = TrainingConfig()
|
||||
targets_config = TargetConfig()
|
||||
targets = build_targets(targets_config)
|
||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
|
||||
|
||||
module = build_training_module(
|
||||
model_config=model_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
train_config=train_config,
|
||||
)
|
||||
|
||||
@ -153,14 +216,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
assert recovered.model_config.model_dump(
|
||||
_, recovered_configs = load_model_from_checkpoint(path)
|
||||
assert recovered_configs.model.model_dump(
|
||||
mode="json"
|
||||
) == expected_model_config.model_dump(mode="json")
|
||||
assert recovered.train_config.model_dump(
|
||||
assert recovered_configs.train.model_dump(
|
||||
mode="json"
|
||||
) == train_config.model_dump(mode="json")
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
|
||||
loaded_optimization_config = recovered.configure_optimizers()
|
||||
loaded_optimizer = loaded_optimization_config["optimizer"]
|
||||
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
|
||||
@ -175,12 +240,28 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
_, stored_configs = load_model_from_checkpoint(path)
|
||||
api = BatDetect2API.from_checkpoint(path)
|
||||
|
||||
assert api.model_config.model_dump(
|
||||
mode="json"
|
||||
) == module.model_config.model_dump(mode="json")
|
||||
assert api.audio_config.samplerate == module.model_config.samplerate
|
||||
) == stored_configs.model.model_dump(mode="json")
|
||||
assert api.audio_config.samplerate == stored_configs.model.samplerate
|
||||
|
||||
|
||||
def test_api_from_checkpoint_reconstructs_targets_from_checkpoint(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
targets_config = TargetConfig()
|
||||
module = build_default_module(target_config=targets_config)
|
||||
trainer = L.Trainer()
|
||||
path = tmp_path / "example.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(path)
|
||||
|
||||
assert api.targets.get_config() == targets_config.model_dump(mode="json")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -189,19 +270,26 @@ def test_train_smoke_produces_loadable_checkpoint(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
sample_audio_loader: AudioLoader,
|
||||
):
|
||||
config = BatDetect2Config()
|
||||
config.train.trainer.limit_train_batches = 1
|
||||
config.train.trainer.limit_val_batches = 1
|
||||
config.train.trainer.log_every_n_steps = 1
|
||||
config.train.train_loader.batch_size = 1
|
||||
config.train.train_loader.augmentations.enabled = False
|
||||
# Given
|
||||
train_config = TrainingConfig.model_validate(
|
||||
{
|
||||
"trainer": {
|
||||
"limit_train_batches": 1,
|
||||
"limit_val_batches": 1,
|
||||
"log_every_n_steps": 1,
|
||||
},
|
||||
"train_loader": {
|
||||
"batch_size": 1,
|
||||
"augmentations": {"enabled": False},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# When
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=train_config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
@ -209,18 +297,11 @@ def test_train_smoke_produces_loadable_checkpoint(
|
||||
seed=0,
|
||||
)
|
||||
|
||||
# Then
|
||||
checkpoints = list(tmp_path.rglob("*.ckpt"))
|
||||
assert checkpoints
|
||||
|
||||
model, model_config = load_model_from_checkpoint(checkpoints[0])
|
||||
assert model_config.samplerate == config.model.samplerate
|
||||
assert model_config.architecture.name == config.model.architecture.name
|
||||
assert model_config.preprocess.model_dump(
|
||||
mode="json"
|
||||
) == config.model.preprocess.model_dump(mode="json")
|
||||
assert model_config.postprocess.model_dump(
|
||||
mode="json"
|
||||
) == config.model.postprocess.model_dump(mode="json")
|
||||
|
||||
wav = torch.tensor(
|
||||
sample_audio_loader.load_clip(example_annotations[0].clip)
|
||||
@ -230,10 +311,18 @@ def test_train_smoke_produces_loadable_checkpoint(
|
||||
|
||||
|
||||
def test_build_training_module_uses_provided_model() -> None:
|
||||
model = build_model(ModelConfig())
|
||||
targets = build_targets(TargetConfig())
|
||||
roi_mapper = build_roi_mapping(TargetConfig().roi)
|
||||
model = build_model(
|
||||
ModelConfig(),
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
)
|
||||
|
||||
module = build_training_module(
|
||||
model_config=ModelConfig(),
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
train_config=TrainingConfig(),
|
||||
model=model,
|
||||
)
|
||||
@ -241,18 +330,117 @@ def test_build_training_module_uses_provided_model() -> None:
|
||||
assert module.model is model
|
||||
|
||||
|
||||
def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
||||
None
|
||||
):
|
||||
source_targets_config = TargetConfig()
|
||||
source_targets = build_targets(source_targets_config)
|
||||
source_roi_mapper = build_roi_mapping(source_targets_config.roi)
|
||||
source_model = build_model(
|
||||
ModelConfig(),
|
||||
class_names=source_targets.class_names,
|
||||
dimension_names=source_roi_mapper.dimension_names,
|
||||
)
|
||||
|
||||
new_targets_config = TargetConfig.model_validate(
|
||||
{
|
||||
"classification_targets": [
|
||||
{
|
||||
"name": "single_class",
|
||||
"tags": [{"key": "class", "value": "single_class"}],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
new_targets = build_targets(new_targets_config)
|
||||
new_roi_mapper = build_roi_mapping(new_targets_config.roi)
|
||||
|
||||
rebuilt_model = build_model_with_new_targets(
|
||||
model=source_model,
|
||||
targets=new_targets,
|
||||
roi_mapper=new_roi_mapper,
|
||||
)
|
||||
|
||||
source_detector = source_model.detector
|
||||
rebuilt_detector = rebuilt_model.detector
|
||||
|
||||
assert rebuilt_detector.backbone is source_detector.backbone
|
||||
assert (
|
||||
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
||||
)
|
||||
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
|
||||
assert rebuilt_model.class_names == ["single_class"]
|
||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_run_train_logs_training_artifacts(
|
||||
tmp_path: Path,
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
example_dataset,
|
||||
) -> None:
|
||||
train_config = TrainingConfig.model_validate(
|
||||
{
|
||||
"trainer": {
|
||||
"limit_train_batches": 1,
|
||||
"limit_val_batches": 1,
|
||||
"log_every_n_steps": 1,
|
||||
},
|
||||
"train_loader": {
|
||||
"batch_size": 1,
|
||||
"augmentations": {"enabled": False},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=train_config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=tmp_path / "checkpoints",
|
||||
log_dir=tmp_path / "logs",
|
||||
seed=0,
|
||||
logging_callbacks=[
|
||||
DatasetConfigArtifactLogging(
|
||||
train_dataset_config=DatasetConfigArtifact(
|
||||
filename="train_dataset.yaml",
|
||||
config=example_dataset,
|
||||
),
|
||||
val_dataset_config=DatasetConfigArtifact(
|
||||
filename="val_dataset.yaml",
|
||||
config=example_dataset,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
artifact_root = next((tmp_path / "logs").rglob("training_artifacts"))
|
||||
|
||||
assert (artifact_root / "targets.yaml").exists()
|
||||
assert (artifact_root / "train_dataset.yaml").exists()
|
||||
assert (artifact_root / "val_dataset.yaml").exists()
|
||||
assert (artifact_root / "train_class_summary.csv").exists()
|
||||
assert (artifact_root / "val_class_summary.csv").exists()
|
||||
|
||||
|
||||
def test_run_train_rejects_incompatible_model_config(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
model = build_model(ModelConfig())
|
||||
# Given
|
||||
targets_config = TargetConfig()
|
||||
targets = build_targets(targets_config)
|
||||
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||
incompatible_config = ModelConfig()
|
||||
incompatible_config.targets.classification_targets.append(
|
||||
TargetClassConfig(
|
||||
name="dummy_class",
|
||||
tags=[data.Tag(key="class", value="Dummy class")],
|
||||
)
|
||||
incompatible_model = build_model(
|
||||
incompatible_config,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=[*roi_mapper.dimension_names, "extra_dim"],
|
||||
)
|
||||
|
||||
# When/Then
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Provided model is incompatible with model_config",
|
||||
@ -260,7 +448,10 @@ def test_run_train_rejects_incompatible_model_config(
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=None,
|
||||
model=model,
|
||||
model=incompatible_model,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
model_config=incompatible_config,
|
||||
targets_config=targets_config,
|
||||
train_config=TrainingConfig(),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user