Compare commits

...

17 Commits

Author SHA1 Message Date
mbsantiago
5a974711b0 feat: log training provenance artifacts 2026-05-05 14:09:53 +01:00
mbsantiago
7b2699786f feat: add checkpoint finetuning workflow 2026-05-05 12:16:37 +01:00
mbsantiago
75e52cc548 fix: keep checkpoint targets immutable 2026-05-05 10:59:32 +01:00
mbsantiago
7d416e0f99 Set model mode during training 2026-05-05 00:55:27 +01:00
mbsantiago
2d0b810ed3 Log scheduler and optimizer ocnfigs 2026-05-05 00:42:49 +01:00
mbsantiago
7a46fa021b Use base config for import config 2026-05-05 00:42:32 +01:00
mbsantiago
cbe428fc3f Set default num workers to 0 2026-05-04 23:15:51 +01:00
mbsantiago
7a10b7ffff refactor: remove aggregate app config 2026-05-04 23:10:31 +01:00
Santiago Martinez Balvanera
c27e7f9f52 fix: ensure checkpoint is path 2026-05-04 22:56:34 +01:00
mbsantiago
aa36df668f feat: persist target configs in checkpoints 2026-05-04 22:54:32 +01:00
mbsantiago
20a7c058fc feat: support target config roundtrips 2026-05-04 22:31:32 +01:00
mbsantiago
eec126a502 fix: align cli and helpers with model refactor 2026-05-04 21:20:02 +01:00
mbsantiago
57236fc82a refactor: decouple model metadata from target configs 2026-05-04 21:18:17 +01:00
mbsantiago
e33053614a update gitignore 2026-05-04 17:13:18 +01:00
mbsantiago
ae4f742345 Modify max epochs with override before logging 2026-05-04 17:12:55 +01:00
mbsantiago
44f9870e9e fix: have 0 default workers for eval 2026-05-04 17:12:35 +01:00
mbsantiago
d7e61ccd43 feat(plotting): add size heatmap label plotting 2026-05-04 16:48:49 +01:00
61 changed files with 2053 additions and 696 deletions

1
.gitignore vendored
View File

@ -132,3 +132,4 @@ notebooks/tmp
# Assets
!assets/*
/models

View File

@ -2,7 +2,8 @@
`BatDetect2API` is the main entry point for the current Python workflow.
It wraps model loading, inference, evaluation, output formatting, and training-related entry points behind one object.
It wraps model loading, inference, evaluation, output formatting, and
training-related entry points behind one object.
Defined in `batdetect2.api_v2`.
@ -10,8 +11,8 @@ Defined in `batdetect2.api_v2`.
- `BatDetect2API.from_checkpoint(path, ...)`
- load a trained checkpoint and optional config overrides.
- `BatDetect2API.from_config(config)`
- build a full stack from a `BatDetect2Config` object.
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
- build a full stack from separate config objects.
## Inference methods
@ -46,10 +47,12 @@ Defined in `batdetect2.api_v2`.
## Output persistence helpers
- `save_predictions(predictions, path, audio_dir=None, format=None, config=None)`
- `save_predictions(predictions, path, audio_dir=None, format=None,
config=None)`
- `load_predictions(path, format=None, config=None)`
Use these when you want to save programmatic predictions without going through the CLI.
Use these when you want to save programmatic predictions without going through
the CLI.
## Training and evaluation entry points
@ -60,6 +63,9 @@ Use these when you want to save programmatic predictions without going through t
## Related pages
- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline`
- Outputs config reference: {doc}`outputs-config`
- Output formats reference: {doc}`output-formats`
- Python tutorial:
{doc}`../tutorials/integrate-with-a-python-pipeline`
- Outputs config reference:
{doc}`outputs-config`
- Output formats reference:
{doc}`output-formats`

View File

@ -1,38 +0,0 @@
# Top-level app config reference
The top-level config object is `BatDetect2Config`.
Defined in `batdetect2.config`.
It combines the main configuration surfaces used across training, inference, evaluation, outputs, and logging.
## Fields
- `config_version`
- `train`
- training-specific config.
- `evaluation`
- evaluation task and plot config.
- `model`
- model architecture, preprocessing, postprocessing, and targets.
- `audio`
- audio loading and resampling config.
- `inference`
- clipping and loader config for prediction-time workflows.
- `outputs`
- output format and output transform config.
- `logging`
- logging backend and formatting config.
## Mental model
Think of `BatDetect2Config` as the complete application wiring for the current stack.
Use it when you want one reproducible config that describes the whole workflow.
## Related pages
- Inference config: {doc}`inference-config`
- Evaluation config: {doc}`evaluation-config`
- Outputs config: {doc}`outputs-config`
- General config reference: {doc}`configs`

View File

@ -24,8 +24,8 @@ for full options and argument details.
- Global CLI options are documented in {doc}`base`.
- Paths with spaces should be wrapped in quotes.
- Input audio is expected to be mono.
- Legacy `detect` uses a required threshold argument, while `predict` uses
the optional `--detection-threshold` override.
- Legacy `detect` uses a required threshold argument, while `predict` uses the
optional `--detection-threshold` override.
```{warning}
`batdetect2 detect` is a legacy command.

View File

@ -1,5 +1,15 @@
Config reference
================
.. automodule:: batdetect2.config
:members:
BatDetect2 uses separate config objects for different workflow surfaces.
Use the dedicated reference pages for each config family:
- inference config
- evaluation config
- outputs config
- preprocessing config
- postprocess config
- targets config workflow
Example config files live under `example_data/configs/`.

View File

@ -2,14 +2,14 @@
Reference pages are the detailed lookup pages.
Use this section when you need exact command options, setting names, output details, or Python API entries.
Use this section when you need exact command options, setting names, output
details, or Python API entries.
```{toctree}
:maxdepth: 1
cli/index
api
app-config
inference-config
evaluation-config
outputs-config

View File

@ -1,192 +0,0 @@
config_version: v1
audio:
samplerate: 256000
resample:
enabled: true
method: poly
model:
samplerate: 256000
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_subtraction
architecture:
name: UNetBackbone
input_height: 128
in_channels: 1
encoder:
layers:
- name: FreqCoordConvDown
out_channels: 32
- name: FreqCoordConvDown
out_channels: 64
- name: LayerGroup
layers:
- name: FreqCoordConvDown
out_channels: 128
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
layers:
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- name: FreqCoordConvUp
out_channels: 64
- name: FreqCoordConvUp
out_channels: 32
- name: LayerGroup
layers:
- name: FreqCoordConvUp
out_channels: 32
- name: ConvBlock
out_channels: 32
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200
train:
optimizer:
name: adam
learning_rate: 0.001
scheduler:
name: cosine_annealing
t_max: 100
labels:
sigma: 3
trainer:
max_epochs: 10
check_val_every_n_epoch: 5
train_loader:
batch_size: 8
shuffle: true
clipping_strategy:
name: random_subclip
duration: 0.256
augmentations:
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- name: warp
probability: 0.2
delta: 0.04
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3
val_loader:
clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
loss:
detection:
weight: 1.0
focal:
beta: 4
alpha: 2
classification:
weight: 2.0
focal:
beta: 4
alpha: 2
size:
weight: 0.1
validation:
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: sound_event_classification
metrics:
- name: average_precision
logging:
train:
name: csv
evaluation:
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: score_distribution
- name: example_detection
- name: sound_event_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: top_class_detection
metrics:
- name: average_precision
plots:
- name: pr_curve
- name: confusion_matrix
- name: example_classification
- name: clip_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
- name: score_distribution
- name: clip_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve

View File

@ -0,0 +1,4 @@
samplerate: 256000
resample:
enabled: true
method: poly

View File

@ -0,0 +1,37 @@
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: score_distribution
- name: example_detection
- name: sound_event_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: top_class_detection
metrics:
- name: average_precision
plots:
- name: pr_curve
- name: confusion_matrix
- name: example_classification
- name: clip_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
- name: score_distribution
- name: clip_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve

View File

@ -0,0 +1,9 @@
loader:
batch_size: 8
clipping:
enabled: true
duration: 0.5
overlap: 0.0
max_empty: 0.0
discard_empty: true

View File

@ -0,0 +1,2 @@
train:
name: csv

View File

@ -0,0 +1,59 @@
samplerate: 256000
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_subtraction
architecture:
name: UNetBackbone
input_height: 128
in_channels: 1
encoder:
layers:
- name: FreqCoordConvDown
out_channels: 32
- name: FreqCoordConvDown
out_channels: 64
- name: LayerGroup
layers:
- name: FreqCoordConvDown
out_channels: 128
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
layers:
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- name: FreqCoordConvUp
out_channels: 64
- name: FreqCoordConvUp
out_channels: 32
- name: LayerGroup
layers:
- name: FreqCoordConvUp
out_channels: 32
- name: ConvBlock
out_channels: 32
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200

View File

@ -0,0 +1,9 @@
format:
name: raw
include_class_scores: true
include_features: true
include_geometry: true
transform:
detection_transforms: []
clip_transforms: []

View File

@ -0,0 +1,79 @@
optimizer:
name: adam
learning_rate: 0.001
scheduler:
name: cosine_annealing
t_max: 100
labels:
sigma: 3
trainer:
max_epochs: 10
check_val_every_n_epoch: 5
train_loader:
batch_size: 8
shuffle: true
clipping_strategy:
name: random_subclip
duration: 0.256
augmentations:
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- name: warp
probability: 0.2
delta: 0.04
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3
val_loader:
clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
loss:
detection:
weight: 1.0
focal:
beta: 4
alpha: 2
classification:
weight: 2.0
focal:
beta: 4
alpha: 2
size:
weight: 0.1
validation:
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: sound_event_classification
metrics:
- name: average_precision

View File

@ -112,6 +112,12 @@ clean: clean-build clean-pyc clean-test clean-docs
example-train OPTIONS="":
uv run batdetect2 train \
--val-dataset example_data/dataset.yaml \
--config example_data/config.yaml \
--base-dir . \
--targets example_data/targets.yaml \
--model-config example_data/configs/model.yaml \
--training-config example_data/configs/training.yaml \
--audio-config example_data/configs/audio.yaml \
--evaluation-config example_data/configs/evaluation.yaml \
--logging-config example_data/configs/logging.yaml \
{{OPTIONS}} \
example_data/dataset.yaml

View File

@ -12,11 +12,14 @@ if TYPE_CHECKING:
import torch
from batdetect2.audio import AudioConfig, AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.data import Dataset
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig, LoggerConfig
from batdetect2.logging import (
AppLoggingConfig,
LoggerConfig,
LoggingCallback,
)
from batdetect2.models import Model, ModelConfig
from batdetect2.outputs import (
OutputFormatConfig,
@ -36,6 +39,7 @@ if TYPE_CHECKING:
TargetProtocol,
)
from batdetect2.train import TrainingConfig
from batdetect2.train.logging import TrainLoggingContext
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
@ -107,9 +111,11 @@ class BatDetect2API:
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
):
from batdetect2.train import run_train
self.model.train()
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
@ -130,12 +136,15 @@ class BatDetect2API:
train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config,
logger_config=logger_config or self.logging_config.train,
logging_callbacks=logging_callbacks,
)
self.model.eval()
return self
def finetune(
self,
train_annotations: Sequence[data.ClipAnnotation],
targets_config: TargetConfig,
val_annotations: Sequence[data.ClipAnnotation] | None = None,
trainable: Literal[
"all", "heads", "classifier_head", "bbox_head"
@ -148,25 +157,77 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
"""Fine-tune from a checkpoint using a new target definition."""
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
build_output_formatter,
build_output_transform,
)
from batdetect2.targets import (
TargetConfig,
build_roi_mapping,
build_targets,
)
from batdetect2.train import run_train
self._set_trainable_parameters(trainable)
target_config = TargetConfig.model_validate(targets_config)
targets = build_targets(config=target_config)
roi_mapper = build_roi_mapping(config=target_config.roi)
model = build_model_with_new_targets(
model=self.model,
targets=targets,
roi_mapper=roi_mapper,
)
output_transform = build_output_transform(
config=self.outputs_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
api = BatDetect2API(
model_config=self.model_config,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
evaluation_config=self.evaluation_config,
inference_config=self.inference_config,
outputs_config=self.outputs_config,
logging_config=self.logging_config,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
postprocessor=self.postprocessor,
evaluator=build_evaluator(
config=self.evaluation_config,
targets=targets,
roi_mapper=roi_mapper,
transform=output_transform,
),
formatter=build_output_formatter(
targets,
config=self.outputs_config.format,
),
output_transform=output_transform,
model=model,
)
api._set_trainable_parameters(trainable)
api.model.train()
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
model=api.model,
targets=api.targets,
roi_mapper=api.roi_mapper,
model_config=api.model_config,
preprocessor=api.preprocessor,
audio_loader=api.audio_loader,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
@ -175,11 +236,13 @@ class BatDetect2API:
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
logger_config=logger_config or self.logging_config.train,
audio_config=api.audio_config,
train_config=api.train_config,
logger_config=logger_config or api.logging_config.train,
logging_callbacks=logging_callbacks,
)
return self
api.model.eval()
return api
def evaluate(
self,
@ -483,46 +546,70 @@ class BatDetect2API:
@classmethod
def from_config(
cls,
config: BatDetect2Config,
model_config: ModelConfig | None = None,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig, build_model
from batdetect2.outputs import (
OutputsConfig,
build_output_formatter,
build_output_transform,
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.targets import (
TargetConfig,
build_roi_mapping,
build_targets,
)
from batdetect2.train import TrainingConfig
targets = build_targets(config=config.model.targets)
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
model_config = model_config or ModelConfig()
targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
audio_loader = build_audio_loader(config=config.audio)
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
audio_loader = build_audio_loader(config=audio_config)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.model.preprocess,
config=model_config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.model.postprocess,
config=model_config.postprocess,
)
formatter = build_output_formatter(
targets,
config=config.outputs.format,
config=outputs_config.format,
)
output_transform = build_output_transform(
config=config.outputs.transform,
config=outputs_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
evaluator = build_evaluator(
config=config.evaluation,
config=evaluation_config,
targets=targets,
roi_mapper=roi_mapper,
transform=output_transform,
@ -531,27 +618,27 @@ class BatDetect2API:
# NOTE: Build separate instances of preprocessor and postprocessor
# to avoid device mismatch errors
model = build_model(
config=config.model,
targets=targets,
roi_mapper=roi_mapper,
config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.model.preprocess,
config=model_config.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.model.postprocess,
config=model_config.postprocess,
),
)
return cls(
model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
logging_config=config.logging,
model_config=model_config,
audio_config=audio_config,
train_config=train_config,
evaluation_config=evaluation_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader,
@ -567,7 +654,6 @@ class BatDetect2API:
def from_checkpoint(
cls,
path: data.PathLike,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
@ -579,7 +665,6 @@ class BatDetect2API:
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
OutputsConfig,
build_output_formatter,
@ -587,37 +672,41 @@ class BatDetect2API:
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.targets import (
build_roi_mapping,
build_targets,
check_target_compatibility,
)
from batdetect2.train import load_model_from_checkpoint
model, model_config = load_model_from_checkpoint(path)
model, configs = load_model_from_checkpoint(path)
model_config = configs.model
train_config = train_config or configs.train
audio_config = audio_config or AudioConfig(
samplerate=model_config.samplerate,
)
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
targets_config = configs.targets
if (
targets_config is not None
and targets_config != model_config.targets
):
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
model = build_model_with_new_targets(
model=model,
targets=targets,
roi_mapper=roi_mapper,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
if not check_target_compatibility(targets, model.class_names):
raise ValueError(
"Provided targets_config is incompatible with the "
"checkpoint model: missing one or more model classes."
)
targets = build_targets(config=model_config.targets)
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
if model.dimension_names != roi_mapper.dimension_names:
raise ValueError(
"Provided targets_config is incompatible with the "
"checkpoint model: mismatched dimension names."
)
audio_loader = build_audio_loader(config=audio_config)

View File

@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.finetune import finetune_command
from batdetect2.cli.inference import predict
from batdetect2.cli.train import train_command
@ -10,6 +11,7 @@ __all__ = [
"detect",
"data",
"train_command",
"finetune_command",
"evaluate_command",
"predict",
]

View File

@ -77,6 +77,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
"num_workers",
type=int,
help="Number of worker processes for dataset loading.",
default=0,
)
def evaluate_command(
model_path: Path,
@ -105,7 +106,6 @@ def evaluate_command(
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
logger.info("Initiating evaluation process...")
@ -119,11 +119,6 @@ def evaluate_command(
num_annotations=len(test_annotations),
)
target_conf = (
TargetConfig.load(targets_config)
if targets_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
@ -150,7 +145,6 @@ def evaluate_command(
api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,

View File

@ -0,0 +1,211 @@
from pathlib import Path
from typing import Literal, cast
import click
from loguru import logger
from batdetect2.cli.base import cli
__all__ = ["finetune_command"]
@cli.command(
name="finetune", short_help="Fine-tune a checkpoint on new targets."
)
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option(
"--model",
"model_path",
required=True,
type=click.Path(exists=True),
help="Path to a checkpoint to fine-tune from.",
)
@click.option(
"--targets",
"targets_config",
required=True,
type=click.Path(exists=True),
help="Path to the new targets config file.",
)
@click.option(
"--val-dataset",
type=click.Path(exists=True),
help="Path to validation dataset config file.",
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"Base directory used to resolve relative paths inside the training "
"and validation dataset configs."
),
)
@click.option(
"--training-config",
type=click.Path(exists=True),
help="Path to training config file.",
)
@click.option(
"--audio-config",
type=click.Path(exists=True),
help="Path to audio config file.",
)
@click.option(
"--logging-config",
type=click.Path(exists=True),
help="Path to logging config file.",
)
@click.option(
"--trainable",
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
default="heads",
show_default=True,
help="Which model parameters remain trainable during fine-tuning.",
)
@click.option(
"--ckpt-dir",
type=click.Path(exists=True),
help="Directory where checkpoints are saved.",
)
@click.option(
"--log-dir",
type=click.Path(exists=True),
help="Directory where logs are written.",
)
@click.option(
"--train-workers",
type=int,
default=0,
help="Number of worker processes for training data loading.",
)
@click.option(
"--val-workers",
type=int,
default=0,
help="Number of worker processes for validation data loading.",
)
@click.option(
"--num-epochs",
type=int,
help="Maximum number of training epochs.",
)
@click.option(
"--experiment-name",
type=str,
help="Experiment name used for logging backends.",
)
@click.option(
"--run-name",
type=str,
help="Run name used for logging backends.",
)
@click.option(
"--seed",
type=int,
help="Random seed used for reproducibility.",
)
def finetune_command(
train_dataset: Path,
model_path: Path,
targets_config: Path,
val_dataset: Path | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
base_dir: Path | None = None,
training_config: Path | None = None,
audio_config: Path | None = None,
logging_config: Path | None = None,
trainable: str = "heads",
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: str | None = None,
run_name: str | None = None,
):
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset, load_dataset_config
from batdetect2.logging import AppLoggingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
from batdetect2.train.logging import (
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
)
logger.info("Initiating fine-tuning process...")
target_conf = TargetConfig.load(targets_config)
train_conf = (
TrainingConfig.load(training_config)
if training_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
train_dataset_conf = load_dataset_config(train_dataset)
train_annotations = load_dataset(train_dataset_conf, base_dir=base_dir)
val_dataset_conf = (
load_dataset_config(val_dataset) if val_dataset else None
)
val_annotations = (
load_dataset(val_dataset_conf, base_dir=base_dir)
if val_dataset_conf
else None
)
logging_callbacks = [
DatasetConfigArtifactLogging(
train_dataset_config=DatasetConfigArtifact(
filename="train_dataset.yaml",
config=train_dataset_conf,
),
val_dataset_config=(
DatasetConfigArtifact(
filename="val_dataset.yaml",
config=val_dataset_conf,
)
if val_dataset_conf
else None
),
)
]
api = BatDetect2API.from_checkpoint(
model_path,
train_config=train_conf,
audio_config=audio_conf,
logging_config=logging_conf,
)
return api.finetune(
train_annotations=train_annotations,
val_annotations=val_annotations,
targets_config=target_conf,
trainable=cast(
Literal["all", "heads", "classifier_head", "bbox_head"],
trainable,
),
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
experiment_name=experiment_name,
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
train_config=train_conf,
audio_config=audio_conf,
logger_config=logging_conf.train if logging_conf is not None else None,
logging_callbacks=logging_callbacks,
)

View File

@ -86,11 +86,13 @@ __all__ = ["train_command"]
@click.option(
"--train-workers",
type=int,
default=0,
help="Number of worker processes for training data loading.",
)
@click.option(
"--val-workers",
type=int,
default=0,
help="Number of worker processes for validation data loading.",
)
@click.option(
@ -143,8 +145,7 @@ def train_command(
"""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config
from batdetect2.data import load_dataset_from_config
from batdetect2.data import load_dataset_config, load_dataset_from_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
@ -152,6 +153,10 @@ def train_command(
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
from batdetect2.train.logging import (
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
)
logger.info("Initiating training process...")
@ -196,9 +201,6 @@ def train_command(
if target_conf is not None:
logger.info("Loaded targets configuration.")
if model_conf is not None and target_conf is not None:
model_conf = model_conf.model_copy(update={"targets": target_conf})
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(
train_dataset,
@ -224,37 +226,49 @@ def train_command(
logger.info("Configuration and data loaded. Starting training...")
logging_callbacks = [
DatasetConfigArtifactLogging(
train_dataset_config=DatasetConfigArtifact(
filename="train_dataset.yaml",
config=load_dataset_config(train_dataset),
),
val_dataset_config=(
DatasetConfigArtifact(
filename="val_dataset.yaml",
config=load_dataset_config(val_dataset),
)
if val_dataset is not None
else None
),
)
]
if model_path is not None and model_conf is not None:
raise click.UsageError(
"--model-config cannot be used with --model. "
"Checkpoint model configuration is loaded from the checkpoint."
)
if model_path is not None and target_conf is not None:
raise click.UsageError(
"--targets cannot be used with --model. "
"Checkpoint target configuration is loaded from the checkpoint."
)
if model_path is None:
conf = BatDetect2Config()
if model_conf is not None:
conf.model = model_conf
elif target_conf is not None:
conf.model = conf.model.model_copy(update={"targets": target_conf})
if train_conf is not None:
conf.train = train_conf
if audio_conf is not None:
conf.audio = audio_conf
if eval_conf is not None:
conf.evaluation = eval_conf
if inference_conf is not None:
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
if logging_conf is not None:
conf.logging = logging_conf
api = BatDetect2API.from_config(conf)
api = BatDetect2API.from_config(
model_config=model_conf,
targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
else:
api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
@ -274,4 +288,5 @@ def train_command(
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
logging_callbacks=logging_callbacks,
)

View File

@ -1,31 +0,0 @@
from typing import Literal
from pydantic import Field
from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.evaluate.config import (
EvaluationConfig,
get_default_eval_config,
)
from batdetect2.inference.config import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig
__all__ = ["BatDetect2Config"]
class BatDetect2Config(BaseConfig):
config_version: Literal["v1"] = "v1"
train: TrainingConfig = Field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = Field(
default_factory=get_default_eval_config
)
model: ModelConfig = Field(default_factory=ModelConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)

View File

@ -12,6 +12,8 @@ from typing import (
from hydra.utils import instantiate
from pydantic import BaseModel, Field
from batdetect2.core.configs import BaseConfig
__all__ = [
"add_import_config",
"ImportConfig",
@ -120,7 +122,7 @@ class Registry(Generic[T_Type, P_Type]):
return self._registry[name](config, *args, **kwargs)
class ImportConfig(BaseModel):
class ImportConfig(BaseConfig):
"""Base config for dynamic instantiation via Hydra.
Subclass this to create a registry-specific import escape hatch.

View File

@ -16,6 +16,7 @@ from batdetect2.data.conditions.common import (
IdInListConfig,
JsonList,
ListFormatConfig,
TagInfo,
TxtList,
)
from batdetect2.data.conditions.recordings import (
@ -63,16 +64,17 @@ __all__ = [
"NotConfig",
"Operator",
"PathInListConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingCondition",
"RecordingConditionConfig",
"RecordingConditionImportConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingNotConfig",
"RecordingSatisfiesConfig",
"SoundEventCondition",
"SoundEventConditionConfig",
"SoundEventConditionImportConfig",
"TagInfo",
"TxtList",
"build_clip_annotation_condition",
"build_recording_condition",

View File

@ -2,10 +2,23 @@ import csv
import json
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
from typing import (
Annotated,
Any,
Generic,
Literal,
ParamSpec,
Protocol,
TypeVar,
)
from uuid import UUID
from pydantic import BaseModel, Field, model_validator
from pydantic import (
BaseModel,
Field,
PlainSerializer,
model_validator,
)
from soundevent import data
from batdetect2.core.configs import BaseConfig
@ -138,19 +151,26 @@ class IdInList(Generic[UUIDObject]):
return obj.uuid in self.ids
def dump_tag(tag: data.Tag) -> dict[str, Any]:
return {"key": tag.term.name, "value": tag.value}
TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)]
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
tag: TagInfo
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: list[data.Tag]
tags: list[TagInfo]
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: list[data.Tag]
tags: list[TagInfo]
class JsonList(BaseConfig):

View File

@ -94,7 +94,7 @@ def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
return partial(operator.ge, value)
if op == "eq":
return partial(operator.eq, b=value)
return partial(operator.eq, value)
raise ValueError(f"Invalid operator {op}")

View File

@ -24,8 +24,8 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def run_evaluate(
model: Model,
test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
@ -46,8 +46,6 @@ def run_evaluate(
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
loader = build_test_loader(
test_annotations,

View File

@ -45,8 +45,16 @@ def run_batch_inference(
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
if targets is None:
raise ValueError(
"targets must be provided when running batch inference."
)
if roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when running batch inference."
)
output_transform = output_transform or build_output_transform(
config=output_config.transform,

View File

@ -7,21 +7,37 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
class InferenceModule(LightningModule):
def __init__(
self,
model: Model,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
output_transform: OutputTransformProtocol | None = None,
detection_threshold: float | None = None,
):
super().__init__()
self.model = model
self.detection_threshold = detection_threshold
if output_transform is None and targets is None:
raise ValueError(
"targets must be provided when building inference output "
"transforms."
)
if output_transform is None and roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when building inference output "
"transforms."
)
self.output_transform = output_transform or build_output_transform(
targets=model.targets,
roi_mapper=model.roi_mapper,
targets=targets,
roi_mapper=roi_mapper,
)
def predict_step(

View File

@ -24,12 +24,7 @@ from batdetect2.core.configs import BaseConfig
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from lightning.pytorch.loggers import Logger
from matplotlib.figure import Figure
from soundevent import data
@ -43,11 +38,15 @@ __all__ = [
"DVCLiveConfig",
"LoggerConfig",
"MLFlowLoggerConfig",
"LoggingCallback",
"TensorBoardLoggerConfig",
"build_logger",
"enable_logging",
"get_image_logger",
"get_table_logger",
"log_artifact_file",
"log_config_artifact",
"log_csv_artifact",
]
@ -123,6 +122,18 @@ class LoggerBuilder(Protocol, Generic[T]):
) -> Logger: ...
LoggingContext = TypeVar("LoggingContext", contravariant=True)
class LoggingCallback(Protocol, Generic[LoggingContext]):
def run(
self,
logger: Logger,
artifact_path: Path,
context: LoggingContext,
) -> None: ...
def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Path | None = None,
@ -276,6 +287,71 @@ def build_logger(
)
def log_artifact_file(
runtime_logger: Logger,
path: Path,
artifact_path: str = "artifacts",
) -> None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(runtime_logger, MLFlowLogger):
runtime_logger.experiment.log_artifact( # type: ignore[call-arg]
local_path=str(path),
artifact_path=artifact_path,
run_id=runtime_logger.run_id,
)
return
experiment = getattr(runtime_logger, "experiment", None)
if experiment is not None and hasattr(experiment, "log_artifact"):
experiment.log_artifact(path=path, name=path.name, copy=True)
return
if isinstance(runtime_logger, (CSVLogger, TensorBoardLogger)):
return
logger.warning(
"Skipping artifact logging for unsupported logger type {logger_type}",
logger_type=type(runtime_logger).__name__,
)
def log_config_artifact(
logger: Logger,
config: BaseConfig,
filename: str,
artifact_path: Path,
) -> None:
artifact_path.mkdir(parents=True, exist_ok=True)
path = artifact_path / filename
path.write_text(config.to_yaml_string())
log_artifact_file(
logger,
path,
artifact_path=artifact_path.name,
)
def log_csv_artifact(
logger: Logger,
df: pd.DataFrame,
filename: str,
artifact_path: Path,
) -> None:
artifact_path.mkdir(parents=True, exist_ok=True)
path = artifact_path / filename
df.to_csv(path, index=False)
log_artifact_file(
logger,
path,
artifact_path=artifact_path.name,
)
PlotLogger = Callable[[str, "Figure", int], None]

View File

@ -26,11 +26,8 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module.
"""
from typing import Literal
import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
@ -73,7 +70,6 @@ from batdetect2.postprocess.types import (
)
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [
@ -131,10 +127,6 @@ class ModelConfig(BaseConfig):
Parameters for converting raw model outputs into detections (NMS
kernel, thresholds, top-k limit). Defaults to
``PostprocessConfig()``.
targets : TargetConfig
Detection and classification target definitions (class list,
detection target, bounding-box mapper). Defaults to
``TargetConfig()``.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
@ -143,23 +135,6 @@ class ModelConfig(BaseConfig):
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
@classmethod
def load(
cls,
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
targets: TargetConfig | None = None,
) -> "ModelConfig":
config = super().load(path, field, extra, strict)
if targets is None:
return config
return config.model_copy(update={"targets": targets})
class Model(torch.nn.Module):
@ -183,33 +158,32 @@ class Model(torch.nn.Module):
postprocessor : PostprocessorProtocol
Converts the raw ``ModelOutput`` from ``detector`` into a list of
per-clip detection tensors.
targets : TargetProtocol
Describes the set of target classes; used when building heads and
during training target construction.
roi_mapper : ROIMapperProtocol
Maps geometries to target-size channels and back.
class_names : list[str]
Class names corresponding to the model classification outputs.
dimension_names : list[str]
Size-dimension names corresponding to the model size outputs.
"""
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
targets: TargetProtocol
roi_mapper: ROIMapperProtocol
class_names: list[str]
dimension_names: list[str]
def __init__(
self,
detector: DetectionModel,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
class_names: list[str],
dimension_names: list[str],
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.roi_mapper = roi_mapper
self.class_names = class_names
self.dimension_names = dimension_names
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
@ -237,9 +211,9 @@ class Model(torch.nn.Module):
def build_model(
config: ModelConfig | None = None,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
config: ModelConfig | dict | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
) -> Model:
@ -254,11 +228,13 @@ def build_model(
----------
config : ModelConfig, optional
Full model configuration (samplerate, architecture, preprocessing,
postprocessing, targets). Defaults to ``ModelConfig()`` if not
provided.
targets : TargetProtocol, optional
Pre-built targets object. If given, overrides
``config.targets``.
postprocessing). Defaults to ``ModelConfig()`` if not provided.
class_names : list[str], optional
Class names used to size the classifier head. Required when building
a new model.
dimension_names : list[str], optional
Dimension names used to size the bbox head. Required when building a
new model.
preprocessor : PreprocessorProtocol, optional
Pre-built preprocessor. If given, overrides
``config.preprocess`` and ``config.samplerate`` for the
@ -278,19 +254,20 @@ def build_model(
"""
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
config = config or ModelConfig()
targets = targets or build_targets(config=config.targets)
targets_config = getattr(targets, "config", None)
roi_config = (
targets_config.roi
if isinstance(targets_config, TargetConfig)
else config.targets.roi
)
if isinstance(config, dict):
config = ModelConfig.model_validate(config)
if class_names is None:
raise ValueError("class_names must be provided when building a model.")
if dimension_names is None:
raise ValueError(
"dimension_names must be provided when building a model."
)
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=config.samplerate,
@ -300,16 +277,16 @@ def build_model(
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
num_classes=len(class_names),
num_sizes=len(dimension_names),
config=config.architecture,
)
return Model(
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
roi_mapper=roi_mapper,
class_names=class_names,
dimension_names=dimension_names,
)
@ -329,6 +306,6 @@ def build_model_with_new_targets(
detector=detector,
postprocessor=model.postprocessor,
preprocessor=model.preprocessor,
targets=targets,
roi_mapper=roi_mapper,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
)

View File

@ -53,8 +53,12 @@ import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.core.configs import BaseConfig
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
__all__ = [
"BlockImportConfig",

View File

@ -6,6 +6,7 @@ from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.heatmaps import (
plot_classification_heatmap,
plot_detection_heatmap,
plot_size_heatmap,
)
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
@ -25,5 +26,6 @@ __all__ = [
"plot_true_positive_match",
"plot_detection_heatmap",
"plot_classification_heatmap",
"plot_size_heatmap",
"plot_match_gallery",
]

View File

@ -1,4 +1,4 @@
"""Plot heatmaps"""
"""Plot heatmaps."""
import numpy as np
import torch
@ -8,6 +8,12 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
from batdetect2.plotting.common import create_ax
__all__ = [
"plot_detection_heatmap",
"plot_classification_heatmap",
"plot_size_heatmap",
]
def plot_detection_heatmap(
heatmap: torch.Tensor | np.ndarray,
@ -108,7 +114,91 @@ def plot_classification_heatmap(
return ax
def create_colormap(color: str) -> Colormap:
def plot_size_heatmap(
heatmap: torch.Tensor | np.ndarray,
dimension_names: list[str],
ax: axes.Axes | None = None,
figsize: tuple[int, int] = (10, 10),
color: str = "crimson",
size: float = 20,
fontsize: float = 8,
) -> axes.Axes:
"""Plot sparse size labels from a size heatmap.
Parameters
----------
heatmap : torch.Tensor | np.ndarray
Size heatmap with shape ``[num_dims, height, width]``. Entries are
expected to be zero everywhere except at labelled positions.
dimension_names : list[str]
Names corresponding to the first heatmap dimension.
ax : matplotlib.axes.Axes | None, default=None
Axis to plot on. If ``None``, a new axis is created.
figsize : tuple[int, int], default=(10, 10)
Figure size used when creating a new axis.
color : str, default="crimson"
Color used for scatter points and text labels.
size : float, default=20
Marker size for plotted points.
fontsize : float, default=8
Font size used for the text labels.
Returns
-------
matplotlib.axes.Axes
Axis containing the plotted size labels.
"""
ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
if heatmap.ndim == 4:
heatmap = heatmap[0]
if heatmap.ndim != 3:
raise ValueError("Expecting a 3-dimensional array")
if len(dimension_names) != heatmap.shape[0]:
raise ValueError("Inconsistent number of dimension names")
point_mask = np.any(heatmap != 0, axis=0)
rows, cols = np.nonzero(point_mask)
if len(rows) == 0:
return ax
ax.scatter(cols, rows, c=color, s=size)
for row, col in zip(rows, cols, strict=False):
values = heatmap[:, row, col]
labels = [
f"{name}={value:.2f}"
for name, value in zip(
dimension_names,
values,
strict=False,
)
if value != 0
]
ax.text(
float(col),
float(row),
"\n".join(labels),
fontsize=fontsize,
color=color,
va="bottom",
ha="left",
)
ax.set_xlim(0, heatmap.shape[2])
ax.set_ylim(0, heatmap.shape[1])
return ax
def create_colormap(
color: str | tuple[float, float, float, float],
) -> Colormap:
(r, g, b, a) = to_rgba(color)
return LinearSegmentedColormap.from_list(
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]

View File

@ -6,7 +6,7 @@ from batdetect2.targets.classes import (
build_sound_event_encoder,
get_class_names_from_config,
)
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.config import TargetConfig, build_default_target_config
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
ROIMapperConfig,
@ -36,13 +36,14 @@ from batdetect2.targets.types import (
SoundEventFilter,
TargetProtocol,
)
from batdetect2.targets.utils import check_target_compatibility
__all__ = [
"AnchorBBoxMapperConfig",
"Position",
"ROIMappingConfig",
"ROIMapperProtocol",
"ROIMapperConfig",
"ROIMapperProtocol",
"ROIMappingConfig",
"ROITargetMapper",
"Size",
"SoundEventDecoder",
@ -52,12 +53,14 @@ __all__ = [
"TargetConfig",
"TargetProtocol",
"Targets",
"build_roi_mapping",
"build_default_target_config",
"build_roi_mapper",
"build_roi_mapping",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_targets",
"call_type",
"check_target_compatibility",
"data_source",
"generic_class",
"get_class_names_from_config",

View File

@ -12,6 +12,7 @@ from batdetect2.data.conditions import (
NotConfig,
SoundEventCondition,
SoundEventConditionConfig,
TagInfo,
build_sound_event_condition,
)
from batdetect2.targets.terms import call_type, generic_class
@ -32,11 +33,12 @@ class TargetClassConfig(BaseConfig):
condition_input: SoundEventConditionConfig | None = Field(
alias="match_if",
default=None,
exclude=True,
)
tags: List[data.Tag] | None = Field(default=None, exclude=True)
assign_tags: List[data.Tag] = Field(default_factory=list)
assign_tags: List[TagInfo] = Field(default_factory=list)
_match_if: SoundEventConditionConfig = PrivateAttr()

View File

@ -2,6 +2,7 @@ from collections import Counter
from typing import List
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.targets.classes import (
@ -13,6 +14,7 @@ from batdetect2.targets.rois import ROIMappingConfig
__all__ = [
"TargetConfig",
"build_default_target_config",
]
@ -42,3 +44,20 @@ class TargetConfig(BaseConfig):
f"{', '.join(duplicates)}"
)
return v
def build_default_target_config(class_names: list[str]) -> TargetConfig:
"""Build a default target configuration object."""
return TargetConfig(
detection_target=DEFAULT_DETECTION_CLASS,
classification_targets=[
TargetClassConfig(
name=class_name,
tags=[
data.Tag(key="class", value=class_name),
],
)
for class_name in class_names
],
roi=ROIMappingConfig(),
)

View File

@ -50,21 +50,31 @@ class Targets(TargetProtocol):
self.config = config
self._filter_fn = build_sound_event_condition(
config.detection_target.match_if
self.config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
self.config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
self.config.classification_targets
)
self.class_names = get_class_names_from_config(
config.classification_targets
self.config.classification_targets
)
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self.detection_class_name = self.config.detection_target.name
self.detection_class_tags = self.config.detection_target.assign_tags
@classmethod
def from_config(cls, config: dict) -> "Targets":
"""Build a Targets object from a serialized config dictionary."""
validated_config = TargetConfig.model_validate(config)
return cls(config=validated_config)
def get_config(self) -> dict:
"""Return the serialized target config used to build this object."""
return self.config.model_dump(mode="json")
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
@ -131,7 +141,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
)
def build_targets(config: TargetConfig | None = None) -> Targets:
def build_targets(config: TargetConfig | dict | None = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
Parameters
@ -153,6 +163,10 @@ def build_targets(config: TargetConfig | None = None) -> Targets:
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
if not isinstance(config, TargetConfig):
config = TargetConfig.model_validate(config)
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),

View File

@ -28,6 +28,11 @@ class TargetProtocol(Protocol):
detection_class_tags: list[data.Tag]
detection_class_name: str
@classmethod
def from_config(cls, config: dict) -> "TargetProtocol": ...
def get_config(self) -> dict: ...
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
def encode_class(

View File

@ -0,0 +1,29 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from batdetect2.targets.types import TargetProtocol
def check_target_compatibility(
targets: "TargetProtocol",
class_names: list[str],
) -> bool:
"""Check if a target definition can decode a model's outputs.
Parameters
----------
targets : TargetProtocol
Target definition that would be used with the model outputs.
class_names : list[str]
Class names produced by the model checkpoint.
Returns
-------
bool
True when every model class name exists in the provided targets,
False otherwise.
"""
target_class_names = set(targets.class_names)
model_class_names = set(class_names)
return model_class_names.issubset(target_class_names)

View File

@ -4,10 +4,24 @@ from batdetect2.train.lightning import (
TrainingModule,
load_model_from_checkpoint,
)
from batdetect2.train.logging import (
ConfigHyperparameterLogging,
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
DataSummaryArtifactLogging,
TargetConfigArtifactLogging,
TrainLoggingContext,
)
from batdetect2.train.train import build_trainer, run_train
__all__ = [
"ConfigHyperparameterLogging",
"DataSummaryArtifactLogging",
"DEFAULT_CHECKPOINT_DIR",
"DatasetConfigArtifact",
"DatasetConfigArtifactLogging",
"TargetConfigArtifactLogging",
"TrainLoggingContext",
"TrainingConfig",
"TrainingModule",
"build_trainer",

View File

@ -10,6 +10,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models.types import ModelOutput
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.types import TrainExample
@ -19,11 +20,15 @@ class ValidationMetrics(Callback):
def __init__(
self,
evaluator: EvaluatorProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__()
self.evaluator = evaluator
self.targets = targets
self.roi_mapper = roi_mapper
self.output_transform = output_transform
self._clip_annotations: List[data.ClipAnnotation] = []
@ -93,8 +98,8 @@ class ValidationMetrics(Callback):
model = pl_module.model
if self.output_transform is None:
self.output_transform = build_output_transform(
targets=model.targets,
roi_mapper=model.roi_mapper,
targets=self.targets,
roi_mapper=self.roi_mapper,
)
output_transform = self.output_transform

View File

@ -34,6 +34,8 @@ def build_checkpoint_callback(
if checkpoint_dir is None:
checkpoint_dir = config.checkpoint_dir
checkpoint_dir = Path(checkpoint_dir)
if experiment_name is not None:
checkpoint_dir = checkpoint_dir / experiment_name

View File

@ -1,8 +1,11 @@
from dataclasses import dataclass
import lightning as L
from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models.types import ModelOutput
from batdetect2.targets import TargetConfig
from batdetect2.train.config import TrainingConfig
from batdetect2.train.losses import build_loss
from batdetect2.train.optimizers import build_optimizer
@ -11,6 +14,7 @@ from batdetect2.train.types import LossProtocol, TrainExample
__all__ = [
"TrainingModule",
"load_model_from_checkpoint",
]
@ -21,6 +25,9 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
model_config: dict | None = None,
targets_config: dict | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: dict | None = None,
loss: LossProtocol | None = None,
model: Model | None = None,
@ -29,14 +36,34 @@ class TrainingModule(L.LightningModule):
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
self.model_config = ModelConfig.model_validate(model_config or {})
self.model_config: dict = model_config or {}
self.targets_config: dict = targets_config or {}
self.class_names = list(class_names or [])
self.dimension_names = list(dimension_names or [])
self.train_config = TrainingConfig.model_validate(train_config or {})
if loss is None:
loss = build_loss(config=self.train_config.loss)
if model is None:
model = build_model(config=self.model_config)
if not self.class_names:
raise ValueError(
"class_names must be provided when rebuilding a training "
"module without a model."
)
if not self.dimension_names:
raise ValueError(
"dimension_names must be provided when rebuilding a "
"training module without a model."
)
model = build_model(
config=self.model_config,
class_names=self.class_names,
dimension_names=self.dimension_names,
)
self.loss = loss
self.model = model
@ -95,9 +122,16 @@ class TrainingModule(L.LightningModule):
}
@dataclass
class StoredConfig:
model: ModelConfig
targets: TargetConfig
train: TrainingConfig
def load_model_from_checkpoint(
path: PathLike,
) -> tuple[Model, ModelConfig]:
) -> tuple[Model, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint.
Parameters
@ -110,15 +144,24 @@ def load_model_from_checkpoint(
-------
tuple[Model, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, postprocessing, and
targets.
describes its architecture, preprocessing, and postprocessing.
"""
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.model_config
training_config = TrainingConfig.model_validate(module.train_config)
model_config = ModelConfig.model_validate(module.model_config)
targets_config = TargetConfig.model_validate(module.targets_config)
return module.model, StoredConfig(
model=model_config,
targets=targets_config,
train=training_config,
)
def build_training_module(
model_config: ModelConfig | None = None,
targets_config: TargetConfig | dict | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: TrainingConfig | None = None,
model: Model | None = None,
) -> TrainingModule:
@ -128,8 +171,16 @@ def build_training_module(
if train_config is None:
train_config = TrainingConfig()
if targets_config is None:
targets_config = TargetConfig()
targets_config = TargetConfig.model_validate(targets_config)
return TrainingModule(
model_config=model_config.model_dump(mode="json"),
targets_config=targets_config.model_dump(mode="json"),
train_config=train_config.model_dump(mode="json"),
class_names=class_names,
dimension_names=dimension_names,
model=model,
)

View File

@ -0,0 +1,164 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from lightning.pytorch.loggers import Logger
from soundevent import data
from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data import Dataset, compute_class_summary
from batdetect2.logging import log_config_artifact, log_csv_artifact
from batdetect2.models import ModelConfig
from batdetect2.targets import TargetConfig, TargetProtocol
from batdetect2.train.config import TrainingConfig
__all__ = [
"ConfigHyperparameterLogging",
"DataSummaryArtifactLogging",
"DatasetConfigArtifact",
"DatasetConfigArtifactLogging",
"TargetConfigArtifactLogging",
"TrainLoggingContext",
]
@dataclass(frozen=True)
class TrainLoggingContext:
model_config: ModelConfig
train_config: TrainingConfig
audio_config: AudioConfig
targets: TargetProtocol
train_dataset: Dataset
val_dataset: Dataset | None
@dataclass(frozen=True)
class DatasetConfigArtifact:
filename: str
config: BaseConfig
class ConfigHyperparameterLogging:
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
logger.log_hyperparams(
{
"model": context.model_config.model_dump(
mode="json",
exclude_none=True,
),
"training": context.train_config.model_dump(
mode="json",
exclude_none=True,
),
"audio": context.audio_config.model_dump(
mode="json",
exclude_none=True,
),
}
)
class TargetConfigArtifactLogging:
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
targets_config = TargetConfig.model_validate(
context.targets.get_config()
)
log_config_artifact(
logger,
targets_config,
filename="targets.yaml",
artifact_path=artifact_path / "training_artifacts",
)
class DatasetConfigArtifactLogging:
def __init__(
self,
train_dataset_config: DatasetConfigArtifact,
val_dataset_config: DatasetConfigArtifact | None = None,
):
self.train_dataset_config = train_dataset_config
self.val_dataset_config = val_dataset_config
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
training_artifact_path = artifact_path / "training_artifacts"
log_config_artifact(
logger,
self.train_dataset_config.config,
filename=self.train_dataset_config.filename,
artifact_path=training_artifact_path,
)
if self.val_dataset_config is not None:
log_config_artifact(
logger,
self.val_dataset_config.config,
filename=self.val_dataset_config.filename,
artifact_path=training_artifact_path,
)
class DataSummaryArtifactLogging:
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
training_artifact_path = artifact_path / "training_artifacts"
log_csv_artifact(
logger,
_compute_class_summary_or_empty(
context.train_dataset,
context.targets,
),
filename="train_class_summary.csv",
artifact_path=training_artifact_path,
)
if context.val_dataset is not None:
log_csv_artifact(
logger,
_compute_class_summary_or_empty(
context.val_dataset,
context.targets,
),
filename="val_class_summary.csv",
artifact_path=training_artifact_path,
)
def _compute_class_summary_or_empty(
dataset: Sequence[data.ClipAnnotation],
targets: TargetProtocol,
) -> pd.DataFrame:
try:
return compute_class_summary(dataset, targets)
except KeyError as error:
if error.args != ("class_name",):
raise
return pd.DataFrame(
columns=["num calls", "num recordings", "duration", "call_rate"]
)

View File

@ -3,6 +3,7 @@
from collections.abc import Iterable
from typing import Annotated, Literal
from loguru import logger
from pydantic import Field
from torch import nn
from torch.optim import Adam, Optimizer
@ -84,4 +85,10 @@ def build_optimizer(
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
"""
config = config or AdamOptimizerConfig()
logger.opt(lazy=True).debug(
"Building optimizer with config: \n{}",
lambda: config.to_yaml_string(),
)
return optimizer_registry.build(config, parameters)

View File

@ -2,6 +2,7 @@
from typing import Annotated, Literal
from loguru import logger
from pydantic import Field
from torch.optim import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
@ -78,4 +79,9 @@ def build_scheduler(
"""Build a scheduler from configuration."""
config = config or CosineAnnealingSchedulerConfig()
logger.opt(lazy=True).debug(
"Building scheduler with config: \n{}",
lambda: config.to_yaml_string(),
)
return scheduler_registry.build(config, optimizer)

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Optional
from lightning import Trainer, seed_everything
from lightning.pytorch.loggers import Logger
from loguru import logger
from soundevent import data
@ -10,6 +11,7 @@ from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
from batdetect2.logging import (
LoggerConfig,
LoggingCallback,
TensorBoardLoggerConfig,
build_logger,
)
@ -17,6 +19,7 @@ from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
TargetProtocol,
build_roi_mapping,
build_targets,
@ -27,6 +30,12 @@ from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
from batdetect2.train.logging import (
ConfigHyperparameterLogging,
DataSummaryArtifactLogging,
TargetConfigArtifactLogging,
TrainLoggingContext,
)
from batdetect2.train.types import ClipLabeller
__all__ = [
@ -35,6 +44,9 @@ __all__ = [
]
DEFAULT_LOG_DIR = Path("outputs") / "logs"
def run_train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
@ -46,6 +58,7 @@ def run_train(
labeller: Optional["ClipLabeller"] = None,
audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None,
targets_config: TargetConfig | None = None,
train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None,
trainer: Trainer | None = None,
@ -57,28 +70,43 @@ def run_train(
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
):
if seed is not None:
seed_everything(seed)
model_config = model_config or ModelConfig()
targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
if model is not None:
_validate_model_compatibility(model=model, model_config=model_config)
if targets is None:
raise ValueError(
"targets must be provided when training with an existing "
"model."
)
if roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when training with an existing "
"model."
)
if targets is None:
targets = build_targets(config=targets_config)
else:
targets_config = TargetConfig.model_validate(targets.get_config())
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
if model is not None:
targets = targets or model.targets
if roi_mapper is None and targets is model.targets:
roi_mapper = model.roi_mapper
targets = targets or build_targets(config=model_config.targets)
roi_mapper = roi_mapper or build_roi_mapping(
config=model_config.targets.roi
)
_validate_model_compatibility(
model=model,
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
)
audio_loader = audio_loader or build_audio_loader(config=audio_config)
@ -119,21 +147,57 @@ def run_train(
module = build_training_module(
model_config=model_config,
targets_config=targets_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
model=model,
)
evaluator = build_evaluator(
train_config.validation,
targets=targets,
roi_mapper=roi_mapper,
)
train_logger = build_logger(
logger_config or TensorBoardLoggerConfig(),
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
root_artifact_path = (
Path(log_dir) if log_dir is not None else DEFAULT_LOG_DIR
)
root_artifact_path.mkdir(parents=True, exist_ok=True)
logging_context = TrainLoggingContext(
model_config=model_config,
train_config=train_config,
audio_config=audio_config,
targets=targets,
train_dataset=train_annotations,
val_dataset=val_annotations,
)
resolved_logging_callbacks = (
ConfigHyperparameterLogging(),
TargetConfigArtifactLogging(),
DataSummaryArtifactLogging(),
*logging_callbacks,
)
for callback in resolved_logging_callbacks:
callback.run(train_logger, root_artifact_path, logging_context)
trainer = trainer or build_trainer(
train_config,
logger_config=logger_config,
evaluator=build_evaluator(
train_config.validation,
targets=targets,
roi_mapper=roi_mapper,
),
train_logger=train_logger,
evaluator=evaluator,
targets=targets,
roi_mapper=roi_mapper,
checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
@ -152,8 +216,14 @@ def run_train(
def _validate_model_compatibility(
model: Model,
model_config: ModelConfig,
class_names: list[str],
dimension_names: list[str],
) -> None:
reference_model = build_model(config=model_config)
reference_model = build_model(
config=model_config,
class_names=class_names,
dimension_names=dimension_names,
)
expected_shapes = {
key: tuple(value.shape)
@ -194,10 +264,11 @@ def _validate_model_compatibility(
def build_trainer(
config: TrainingConfig,
logger_config: LoggerConfig | None,
train_logger: Logger,
evaluator: "EvaluatorProtocol",
targets: "TargetProtocol",
roi_mapper: "ROIMapperProtocol",
checkpoint_dir: Path | None = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
num_epochs: int | None = None,
@ -208,25 +279,11 @@ def build_trainer(
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(
logger_config or TensorBoardLoggerConfig(),
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
train_logger.log_hyperparams(
config.model_dump(
mode="json",
exclude_none=True,
)
)
if num_epochs is not None:
trainer_conf.max_epochs = num_epochs
train_config = trainer_conf.model_dump(exclude_none=True)
if num_epochs is not None:
train_config["max_epochs"] = num_epochs
return Trainer(
**train_config,
logger=train_logger,
@ -237,6 +294,6 @@ def build_trainer(
experiment_name=experiment_name,
run_name=run_name,
),
ValidationMetrics(evaluator),
ValidationMetrics(evaluator, targets, roi_mapper),
],
)

View File

@ -13,13 +13,14 @@ from soundevent import data, terms
from batdetect2.audio import build_audio_loader
from batdetect2.audio.clips import build_clipper
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.config import BatDetect2Config
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
build_roi_mapping,
build_targets,
call_type,
)
@ -404,6 +405,13 @@ def sample_targets(
return build_targets(sample_target_config)
@pytest.fixture
def sample_roi_mapper(
sample_target_config: TargetConfig,
) -> ROIMapperProtocol:
return build_roi_mapping(sample_target_config.roi)
@pytest.fixture
def sample_labeller(
sample_targets: TargetProtocol,
@ -458,8 +466,16 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
@pytest.fixture
def tiny_checkpoint_path(tmp_path: Path) -> Path:
module = build_training_module(model_config=BatDetect2Config().model)
def tiny_checkpoint_path(
sample_targets: TargetProtocol,
sample_roi_mapper: ROIMapperProtocol,
tmp_path: Path,
) -> Path:
module = build_training_module(
targets_config=sample_targets.get_config(),
class_names=sample_targets.class_names,
dimension_names=sample_roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "model.ckpt"
trainer.strategy.connect(module)

View File

View File

@ -8,20 +8,43 @@ import torch
from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.inference import InferenceConfig
from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead
from batdetect2.train import load_model_from_checkpoint
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module
@pytest.fixture
def api_v2() -> BatDetect2API:
def train_config() -> TrainingConfig:
"""Train config with a small batch size for testing."""
return TrainingConfig.model_validate({"train_loader": {"batch_size": 2}})
@pytest.fixture
def inference_config() -> InferenceConfig:
"""Inference config with a small batch size for testing."""
return InferenceConfig.model_validate({"loader": {"batch_size": 2}})
@pytest.fixture
def example_targets_config(example_data_dir: Path) -> TargetConfig:
return TargetConfig.load(example_data_dir / "targets.yaml")
@pytest.fixture
def api_v2(
train_config: TrainingConfig,
inference_config: InferenceConfig,
) -> BatDetect2API:
"""User story: users can create a ready-to-use API from config."""
config = BatDetect2Config()
config.inference.loader.batch_size = 2
return BatDetect2API.from_config(config)
api = BatDetect2API.from_config(
train_config=train_config,
inference_config=inference_config,
)
assert api.inference_config.loader.batch_size == 2
return api
def test_process_file_returns_recording_level_predictions(
@ -30,8 +53,10 @@ def test_process_file_returns_recording_level_predictions(
) -> None:
"""User story: process a file and get detections in recording time."""
# When
prediction = api_v2.process_file(example_audio_files[0])
# Then
assert prediction.clip.recording.path == example_audio_files[0]
assert prediction.clip.start_time == 0
assert prediction.clip.end_time == prediction.clip.recording.duration
@ -53,9 +78,11 @@ def test_process_files_is_batch_size_invariant(
) -> None:
"""User story: changing batch size should not change predictions."""
# When
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
# Then
assert len(preds_batch_1) == len(preds_batch_3)
by_key_1 = {
@ -91,12 +118,14 @@ def test_process_audio_matches_process_spectrogram(
) -> None:
"""User story: users can call either audio or spectrogram entrypoint."""
# When
audio = api_v2.load_audio(example_audio_files[0])
from_audio = api_v2.process_audio(audio)
spec = api_v2.generate_spectrogram(audio)
from_spec = api_v2.process_spectrogram(spec)
# Then
assert len(from_audio) == len(from_spec)
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
@ -116,8 +145,10 @@ def test_process_spectrogram_rejects_batched_input(
) -> None:
"""User story: invalid batched input gives a clear error."""
# Given
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
# When/Then
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
api_v2.process_spectrogram(spec)
@ -184,26 +215,35 @@ def test_user_can_read_extracted_features_per_detection(
@pytest.mark.slow
def test_user_can_load_checkpoint_and_finetune(
tmp_path: Path,
example_targets_config: TargetConfig,
example_annotations,
) -> None:
"""User story: load a checkpoint and continue training from it."""
module = build_training_module(model_config=BatDetect2Config().model)
api = BatDetect2API.from_config(
targets_config=example_targets_config,
)
module = build_training_module(
model_config=api.model_config,
targets_config=example_targets_config,
class_names=api.targets.class_names,
dimension_names=api.roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
config = BatDetect2Config()
config.train.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
train_config = api.train_config.model_copy(deep=True)
train_config.trainer.limit_train_batches = 1
train_config.trainer.limit_val_batches = 1
train_config.trainer.log_every_n_steps = 1
train_config.train_loader.batch_size = 1
train_config.train_loader.augmentations.enabled = False
api = BatDetect2API.from_checkpoint(
checkpoint_path,
train_config=config.train,
train_config=train_config,
)
finetune_dir = tmp_path / "finetuned"
@ -222,62 +262,34 @@ def test_user_can_load_checkpoint_and_finetune(
assert checkpoints
def test_user_can_load_checkpoint_with_new_targets(
tmp_path: Path,
sample_targets,
) -> None:
"""User story: start from checkpoint with a new target definition."""
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base_transfer.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, _ = load_model_from_checkpoint(checkpoint_path)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=sample_targets.config,
)
source_detector = cast(Detector, source_model.detector)
detector = cast(Detector, api.model.detector)
classifier_head = cast(ClassifierHead, detector.classifier_head)
assert api.targets.config == sample_targets.config # type: ignore
assert detector.num_classes == len(sample_targets.class_names)
assert (
classifier_head.classifier.out_channels
== len(sample_targets.class_names) + 1
)
source_backbone = source_detector.backbone.state_dict()
target_backbone = detector.backbone.state_dict()
assert source_backbone
for key, value in source_backbone.items():
assert key in target_backbone
torch.testing.assert_close(target_backbone[key], value)
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
example_targets_config: TargetConfig,
tmp_path: Path,
) -> None:
"""User story: same targets config does not rebuild prediction heads."""
module = build_training_module(model_config=BatDetect2Config().model)
# Given
source_api = BatDetect2API.from_config(
targets_config=example_targets_config
)
module = build_training_module(
model_config=source_api.model_config,
targets_config=example_targets_config,
class_names=source_api.targets.class_names,
dimension_names=source_api.roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "same_targets.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, source_model_config = load_model_from_checkpoint(
checkpoint_path
)
source_model, _ = load_model_from_checkpoint(checkpoint_path)
source_detector = cast(Detector, source_model.detector)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=source_model_config.targets,
)
# When
api = BatDetect2API.from_checkpoint(checkpoint_path)
# Then
detector = cast(Detector, api.model.detector)
for key, value in source_detector.classifier_head.state_dict().items():
@ -295,42 +307,6 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
)
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config(BatDetect2Config())
finetune_dir = tmp_path / "heads_only"
api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow
def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API,
@ -348,8 +324,6 @@ def test_user_can_evaluate_small_dataset_and_get_metrics(
assert isinstance(metrics, list)
assert len(metrics) == 1
assert isinstance(metrics[0], dict)
assert len(metrics[0]) > 0
assert isinstance(predictions, list)
assert len(predictions) == 1
@ -450,8 +424,17 @@ def test_detection_threshold_override_changes_spectrogram_results(
spec = api_v2.generate_spectrogram(audio)
default_detections = api_v2.process_spectrogram(spec)
strict_detections = api_v2.process_spectrogram(
spec,
detection_threshold=1.0,
spec, detection_threshold=1.0
)
assert len(strict_detections) <= len(default_detections)
def test_user_can_create_api_with_custom_targets_and_model_metadata_matches(
sample_targets,
) -> None:
"""User story: custom targets define model output names for a new API."""
api = BatDetect2API.from_config(targets_config=sample_targets.config)
assert api.model.class_names == sample_targets.class_names

View File

@ -0,0 +1,114 @@
from pathlib import Path
from typing import cast
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.models.detectors import Detector
from batdetect2.targets import TargetConfig
from batdetect2.train import load_model_from_checkpoint
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config()
source_classifier_head = api.model.detector.classifier_head
source_bbox_head = api.model.detector.bbox_head
source_backbone = api.model.detector.backbone
finetune_dir = tmp_path / "heads_only"
finetuned_api = api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=TargetConfig(),
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, finetuned_api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert finetuned_api is not api
assert detector.backbone is source_backbone
assert detector.classifier_head is not source_classifier_head
assert detector.bbox_head is not source_bbox_head
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow
def test_finetune_replaces_targets_and_checkpoint_owns_new_targets(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tuning writes checkpoints with the new targets."""
source_api = BatDetect2API.from_config()
source_evaluator = source_api.evaluator
source_formatter = source_api.formatter
source_output_transform = source_api.output_transform
new_targets = TargetConfig.model_validate(
{
"classification_targets": [
{
"name": "single_class",
"tags": [{"key": "class", "value": "single_class"}],
}
],
"roi": {"mapper": "top_left"},
}
)
finetune_dir = tmp_path / "new_targets"
finetuned_api = source_api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=new_targets,
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
checkpoints = list(finetune_dir.rglob("*.ckpt"))
assert source_api.targets.get_config() != new_targets.model_dump(
mode="json"
)
assert finetuned_api.targets.get_config() == new_targets.model_dump(
mode="json"
)
assert finetuned_api.evaluator is not source_evaluator
assert finetuned_api.formatter is not source_formatter
assert finetuned_api.output_transform is not source_output_transform
assert finetuned_api.evaluator.targets is finetuned_api.targets
assert finetuned_api.evaluator.transform is finetuned_api.output_transform
assert finetuned_api.model.class_names == ["single_class"]
assert finetuned_api.model.dimension_names == ["width", "height"]
assert checkpoints
_, configs = load_model_from_checkpoint(checkpoints[0])
assert configs.targets.model_dump(mode="json") == new_targets.model_dump(
mode="json"
)

View File

@ -5,7 +5,6 @@ import numpy as np
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.outputs import build_output_formatter
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
@ -18,7 +17,7 @@ from batdetect2.postprocess.types import ClipDetections
def api_v2() -> BatDetect2API:
"""User story: API object manages prediction IO formats."""
return BatDetect2API.from_config(BatDetect2Config())
return BatDetect2API.from_config()
@pytest.fixture

View File

View File

@ -0,0 +1,99 @@
"""CLI tests for finetune command."""
from pathlib import Path
import pytest
from click.testing import CliRunner
from batdetect2.cli import cli
def test_cli_finetune_help() -> None:
"""User story: inspect finetune command interface and options."""
result = CliRunner().invoke(cli, ["finetune", "--help"])
assert result.exit_code == 0
assert "TRAIN_DATASET" in result.output
assert "--model" in result.output
assert "--targets" in result.output
assert "--training-config" in result.output
assert "--audio-config" in result.output
assert "--logging-config" in result.output
assert "--evaluation-config" not in result.output
assert "--inference-config" not in result.output
assert "--outputs-config" not in result.output
def test_cli_finetune_requires_model() -> None:
"""User story: finetune requires a checkpoint argument."""
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--targets",
"example_data/targets.yaml",
],
)
assert result.exit_code != 0
assert "--model" in result.output
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
"""User story: finetune requires a new target definition."""
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
],
)
assert result.exit_code != 0
assert "--targets" in result.output
@pytest.mark.slow
def test_cli_finetune_from_checkpoint_runs_on_small_dataset(
tmp_path: Path,
tiny_checkpoint_path: Path,
) -> None:
"""User story: fine-tune a checkpoint via CLI with new targets."""
ckpt_dir = tmp_path / "checkpoints"
log_dir = tmp_path / "logs"
ckpt_dir.mkdir()
log_dir.mkdir()
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--val-dataset",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--targets",
"example_data/targets.yaml",
"--num-epochs",
"1",
"--train-workers",
"0",
"--val-workers",
"0",
"--ckpt-dir",
str(ckpt_dir),
"--log-dir",
str(log_dir),
],
)
assert result.exit_code == 0
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1

View File

@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together(
assert result.exit_code != 0
assert "--model-config cannot be used with --model" in result.output
def test_cli_train_rejects_model_and_targets_together(
tiny_checkpoint_path: Path,
) -> None:
"""User story: checkpoint training does not accept new targets."""
result = CliRunner().invoke(
cli,
[
"train",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--targets",
"example_data/targets.yaml",
],
)
assert result.exit_code != 0
assert "--targets cannot be used with --model" in result.output

View File

@ -203,8 +203,8 @@ in_channels: 1
def test_load_backbone_config_from_example_data(example_data_dir: Path):
"""load_backbone_config loads the real example config correctly."""
config = load_backbone_config(
example_data_dir / "config.yaml",
field="model.architecture",
example_data_dir / "configs" / "model.yaml",
field="architecture",
)
assert isinstance(config, UNetBackboneConfig)

View File

@ -1,9 +1,34 @@
import json
from collections.abc import Callable
from pathlib import Path
from soundevent import data, terms
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.targets import (
TargetConfig,
Targets,
build_roi_mapping,
build_targets,
)
def test_targets_get_config_returns_a_json_serializable_dict() -> None:
targets = build_targets(TargetConfig())
config_dict = targets.get_config()
assert isinstance(config_dict, dict)
assert json.dumps(config_dict)
def test_targets_from_config_rebuilds_equivalent_targets() -> None:
original = build_targets(TargetConfig())
rebuilt = Targets.from_config(original.get_config())
assert rebuilt.class_names == original.class_names
assert rebuilt.detection_class_name == original.detection_class_name
assert rebuilt.detection_class_tags == original.detection_class_tags
assert rebuilt.get_config() == original.get_config()
def test_can_override_default_roi_mapper_per_class(

View File

@ -0,0 +1,40 @@
from soundevent import data
from batdetect2.targets import (
TargetClassConfig,
TargetConfig,
build_targets,
check_target_compatibility,
)
def _target_class(name: str) -> TargetClassConfig:
return TargetClassConfig(
name=name,
tags=[data.Tag(key="class", value=name)],
)
def test_check_target_compatibility_accepts_superset_targets() -> None:
config = TargetConfig(
classification_targets=[
_target_class("pip35"),
_target_class("myo"),
_target_class("extra"),
]
)
targets = build_targets(config)
assert check_target_compatibility(targets, ["pip35", "myo"])
def test_check_target_compatibility_rejects_missing_model_classes() -> None:
config = TargetConfig(
classification_targets=[
_target_class("pip35"),
_target_class("myo"),
]
)
targets = build_targets(config)
assert not check_target_compatibility(targets, ["pip35", "nyc"])

View File

@ -3,20 +3,19 @@ from pathlib import Path
import pytest
from soundevent import data
from batdetect2.config import BatDetect2Config
from batdetect2.train import run_train
from batdetect2.train import TrainingConfig, run_train
pytestmark = pytest.mark.slow
def _build_fast_train_config() -> BatDetect2Config:
config = BatDetect2Config()
config.train.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1
config.train.trainer.check_val_every_n_epoch = 1
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
def _build_fast_train_config() -> TrainingConfig:
config = TrainingConfig()
config.trainer.limit_train_batches = 1
config.trainer.limit_val_batches = 1
config.trainer.log_every_n_steps = 1
config.trainer.check_val_every_n_epoch = 1
config.train_loader.batch_size = 1
config.train_loader.augmentations.enabled = False
return config
@ -29,9 +28,7 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=config.train,
model_config=config.model,
audio_config=config.audio,
train_config=config,
num_epochs=1,
train_workers=0,
val_workers=0,
@ -50,14 +47,12 @@ def test_train_without_validation_can_still_save_last_checkpoint(
example_annotations: list[data.ClipAnnotation],
) -> None:
config = _build_fast_train_config()
config.train.checkpoints.save_last = True
config.checkpoints.save_last = True
run_train(
train_annotations=example_annotations[:1],
val_annotations=None,
train_config=config.train,
model_config=config.model,
audio_config=config.audio,
train_config=config,
num_epochs=1,
train_workers=0,
val_workers=0,
@ -73,16 +68,14 @@ def test_train_controls_which_checkpoints_are_kept(
example_annotations: list[data.ClipAnnotation],
) -> None:
config = _build_fast_train_config()
config.train.checkpoints.save_top_k = 1
config.train.checkpoints.save_last = True
config.train.checkpoints.filename = "epoch{epoch}"
config.checkpoints.save_top_k = 1
config.checkpoints.save_last = True
config.checkpoints.filename = "epoch{epoch}"
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=config.train,
model_config=config.model,
audio_config=config.audio,
train_config=config,
num_epochs=3,
train_workers=0,
val_workers=0,

View File

@ -1,12 +1,43 @@
from batdetect2.config import BatDetect2Config
from batdetect2.core import load_config
from batdetect2.audio import AudioConfig
from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
def test_example_config_is_valid(example_data_dir):
conf = load_config(
example_data_dir / "config.yaml",
schema=BatDetect2Config,
extra="forbid",
strict=True,
def test_example_split_configs_are_valid(example_data_dir):
configs_dir = example_data_dir / "configs"
assert isinstance(
AudioConfig.load(configs_dir / "audio.yaml"), AudioConfig
)
assert isinstance(
ModelConfig.load(configs_dir / "model.yaml"), ModelConfig
)
assert isinstance(
TargetConfig.load(example_data_dir / "targets.yaml"),
TargetConfig,
)
assert isinstance(
TrainingConfig.load(configs_dir / "training.yaml"),
TrainingConfig,
)
assert isinstance(
EvaluationConfig.load(configs_dir / "evaluation.yaml"),
EvaluationConfig,
)
assert isinstance(
InferenceConfig.load(configs_dir / "inference.yaml"),
InferenceConfig,
)
assert isinstance(
OutputsConfig.load(configs_dir / "outputs.yaml"),
OutputsConfig,
)
assert isinstance(
AppLoggingConfig.load(configs_dir / "logging.yaml"),
AppLoggingConfig,
)
assert isinstance(conf, BatDetect2Config)

View File

@ -10,25 +10,42 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio.types import AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.models import ModelConfig, build_model
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.models import (
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.train import (
TrainingConfig,
TrainingModule,
load_model_from_checkpoint,
run_train,
)
from batdetect2.train.logging import (
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
)
from batdetect2.train.optimizers import AdamOptimizerConfig
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
from batdetect2.train.train import build_training_module
def build_default_module(config: BatDetect2Config | None = None):
config = config or BatDetect2Config()
def build_default_module(
target_config: TargetConfig | None = None,
model_config: ModelConfig | None = None,
train_config: TrainingConfig | None = None,
):
target_config = target_config or TargetConfig()
model_config = model_config or ModelConfig()
train_config = train_config or TrainingConfig()
targets = build_targets(target_config)
roi_mapper = build_roi_mapping(target_config.roi)
return build_training_module(
model_config=config.model,
train_config=config.train,
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
@ -64,7 +81,7 @@ def test_can_save_checkpoint(
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
def test_load_model_from_checkpoint_returns_model_and_config(
def test_load_model_from_checkpoint_returns_model_and_configs(
tmp_path: Path,
):
input_model_config = ModelConfig(samplerate=192_000)
@ -72,8 +89,13 @@ def test_load_model_from_checkpoint_returns_model_and_config(
input_model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
module = build_training_module(
model_config=input_model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
trainer = L.Trainer()
@ -81,12 +103,20 @@ def test_load_model_from_checkpoint_returns_model_and_config(
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
model, loaded_model_config = load_model_from_checkpoint(path)
model, loaded_configs = load_model_from_checkpoint(path)
assert model is not None
assert loaded_model_config.model_dump(
assert loaded_configs.model.model_dump(
mode="json"
) == expected_model_config.model_dump(mode="json")
assert loaded_configs.targets.model_dump(
mode="json"
) == targets_config.model_dump(mode="json")
assert loaded_configs.train.model_dump(
mode="json"
) == train_config.model_dump(mode="json")
assert model.class_names == targets.class_names
assert model.dimension_names == roi_mapper.dimension_names
recovered = TrainingModule.load_from_checkpoint(path)
assert recovered.train_config.model_dump(
@ -100,6 +130,9 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
train_config.trainer.max_epochs = 3
@ -107,6 +140,8 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
module = build_training_module(
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
trainer = L.Trainer()
@ -114,28 +149,56 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
recovered = TrainingModule.load_from_checkpoint(path)
_, recovered_configs = load_model_from_checkpoint(path)
assert not DeepDiff(
recovered.model_config.model_dump(mode="json"),
recovered_configs.model.model_dump(mode="json"),
expected_model_config.model_dump(mode="json"),
)
assert not DeepDiff(
recovered.train_config.model_dump(mode="json"),
recovered_configs.train.model_dump(mode="json"),
train_config.model_dump(mode="json"),
)
def test_load_model_from_checkpoint_includes_targets_config(tmp_path: Path):
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
module = build_training_module(
model_config=ModelConfig(),
targets_config=targets_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=TrainingConfig(),
)
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
_, loaded_configs = load_model_from_checkpoint(path)
assert loaded_configs.targets.model_dump(
mode="json"
) == targets_config.model_dump(mode="json")
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
model_config = ModelConfig()
expected_model_config = ModelConfig.model_validate(
model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
module = build_training_module(
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
@ -153,14 +216,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
recovered = TrainingModule.load_from_checkpoint(path)
assert recovered.model_config.model_dump(
_, recovered_configs = load_model_from_checkpoint(path)
assert recovered_configs.model.model_dump(
mode="json"
) == expected_model_config.model_dump(mode="json")
assert recovered.train_config.model_dump(
assert recovered_configs.train.model_dump(
mode="json"
) == train_config.model_dump(mode="json")
recovered = TrainingModule.load_from_checkpoint(path)
loaded_optimization_config = recovered.configure_optimizers()
loaded_optimizer = loaded_optimization_config["optimizer"]
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
@ -175,12 +240,28 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
_, stored_configs = load_model_from_checkpoint(path)
api = BatDetect2API.from_checkpoint(path)
assert api.model_config.model_dump(
mode="json"
) == module.model_config.model_dump(mode="json")
assert api.audio_config.samplerate == module.model_config.samplerate
) == stored_configs.model.model_dump(mode="json")
assert api.audio_config.samplerate == stored_configs.model.samplerate
def test_api_from_checkpoint_reconstructs_targets_from_checkpoint(
tmp_path: Path,
) -> None:
targets_config = TargetConfig()
module = build_default_module(target_config=targets_config)
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
api = BatDetect2API.from_checkpoint(path)
assert api.targets.get_config() == targets_config.model_dump(mode="json")
@pytest.mark.slow
@ -189,19 +270,26 @@ def test_train_smoke_produces_loadable_checkpoint(
example_annotations: list[data.ClipAnnotation],
sample_audio_loader: AudioLoader,
):
config = BatDetect2Config()
config.train.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
# Given
train_config = TrainingConfig.model_validate(
{
"trainer": {
"limit_train_batches": 1,
"limit_val_batches": 1,
"log_every_n_steps": 1,
},
"train_loader": {
"batch_size": 1,
"augmentations": {"enabled": False},
},
}
)
# When
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=config.train,
model_config=config.model,
audio_config=config.audio,
train_config=train_config,
num_epochs=1,
train_workers=0,
val_workers=0,
@ -209,18 +297,11 @@ def test_train_smoke_produces_loadable_checkpoint(
seed=0,
)
# Then
checkpoints = list(tmp_path.rglob("*.ckpt"))
assert checkpoints
model, model_config = load_model_from_checkpoint(checkpoints[0])
assert model_config.samplerate == config.model.samplerate
assert model_config.architecture.name == config.model.architecture.name
assert model_config.preprocess.model_dump(
mode="json"
) == config.model.preprocess.model_dump(mode="json")
assert model_config.postprocess.model_dump(
mode="json"
) == config.model.postprocess.model_dump(mode="json")
wav = torch.tensor(
sample_audio_loader.load_clip(example_annotations[0].clip)
@ -230,10 +311,18 @@ def test_train_smoke_produces_loadable_checkpoint(
def test_build_training_module_uses_provided_model() -> None:
model = build_model(ModelConfig())
targets = build_targets(TargetConfig())
roi_mapper = build_roi_mapping(TargetConfig().roi)
model = build_model(
ModelConfig(),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
)
module = build_training_module(
model_config=ModelConfig(),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=TrainingConfig(),
model=model,
)
@ -241,18 +330,117 @@ def test_build_training_module_uses_provided_model() -> None:
assert module.model is model
def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
None
):
source_targets_config = TargetConfig()
source_targets = build_targets(source_targets_config)
source_roi_mapper = build_roi_mapping(source_targets_config.roi)
source_model = build_model(
ModelConfig(),
class_names=source_targets.class_names,
dimension_names=source_roi_mapper.dimension_names,
)
new_targets_config = TargetConfig.model_validate(
{
"classification_targets": [
{
"name": "single_class",
"tags": [{"key": "class", "value": "single_class"}],
}
]
}
)
new_targets = build_targets(new_targets_config)
new_roi_mapper = build_roi_mapping(new_targets_config.roi)
rebuilt_model = build_model_with_new_targets(
model=source_model,
targets=new_targets,
roi_mapper=new_roi_mapper,
)
source_detector = source_model.detector
rebuilt_detector = rebuilt_model.detector
assert rebuilt_detector.backbone is source_detector.backbone
assert (
rebuilt_detector.classifier_head is not source_detector.classifier_head
)
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
assert rebuilt_model.class_names == ["single_class"]
assert rebuilt_model.dimension_names == ["width", "height"]
@pytest.mark.slow
def test_run_train_logs_training_artifacts(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],
example_dataset,
) -> None:
train_config = TrainingConfig.model_validate(
{
"trainer": {
"limit_train_batches": 1,
"limit_val_batches": 1,
"log_every_n_steps": 1,
},
"train_loader": {
"batch_size": 1,
"augmentations": {"enabled": False},
},
}
)
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=train_config,
num_epochs=1,
train_workers=0,
val_workers=0,
checkpoint_dir=tmp_path / "checkpoints",
log_dir=tmp_path / "logs",
seed=0,
logging_callbacks=[
DatasetConfigArtifactLogging(
train_dataset_config=DatasetConfigArtifact(
filename="train_dataset.yaml",
config=example_dataset,
),
val_dataset_config=DatasetConfigArtifact(
filename="val_dataset.yaml",
config=example_dataset,
),
)
],
)
artifact_root = next((tmp_path / "logs").rglob("training_artifacts"))
assert (artifact_root / "targets.yaml").exists()
assert (artifact_root / "train_dataset.yaml").exists()
assert (artifact_root / "val_dataset.yaml").exists()
assert (artifact_root / "train_class_summary.csv").exists()
assert (artifact_root / "val_class_summary.csv").exists()
def test_run_train_rejects_incompatible_model_config(
example_annotations: list[data.ClipAnnotation],
) -> None:
model = build_model(ModelConfig())
# Given
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
incompatible_config = ModelConfig()
incompatible_config.targets.classification_targets.append(
TargetClassConfig(
name="dummy_class",
tags=[data.Tag(key="class", value="Dummy class")],
)
incompatible_model = build_model(
incompatible_config,
class_names=targets.class_names,
dimension_names=[*roi_mapper.dimension_names, "extra_dim"],
)
# When/Then
with pytest.raises(
ValueError,
match="Provided model is incompatible with model_config",
@ -260,7 +448,10 @@ def test_run_train_rejects_incompatible_model_config(
run_train(
train_annotations=example_annotations[:1],
val_annotations=None,
model=model,
model=incompatible_model,
targets=targets,
roi_mapper=roi_mapper,
model_config=incompatible_config,
targets_config=targets_config,
train_config=TrainingConfig(),
)