feat: add checkpoint finetuning workflow

This commit is contained in:
mbsantiago 2026-05-05 12:16:37 +01:00
parent 75e52cc548
commit 7b2699786f
9 changed files with 517 additions and 50 deletions

View File

@ -137,6 +137,7 @@ class BatDetect2API:
def finetune(
self,
train_annotations: Sequence[data.ClipAnnotation],
targets_config: TargetConfig,
val_annotations: Sequence[data.ClipAnnotation] | None = None,
trainable: Literal[
"all", "heads", "classifier_head", "bbox_head"
@ -149,25 +150,76 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
"""Fine-tune from a checkpoint using a new target definition."""
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
build_output_formatter,
build_output_transform,
)
from batdetect2.targets import (
TargetConfig,
build_roi_mapping,
build_targets,
)
from batdetect2.train import run_train
self._set_trainable_parameters(trainable)
target_config = TargetConfig.model_validate(targets_config)
targets = build_targets(config=target_config)
roi_mapper = build_roi_mapping(config=target_config.roi)
model = build_model_with_new_targets(
model=self.model,
targets=targets,
roi_mapper=roi_mapper,
)
output_transform = build_output_transform(
config=self.outputs_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
api = BatDetect2API(
model_config=self.model_config,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
evaluation_config=self.evaluation_config,
inference_config=self.inference_config,
outputs_config=self.outputs_config,
logging_config=self.logging_config,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
postprocessor=self.postprocessor,
evaluator=build_evaluator(
config=self.evaluation_config,
targets=targets,
roi_mapper=roi_mapper,
transform=output_transform,
),
formatter=build_output_formatter(
targets,
config=self.outputs_config.format,
),
output_transform=output_transform,
model=model,
)
api._set_trainable_parameters(trainable)
api.model.train()
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
model=api.model,
targets=api.targets,
roi_mapper=api.roi_mapper,
model_config=api.model_config,
preprocessor=api.preprocessor,
audio_loader=api.audio_loader,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
@ -176,11 +228,12 @@ class BatDetect2API:
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
logger_config=logger_config or self.logging_config.train,
audio_config=api.audio_config,
train_config=api.train_config,
logger_config=logger_config or api.logging_config.train,
)
return self
api.model.eval()
return api
def evaluate(
self,

View File

@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.finetune import finetune_command
from batdetect2.cli.inference import predict
from batdetect2.cli.train import train_command
@ -10,6 +11,7 @@ __all__ = [
"detect",
"data",
"train_command",
"finetune_command",
"evaluate_command",
"predict",
]

View File

@ -0,0 +1,188 @@
from pathlib import Path
from typing import Literal, cast
import click
from loguru import logger
from batdetect2.cli.base import cli
__all__ = ["finetune_command"]
@cli.command(
name="finetune", short_help="Fine-tune a checkpoint on new targets."
)
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option(
"--model",
"model_path",
required=True,
type=click.Path(exists=True),
help="Path to a checkpoint to fine-tune from.",
)
@click.option(
"--targets",
"targets_config",
required=True,
type=click.Path(exists=True),
help="Path to the new targets config file.",
)
@click.option(
"--val-dataset",
type=click.Path(exists=True),
help="Path to validation dataset config file.",
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"Base directory used to resolve relative paths inside the training "
"and validation dataset configs."
),
)
@click.option(
"--training-config",
type=click.Path(exists=True),
help="Path to training config file.",
)
@click.option(
"--audio-config",
type=click.Path(exists=True),
help="Path to audio config file.",
)
@click.option(
"--logging-config",
type=click.Path(exists=True),
help="Path to logging config file.",
)
@click.option(
"--trainable",
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
default="heads",
show_default=True,
help="Which model parameters remain trainable during fine-tuning.",
)
@click.option(
"--ckpt-dir",
type=click.Path(exists=True),
help="Directory where checkpoints are saved.",
)
@click.option(
"--log-dir",
type=click.Path(exists=True),
help="Directory where logs are written.",
)
@click.option(
"--train-workers",
type=int,
default=0,
help="Number of worker processes for training data loading.",
)
@click.option(
"--val-workers",
type=int,
default=0,
help="Number of worker processes for validation data loading.",
)
@click.option(
"--num-epochs",
type=int,
help="Maximum number of training epochs.",
)
@click.option(
"--experiment-name",
type=str,
help="Experiment name used for logging backends.",
)
@click.option(
"--run-name",
type=str,
help="Run name used for logging backends.",
)
@click.option(
"--seed",
type=int,
help="Random seed used for reproducibility.",
)
def finetune_command(
train_dataset: Path,
model_path: Path,
targets_config: Path,
val_dataset: Path | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
base_dir: Path | None = None,
training_config: Path | None = None,
audio_config: Path | None = None,
logging_config: Path | None = None,
trainable: str = "heads",
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: str | None = None,
run_name: str | None = None,
):
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config
from batdetect2.logging import AppLoggingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
logger.info("Initiating fine-tuning process...")
target_conf = TargetConfig.load(targets_config)
train_conf = (
TrainingConfig.load(training_config)
if training_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
train_annotations = load_dataset_from_config(
train_dataset,
base_dir=base_dir,
)
val_annotations = None
if val_dataset is not None:
val_annotations = load_dataset_from_config(
val_dataset,
base_dir=base_dir,
)
api = BatDetect2API.from_checkpoint(
model_path,
train_config=train_conf,
audio_config=audio_conf,
logging_config=logging_conf,
)
return api.finetune(
train_annotations=train_annotations,
val_annotations=val_annotations,
targets_config=target_conf,
trainable=cast(
Literal["all", "heads", "classifier_head", "bbox_head"],
trainable,
),
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
experiment_name=experiment_name,
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
train_config=train_conf,
audio_config=audio_conf,
logger_config=logging_conf.train if logging_conf is not None else None,
)

View File

View File

@ -307,42 +307,6 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
)
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config()
finetune_dir = tmp_path / "heads_only"
api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow
def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API,

View File

@ -0,0 +1,114 @@
from pathlib import Path
from typing import cast
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.models.detectors import Detector
from batdetect2.targets import TargetConfig
from batdetect2.train import load_model_from_checkpoint
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config()
source_classifier_head = api.model.detector.classifier_head
source_bbox_head = api.model.detector.bbox_head
source_backbone = api.model.detector.backbone
finetune_dir = tmp_path / "heads_only"
finetuned_api = api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=TargetConfig(),
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, finetuned_api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert finetuned_api is not api
assert detector.backbone is source_backbone
assert detector.classifier_head is not source_classifier_head
assert detector.bbox_head is not source_bbox_head
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow
def test_finetune_replaces_targets_and_checkpoint_owns_new_targets(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tuning writes checkpoints with the new targets."""
source_api = BatDetect2API.from_config()
source_evaluator = source_api.evaluator
source_formatter = source_api.formatter
source_output_transform = source_api.output_transform
new_targets = TargetConfig.model_validate(
{
"classification_targets": [
{
"name": "single_class",
"tags": [{"key": "class", "value": "single_class"}],
}
],
"roi": {"mapper": "top_left"},
}
)
finetune_dir = tmp_path / "new_targets"
finetuned_api = source_api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=new_targets,
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
checkpoints = list(finetune_dir.rglob("*.ckpt"))
assert source_api.targets.get_config() != new_targets.model_dump(
mode="json"
)
assert finetuned_api.targets.get_config() == new_targets.model_dump(
mode="json"
)
assert finetuned_api.evaluator is not source_evaluator
assert finetuned_api.formatter is not source_formatter
assert finetuned_api.output_transform is not source_output_transform
assert finetuned_api.evaluator.targets is finetuned_api.targets
assert finetuned_api.evaluator.transform is finetuned_api.output_transform
assert finetuned_api.model.class_names == ["single_class"]
assert finetuned_api.model.dimension_names == ["width", "height"]
assert checkpoints
_, configs = load_model_from_checkpoint(checkpoints[0])
assert configs.targets.model_dump(mode="json") == new_targets.model_dump(
mode="json"
)

View File

View File

@ -0,0 +1,99 @@
"""CLI tests for finetune command."""
from pathlib import Path
import pytest
from click.testing import CliRunner
from batdetect2.cli import cli
def test_cli_finetune_help() -> None:
"""User story: inspect finetune command interface and options."""
result = CliRunner().invoke(cli, ["finetune", "--help"])
assert result.exit_code == 0
assert "TRAIN_DATASET" in result.output
assert "--model" in result.output
assert "--targets" in result.output
assert "--training-config" in result.output
assert "--audio-config" in result.output
assert "--logging-config" in result.output
assert "--evaluation-config" not in result.output
assert "--inference-config" not in result.output
assert "--outputs-config" not in result.output
def test_cli_finetune_requires_model() -> None:
"""User story: finetune requires a checkpoint argument."""
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--targets",
"example_data/targets.yaml",
],
)
assert result.exit_code != 0
assert "--model" in result.output
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
"""User story: finetune requires a new target definition."""
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
],
)
assert result.exit_code != 0
assert "--targets" in result.output
@pytest.mark.slow
def test_cli_finetune_from_checkpoint_runs_on_small_dataset(
tmp_path: Path,
tiny_checkpoint_path: Path,
) -> None:
"""User story: fine-tune a checkpoint via CLI with new targets."""
ckpt_dir = tmp_path / "checkpoints"
log_dir = tmp_path / "logs"
ckpt_dir.mkdir()
log_dir.mkdir()
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--val-dataset",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--targets",
"example_data/targets.yaml",
"--num-epochs",
"1",
"--train-workers",
"0",
"--val-workers",
"0",
"--ckpt-dir",
str(ckpt_dir),
"--log-dir",
str(log_dir),
],
)
assert result.exit_code == 0
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1

View File

@ -10,7 +10,11 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio.types import AudioLoader
from batdetect2.models import ModelConfig, build_model
from batdetect2.models import (
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.train import (
TrainingConfig,
@ -322,6 +326,49 @@ def test_build_training_module_uses_provided_model() -> None:
assert module.model is model
def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
None
):
source_targets_config = TargetConfig()
source_targets = build_targets(source_targets_config)
source_roi_mapper = build_roi_mapping(source_targets_config.roi)
source_model = build_model(
ModelConfig(),
class_names=source_targets.class_names,
dimension_names=source_roi_mapper.dimension_names,
)
new_targets_config = TargetConfig.model_validate(
{
"classification_targets": [
{
"name": "single_class",
"tags": [{"key": "class", "value": "single_class"}],
}
]
}
)
new_targets = build_targets(new_targets_config)
new_roi_mapper = build_roi_mapping(new_targets_config.roi)
rebuilt_model = build_model_with_new_targets(
model=source_model,
targets=new_targets,
roi_mapper=new_roi_mapper,
)
source_detector = source_model.detector
rebuilt_detector = rebuilt_model.detector
assert rebuilt_detector.backbone is source_detector.backbone
assert (
rebuilt_detector.classifier_head is not source_detector.classifier_head
)
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
assert rebuilt_model.class_names == ["single_class"]
assert rebuilt_model.dimension_names == ["width", "height"]
def test_run_train_rejects_incompatible_model_config(
example_annotations: list[data.ClipAnnotation],
) -> None: