batdetect2/tests/test_api_v2/test_finetune.py
2026-05-05 12:16:37 +01:00

115 lines
3.8 KiB
Python

from pathlib import Path
from typing import cast
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.models.detectors import Detector
from batdetect2.targets import TargetConfig
from batdetect2.train import load_model_from_checkpoint
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config()
source_classifier_head = api.model.detector.classifier_head
source_bbox_head = api.model.detector.bbox_head
source_backbone = api.model.detector.backbone
finetune_dir = tmp_path / "heads_only"
finetuned_api = api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=TargetConfig(),
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, finetuned_api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert finetuned_api is not api
assert detector.backbone is source_backbone
assert detector.classifier_head is not source_classifier_head
assert detector.bbox_head is not source_bbox_head
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow
def test_finetune_replaces_targets_and_checkpoint_owns_new_targets(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tuning writes checkpoints with the new targets."""
source_api = BatDetect2API.from_config()
source_evaluator = source_api.evaluator
source_formatter = source_api.formatter
source_output_transform = source_api.output_transform
new_targets = TargetConfig.model_validate(
{
"classification_targets": [
{
"name": "single_class",
"tags": [{"key": "class", "value": "single_class"}],
}
],
"roi": {"mapper": "top_left"},
}
)
finetune_dir = tmp_path / "new_targets"
finetuned_api = source_api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=new_targets,
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
checkpoints = list(finetune_dir.rglob("*.ckpt"))
assert source_api.targets.get_config() != new_targets.model_dump(
mode="json"
)
assert finetuned_api.targets.get_config() == new_targets.model_dump(
mode="json"
)
assert finetuned_api.evaluator is not source_evaluator
assert finetuned_api.formatter is not source_formatter
assert finetuned_api.output_transform is not source_output_transform
assert finetuned_api.evaluator.targets is finetuned_api.targets
assert finetuned_api.evaluator.transform is finetuned_api.output_transform
assert finetuned_api.model.class_names == ["single_class"]
assert finetuned_api.model.dimension_names == ["width", "height"]
assert checkpoints
_, configs = load_model_from_checkpoint(checkpoints[0])
assert configs.targets.model_dump(mode="json") == new_targets.model_dump(
mode="json"
)