From 7b2699786f4925f9b9df13ea91bb4f5e8fb441f9 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 5 May 2026 12:16:37 +0100 Subject: [PATCH] feat: add checkpoint finetuning workflow --- src/batdetect2/api_v2.py | 79 ++++++++++-- src/batdetect2/cli/__init__.py | 2 + src/batdetect2/cli/finetune.py | 188 +++++++++++++++++++++++++++++ tests/test_api_v2/__init__.py | 0 tests/test_api_v2/test_api_v2.py | 36 ------ tests/test_api_v2/test_finetune.py | 114 +++++++++++++++++ tests/test_cli/__init__.py | 0 tests/test_cli/test_finetune.py | 99 +++++++++++++++ tests/test_train/test_lightning.py | 49 +++++++- 9 files changed, 517 insertions(+), 50 deletions(-) create mode 100644 src/batdetect2/cli/finetune.py create mode 100644 tests/test_api_v2/__init__.py create mode 100644 tests/test_api_v2/test_finetune.py create mode 100644 tests/test_cli/__init__.py create mode 100644 tests/test_cli/test_finetune.py diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 0cee1d5..ce24d95 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -137,6 +137,7 @@ class BatDetect2API: def finetune( self, train_annotations: Sequence[data.ClipAnnotation], + targets_config: TargetConfig, val_annotations: Sequence[data.ClipAnnotation] | None = None, trainable: Literal[ "all", "heads", "classifier_head", "bbox_head" @@ -149,25 +150,76 @@ class BatDetect2API: num_epochs: int | None = None, run_name: str | None = None, seed: int | None = None, - model_config: ModelConfig | None = None, audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, logger_config: LoggerConfig | None = None, ) -> "BatDetect2API": - """Fine-tune the model with trainable-parameter selection.""" + """Fine-tune from a checkpoint using a new target definition.""" + from batdetect2.evaluate import build_evaluator + from batdetect2.models import build_model_with_new_targets + from batdetect2.outputs import ( + build_output_formatter, + build_output_transform, + ) + from batdetect2.targets import ( + TargetConfig, + build_roi_mapping, + build_targets, + ) from batdetect2.train import run_train - self._set_trainable_parameters(trainable) + target_config = TargetConfig.model_validate(targets_config) + targets = build_targets(config=target_config) + roi_mapper = build_roi_mapping(config=target_config.roi) + model = build_model_with_new_targets( + model=self.model, + targets=targets, + roi_mapper=roi_mapper, + ) + output_transform = build_output_transform( + config=self.outputs_config.transform, + targets=targets, + roi_mapper=roi_mapper, + ) + api = BatDetect2API( + model_config=self.model_config, + audio_config=audio_config or self.audio_config, + train_config=train_config or self.train_config, + evaluation_config=self.evaluation_config, + inference_config=self.inference_config, + outputs_config=self.outputs_config, + logging_config=self.logging_config, + targets=targets, + roi_mapper=roi_mapper, + audio_loader=self.audio_loader, + preprocessor=self.preprocessor, + postprocessor=self.postprocessor, + evaluator=build_evaluator( + config=self.evaluation_config, + targets=targets, + roi_mapper=roi_mapper, + transform=output_transform, + ), + formatter=build_output_formatter( + targets, + config=self.outputs_config.format, + ), + output_transform=output_transform, + model=model, + ) + + api._set_trainable_parameters(trainable) + api.model.train() run_train( train_annotations=train_annotations, val_annotations=val_annotations, - model=self.model, - targets=self.targets, - roi_mapper=self.roi_mapper, - model_config=model_config or self.model_config, - preprocessor=self.preprocessor, - audio_loader=self.audio_loader, + model=api.model, + targets=api.targets, + roi_mapper=api.roi_mapper, + model_config=api.model_config, + preprocessor=api.preprocessor, + audio_loader=api.audio_loader, train_workers=train_workers, val_workers=val_workers, checkpoint_dir=checkpoint_dir, @@ -176,11 +228,12 @@ class BatDetect2API: num_epochs=num_epochs, run_name=run_name, seed=seed, - audio_config=audio_config or self.audio_config, - train_config=train_config or self.train_config, - logger_config=logger_config or self.logging_config.train, + audio_config=api.audio_config, + train_config=api.train_config, + logger_config=logger_config or api.logging_config.train, ) - return self + api.model.eval() + return api def evaluate( self, diff --git a/src/batdetect2/cli/__init__.py b/src/batdetect2/cli/__init__.py index dace40d..eadc98f 100644 --- a/src/batdetect2/cli/__init__.py +++ b/src/batdetect2/cli/__init__.py @@ -2,6 +2,7 @@ from batdetect2.cli.base import cli from batdetect2.cli.compat import detect from batdetect2.cli.data import data from batdetect2.cli.evaluate import evaluate_command +from batdetect2.cli.finetune import finetune_command from batdetect2.cli.inference import predict from batdetect2.cli.train import train_command @@ -10,6 +11,7 @@ __all__ = [ "detect", "data", "train_command", + "finetune_command", "evaluate_command", "predict", ] diff --git a/src/batdetect2/cli/finetune.py b/src/batdetect2/cli/finetune.py new file mode 100644 index 0000000..45b6efb --- /dev/null +++ b/src/batdetect2/cli/finetune.py @@ -0,0 +1,188 @@ +from pathlib import Path +from typing import Literal, cast + +import click +from loguru import logger + +from batdetect2.cli.base import cli + +__all__ = ["finetune_command"] + + +@cli.command( + name="finetune", short_help="Fine-tune a checkpoint on new targets." +) +@click.argument("train_dataset", type=click.Path(exists=True)) +@click.option( + "--model", + "model_path", + required=True, + type=click.Path(exists=True), + help="Path to a checkpoint to fine-tune from.", +) +@click.option( + "--targets", + "targets_config", + required=True, + type=click.Path(exists=True), + help="Path to the new targets config file.", +) +@click.option( + "--val-dataset", + type=click.Path(exists=True), + help="Path to validation dataset config file.", +) +@click.option( + "--base-dir", + type=click.Path(exists=True), + help=( + "Base directory used to resolve relative paths inside the training " + "and validation dataset configs." + ), +) +@click.option( + "--training-config", + type=click.Path(exists=True), + help="Path to training config file.", +) +@click.option( + "--audio-config", + type=click.Path(exists=True), + help="Path to audio config file.", +) +@click.option( + "--logging-config", + type=click.Path(exists=True), + help="Path to logging config file.", +) +@click.option( + "--trainable", + type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]), + default="heads", + show_default=True, + help="Which model parameters remain trainable during fine-tuning.", +) +@click.option( + "--ckpt-dir", + type=click.Path(exists=True), + help="Directory where checkpoints are saved.", +) +@click.option( + "--log-dir", + type=click.Path(exists=True), + help="Directory where logs are written.", +) +@click.option( + "--train-workers", + type=int, + default=0, + help="Number of worker processes for training data loading.", +) +@click.option( + "--val-workers", + type=int, + default=0, + help="Number of worker processes for validation data loading.", +) +@click.option( + "--num-epochs", + type=int, + help="Maximum number of training epochs.", +) +@click.option( + "--experiment-name", + type=str, + help="Experiment name used for logging backends.", +) +@click.option( + "--run-name", + type=str, + help="Run name used for logging backends.", +) +@click.option( + "--seed", + type=int, + help="Random seed used for reproducibility.", +) +def finetune_command( + train_dataset: Path, + model_path: Path, + targets_config: Path, + val_dataset: Path | None = None, + ckpt_dir: Path | None = None, + log_dir: Path | None = None, + base_dir: Path | None = None, + training_config: Path | None = None, + audio_config: Path | None = None, + logging_config: Path | None = None, + trainable: str = "heads", + seed: int | None = None, + num_epochs: int | None = None, + train_workers: int = 0, + val_workers: int = 0, + experiment_name: str | None = None, + run_name: str | None = None, +): + """Fine-tune a BatDetect2 checkpoint on a new target definition.""" + from batdetect2.api_v2 import BatDetect2API + from batdetect2.audio import AudioConfig + from batdetect2.data import load_dataset_from_config + from batdetect2.logging import AppLoggingConfig + from batdetect2.targets import TargetConfig + from batdetect2.train import TrainingConfig + + logger.info("Initiating fine-tuning process...") + + target_conf = TargetConfig.load(targets_config) + train_conf = ( + TrainingConfig.load(training_config) + if training_config is not None + else None + ) + audio_conf = ( + AudioConfig.load(audio_config) if audio_config is not None else None + ) + logging_conf = ( + AppLoggingConfig.load(logging_config) + if logging_config is not None + else None + ) + + train_annotations = load_dataset_from_config( + train_dataset, + base_dir=base_dir, + ) + val_annotations = None + if val_dataset is not None: + val_annotations = load_dataset_from_config( + val_dataset, + base_dir=base_dir, + ) + + api = BatDetect2API.from_checkpoint( + model_path, + train_config=train_conf, + audio_config=audio_conf, + logging_config=logging_conf, + ) + + return api.finetune( + train_annotations=train_annotations, + val_annotations=val_annotations, + targets_config=target_conf, + trainable=cast( + Literal["all", "heads", "classifier_head", "bbox_head"], + trainable, + ), + train_workers=train_workers, + val_workers=val_workers, + checkpoint_dir=ckpt_dir, + log_dir=log_dir, + experiment_name=experiment_name, + num_epochs=num_epochs, + run_name=run_name, + seed=seed, + train_config=train_conf, + audio_config=audio_conf, + logger_config=logging_conf.train if logging_conf is not None else None, + ) diff --git a/tests/test_api_v2/__init__.py b/tests/test_api_v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index e2ffe31..19e0a40 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -307,42 +307,6 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( ) -@pytest.mark.slow -def test_user_can_finetune_only_heads( - tmp_path: Path, - example_annotations, -) -> None: - """User story: fine-tune only prediction heads.""" - - api = BatDetect2API.from_config() - finetune_dir = tmp_path / "heads_only" - - api.finetune( - train_annotations=example_annotations[:1], - val_annotations=example_annotations[:1], - trainable="heads", - train_workers=0, - val_workers=0, - checkpoint_dir=finetune_dir, - log_dir=tmp_path / "logs", - num_epochs=1, - seed=0, - ) - detector = cast(Detector, api.model.detector) - - backbone_params = list(detector.backbone.parameters()) - classifier_params = list(detector.classifier_head.parameters()) - bbox_params = list(detector.bbox_head.parameters()) - - assert backbone_params - assert classifier_params - assert bbox_params - assert all(not parameter.requires_grad for parameter in backbone_params) - assert all(parameter.requires_grad for parameter in classifier_params) - assert all(parameter.requires_grad for parameter in bbox_params) - assert list(finetune_dir.rglob("*.ckpt")) - - @pytest.mark.slow def test_user_can_evaluate_small_dataset_and_get_metrics( api_v2: BatDetect2API, diff --git a/tests/test_api_v2/test_finetune.py b/tests/test_api_v2/test_finetune.py new file mode 100644 index 0000000..8d8c6a2 --- /dev/null +++ b/tests/test_api_v2/test_finetune.py @@ -0,0 +1,114 @@ +from pathlib import Path +from typing import cast + +import pytest + +from batdetect2.api_v2 import BatDetect2API +from batdetect2.models.detectors import Detector +from batdetect2.targets import TargetConfig +from batdetect2.train import load_model_from_checkpoint + + +@pytest.mark.slow +def test_user_can_finetune_only_heads( + tmp_path: Path, + example_annotations, +) -> None: + """User story: fine-tune only prediction heads.""" + + api = BatDetect2API.from_config() + source_classifier_head = api.model.detector.classifier_head + source_bbox_head = api.model.detector.bbox_head + source_backbone = api.model.detector.backbone + finetune_dir = tmp_path / "heads_only" + + finetuned_api = api.finetune( + train_annotations=example_annotations[:1], + val_annotations=example_annotations[:1], + targets_config=TargetConfig(), + trainable="heads", + train_workers=0, + val_workers=0, + checkpoint_dir=finetune_dir, + log_dir=tmp_path / "logs", + num_epochs=1, + seed=0, + ) + + detector = cast(Detector, finetuned_api.model.detector) + + backbone_params = list(detector.backbone.parameters()) + classifier_params = list(detector.classifier_head.parameters()) + bbox_params = list(detector.bbox_head.parameters()) + + assert backbone_params + assert classifier_params + assert bbox_params + assert all(not parameter.requires_grad for parameter in backbone_params) + assert all(parameter.requires_grad for parameter in classifier_params) + assert all(parameter.requires_grad for parameter in bbox_params) + assert finetuned_api is not api + assert detector.backbone is source_backbone + assert detector.classifier_head is not source_classifier_head + assert detector.bbox_head is not source_bbox_head + assert list(finetune_dir.rglob("*.ckpt")) + + +@pytest.mark.slow +def test_finetune_replaces_targets_and_checkpoint_owns_new_targets( + tmp_path: Path, + example_annotations, +) -> None: + """User story: fine-tuning writes checkpoints with the new targets.""" + + source_api = BatDetect2API.from_config() + source_evaluator = source_api.evaluator + source_formatter = source_api.formatter + source_output_transform = source_api.output_transform + new_targets = TargetConfig.model_validate( + { + "classification_targets": [ + { + "name": "single_class", + "tags": [{"key": "class", "value": "single_class"}], + } + ], + "roi": {"mapper": "top_left"}, + } + ) + finetune_dir = tmp_path / "new_targets" + + finetuned_api = source_api.finetune( + train_annotations=example_annotations[:1], + val_annotations=example_annotations[:1], + targets_config=new_targets, + trainable="heads", + train_workers=0, + val_workers=0, + checkpoint_dir=finetune_dir, + log_dir=tmp_path / "logs", + num_epochs=1, + seed=0, + ) + + checkpoints = list(finetune_dir.rglob("*.ckpt")) + + assert source_api.targets.get_config() != new_targets.model_dump( + mode="json" + ) + assert finetuned_api.targets.get_config() == new_targets.model_dump( + mode="json" + ) + assert finetuned_api.evaluator is not source_evaluator + assert finetuned_api.formatter is not source_formatter + assert finetuned_api.output_transform is not source_output_transform + assert finetuned_api.evaluator.targets is finetuned_api.targets + assert finetuned_api.evaluator.transform is finetuned_api.output_transform + assert finetuned_api.model.class_names == ["single_class"] + assert finetuned_api.model.dimension_names == ["width", "height"] + assert checkpoints + + _, configs = load_model_from_checkpoint(checkpoints[0]) + assert configs.targets.model_dump(mode="json") == new_targets.model_dump( + mode="json" + ) diff --git a/tests/test_cli/__init__.py b/tests/test_cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_cli/test_finetune.py b/tests/test_cli/test_finetune.py new file mode 100644 index 0000000..270ca58 --- /dev/null +++ b/tests/test_cli/test_finetune.py @@ -0,0 +1,99 @@ +"""CLI tests for finetune command.""" + +from pathlib import Path + +import pytest +from click.testing import CliRunner + +from batdetect2.cli import cli + + +def test_cli_finetune_help() -> None: + """User story: inspect finetune command interface and options.""" + + result = CliRunner().invoke(cli, ["finetune", "--help"]) + + assert result.exit_code == 0 + assert "TRAIN_DATASET" in result.output + assert "--model" in result.output + assert "--targets" in result.output + assert "--training-config" in result.output + assert "--audio-config" in result.output + assert "--logging-config" in result.output + assert "--evaluation-config" not in result.output + assert "--inference-config" not in result.output + assert "--outputs-config" not in result.output + + +def test_cli_finetune_requires_model() -> None: + """User story: finetune requires a checkpoint argument.""" + + result = CliRunner().invoke( + cli, + [ + "finetune", + "example_data/dataset.yaml", + "--targets", + "example_data/targets.yaml", + ], + ) + + assert result.exit_code != 0 + assert "--model" in result.output + + +def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None: + """User story: finetune requires a new target definition.""" + + result = CliRunner().invoke( + cli, + [ + "finetune", + "example_data/dataset.yaml", + "--model", + str(tiny_checkpoint_path), + ], + ) + + assert result.exit_code != 0 + assert "--targets" in result.output + + +@pytest.mark.slow +def test_cli_finetune_from_checkpoint_runs_on_small_dataset( + tmp_path: Path, + tiny_checkpoint_path: Path, +) -> None: + """User story: fine-tune a checkpoint via CLI with new targets.""" + + ckpt_dir = tmp_path / "checkpoints" + log_dir = tmp_path / "logs" + ckpt_dir.mkdir() + log_dir.mkdir() + + result = CliRunner().invoke( + cli, + [ + "finetune", + "example_data/dataset.yaml", + "--val-dataset", + "example_data/dataset.yaml", + "--model", + str(tiny_checkpoint_path), + "--targets", + "example_data/targets.yaml", + "--num-epochs", + "1", + "--train-workers", + "0", + "--val-workers", + "0", + "--ckpt-dir", + str(ckpt_dir), + "--log-dir", + str(log_dir), + ], + ) + + assert result.exit_code == 0 + assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1 diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 6e9254a..bac0b50 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -10,7 +10,11 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.api_v2 import BatDetect2API from batdetect2.audio.types import AudioLoader -from batdetect2.models import ModelConfig, build_model +from batdetect2.models import ( + ModelConfig, + build_model, + build_model_with_new_targets, +) from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets from batdetect2.train import ( TrainingConfig, @@ -322,6 +326,49 @@ def test_build_training_module_uses_provided_model() -> None: assert module.model is model +def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> ( + None +): + source_targets_config = TargetConfig() + source_targets = build_targets(source_targets_config) + source_roi_mapper = build_roi_mapping(source_targets_config.roi) + source_model = build_model( + ModelConfig(), + class_names=source_targets.class_names, + dimension_names=source_roi_mapper.dimension_names, + ) + + new_targets_config = TargetConfig.model_validate( + { + "classification_targets": [ + { + "name": "single_class", + "tags": [{"key": "class", "value": "single_class"}], + } + ] + } + ) + new_targets = build_targets(new_targets_config) + new_roi_mapper = build_roi_mapping(new_targets_config.roi) + + rebuilt_model = build_model_with_new_targets( + model=source_model, + targets=new_targets, + roi_mapper=new_roi_mapper, + ) + + source_detector = source_model.detector + rebuilt_detector = rebuilt_model.detector + + assert rebuilt_detector.backbone is source_detector.backbone + assert ( + rebuilt_detector.classifier_head is not source_detector.classifier_head + ) + assert rebuilt_detector.bbox_head is not source_detector.bbox_head + assert rebuilt_model.class_names == ["single_class"] + assert rebuilt_model.dimension_names == ["width", "height"] + + def test_run_train_rejects_incompatible_model_config( example_annotations: list[data.ClipAnnotation], ) -> None: