Compare commits

..

No commits in common. "7b2699786f4925f9b9df13ea91bb4f5e8fb441f9" and "7d416e0f991e69f0f4648bdcc9cd569581e9b0bd" have entirely different histories.

12 changed files with 66 additions and 548 deletions

View File

@ -137,7 +137,6 @@ class BatDetect2API:
def finetune( def finetune(
self, self,
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
targets_config: TargetConfig,
val_annotations: Sequence[data.ClipAnnotation] | None = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
trainable: Literal[ trainable: Literal[
"all", "heads", "classifier_head", "bbox_head" "all", "heads", "classifier_head", "bbox_head"
@ -150,76 +149,25 @@ class BatDetect2API:
num_epochs: int | None = None, num_epochs: int | None = None,
run_name: str | None = None, run_name: str | None = None,
seed: int | None = None, seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None, logger_config: LoggerConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
"""Fine-tune from a checkpoint using a new target definition.""" """Fine-tune the model with trainable-parameter selection."""
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
build_output_formatter,
build_output_transform,
)
from batdetect2.targets import (
TargetConfig,
build_roi_mapping,
build_targets,
)
from batdetect2.train import run_train from batdetect2.train import run_train
target_config = TargetConfig.model_validate(targets_config) self._set_trainable_parameters(trainable)
targets = build_targets(config=target_config)
roi_mapper = build_roi_mapping(config=target_config.roi)
model = build_model_with_new_targets(
model=self.model,
targets=targets,
roi_mapper=roi_mapper,
)
output_transform = build_output_transform(
config=self.outputs_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
api = BatDetect2API(
model_config=self.model_config,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
evaluation_config=self.evaluation_config,
inference_config=self.inference_config,
outputs_config=self.outputs_config,
logging_config=self.logging_config,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
postprocessor=self.postprocessor,
evaluator=build_evaluator(
config=self.evaluation_config,
targets=targets,
roi_mapper=roi_mapper,
transform=output_transform,
),
formatter=build_output_formatter(
targets,
config=self.outputs_config.format,
),
output_transform=output_transform,
model=model,
)
api._set_trainable_parameters(trainable)
api.model.train()
run_train( run_train(
train_annotations=train_annotations, train_annotations=train_annotations,
val_annotations=val_annotations, val_annotations=val_annotations,
model=api.model, model=self.model,
targets=api.targets, targets=self.targets,
roi_mapper=api.roi_mapper, roi_mapper=self.roi_mapper,
model_config=api.model_config, model_config=model_config or self.model_config,
preprocessor=api.preprocessor, preprocessor=self.preprocessor,
audio_loader=api.audio_loader, audio_loader=self.audio_loader,
train_workers=train_workers, train_workers=train_workers,
val_workers=val_workers, val_workers=val_workers,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
@ -228,12 +176,11 @@ class BatDetect2API:
num_epochs=num_epochs, num_epochs=num_epochs,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,
audio_config=api.audio_config, audio_config=audio_config or self.audio_config,
train_config=api.train_config, train_config=train_config or self.train_config,
logger_config=logger_config or api.logging_config.train, logger_config=logger_config or self.logging_config.train,
) )
api.model.eval() return self
return api
def evaluate( def evaluate(
self, self,
@ -645,6 +592,7 @@ class BatDetect2API:
def from_checkpoint( def from_checkpoint(
cls, cls,
path: data.PathLike, path: data.PathLike,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None, evaluation_config: EvaluationConfig | None = None,
@ -668,21 +616,21 @@ class BatDetect2API:
build_targets, build_targets,
check_target_compatibility, check_target_compatibility,
) )
from batdetect2.train import load_model_from_checkpoint from batdetect2.train import TrainingConfig, load_model_from_checkpoint
model, configs = load_model_from_checkpoint(path) model, configs = load_model_from_checkpoint(path)
model_config = configs.model model_config = configs.model
train_config = train_config or configs.train
audio_config = audio_config or AudioConfig( audio_config = audio_config or AudioConfig(
samplerate=model_config.samplerate, samplerate=model_config.samplerate,
) )
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig() evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig() inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig() outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig() logging_config = logging_config or AppLoggingConfig()
targets_config = configs.targets targets_config = targets_config or configs.targets
targets = build_targets(config=targets_config) targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi) roi_mapper = build_roi_mapping(config=targets_config.roi)

View File

@ -2,7 +2,6 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect from batdetect2.cli.compat import detect
from batdetect2.cli.data import data from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.finetune import finetune_command
from batdetect2.cli.inference import predict from batdetect2.cli.inference import predict
from batdetect2.cli.train import train_command from batdetect2.cli.train import train_command
@ -11,7 +10,6 @@ __all__ = [
"detect", "detect",
"data", "data",
"train_command", "train_command",
"finetune_command",
"evaluate_command", "evaluate_command",
"predict", "predict",
] ]

View File

@ -106,6 +106,7 @@ def evaluate_command(
from batdetect2.inference import InferenceConfig from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
logger.info("Initiating evaluation process...") logger.info("Initiating evaluation process...")
@ -119,6 +120,11 @@ def evaluate_command(
num_annotations=len(test_annotations), num_annotations=len(test_annotations),
) )
target_conf = (
TargetConfig.load(targets_config)
if targets_config is not None
else None
)
audio_conf = ( audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None AudioConfig.load(audio_config) if audio_config is not None else None
) )
@ -145,6 +151,7 @@ def evaluate_command(
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,
targets_config=target_conf,
audio_config=audio_conf, audio_config=audio_conf,
evaluation_config=eval_conf, evaluation_config=eval_conf,
inference_config=inference_conf, inference_config=inference_conf,

View File

@ -1,188 +0,0 @@
from pathlib import Path
from typing import Literal, cast
import click
from loguru import logger
from batdetect2.cli.base import cli
__all__ = ["finetune_command"]
@cli.command(
name="finetune", short_help="Fine-tune a checkpoint on new targets."
)
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option(
"--model",
"model_path",
required=True,
type=click.Path(exists=True),
help="Path to a checkpoint to fine-tune from.",
)
@click.option(
"--targets",
"targets_config",
required=True,
type=click.Path(exists=True),
help="Path to the new targets config file.",
)
@click.option(
"--val-dataset",
type=click.Path(exists=True),
help="Path to validation dataset config file.",
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"Base directory used to resolve relative paths inside the training "
"and validation dataset configs."
),
)
@click.option(
"--training-config",
type=click.Path(exists=True),
help="Path to training config file.",
)
@click.option(
"--audio-config",
type=click.Path(exists=True),
help="Path to audio config file.",
)
@click.option(
"--logging-config",
type=click.Path(exists=True),
help="Path to logging config file.",
)
@click.option(
"--trainable",
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
default="heads",
show_default=True,
help="Which model parameters remain trainable during fine-tuning.",
)
@click.option(
"--ckpt-dir",
type=click.Path(exists=True),
help="Directory where checkpoints are saved.",
)
@click.option(
"--log-dir",
type=click.Path(exists=True),
help="Directory where logs are written.",
)
@click.option(
"--train-workers",
type=int,
default=0,
help="Number of worker processes for training data loading.",
)
@click.option(
"--val-workers",
type=int,
default=0,
help="Number of worker processes for validation data loading.",
)
@click.option(
"--num-epochs",
type=int,
help="Maximum number of training epochs.",
)
@click.option(
"--experiment-name",
type=str,
help="Experiment name used for logging backends.",
)
@click.option(
"--run-name",
type=str,
help="Run name used for logging backends.",
)
@click.option(
"--seed",
type=int,
help="Random seed used for reproducibility.",
)
def finetune_command(
train_dataset: Path,
model_path: Path,
targets_config: Path,
val_dataset: Path | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
base_dir: Path | None = None,
training_config: Path | None = None,
audio_config: Path | None = None,
logging_config: Path | None = None,
trainable: str = "heads",
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: str | None = None,
run_name: str | None = None,
):
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config
from batdetect2.logging import AppLoggingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
logger.info("Initiating fine-tuning process...")
target_conf = TargetConfig.load(targets_config)
train_conf = (
TrainingConfig.load(training_config)
if training_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
train_annotations = load_dataset_from_config(
train_dataset,
base_dir=base_dir,
)
val_annotations = None
if val_dataset is not None:
val_annotations = load_dataset_from_config(
val_dataset,
base_dir=base_dir,
)
api = BatDetect2API.from_checkpoint(
model_path,
train_config=train_conf,
audio_config=audio_conf,
logging_config=logging_conf,
)
return api.finetune(
train_annotations=train_annotations,
val_annotations=val_annotations,
targets_config=target_conf,
trainable=cast(
Literal["all", "heads", "classifier_head", "bbox_head"],
trainable,
),
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
experiment_name=experiment_name,
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
train_config=train_conf,
audio_config=audio_conf,
logger_config=logging_conf.train if logging_conf is not None else None,
)

View File

@ -228,12 +228,6 @@ def train_command(
"Checkpoint model configuration is loaded from the checkpoint." "Checkpoint model configuration is loaded from the checkpoint."
) )
if model_path is not None and target_conf is not None:
raise click.UsageError(
"--targets cannot be used with --model. "
"Checkpoint target configuration is loaded from the checkpoint."
)
if model_path is None: if model_path is None:
api = BatDetect2API.from_config( api = BatDetect2API.from_config(
model_config=model_conf, model_config=model_conf,
@ -248,6 +242,7 @@ def train_command(
else: else:
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,
targets_config=target_conf,
train_config=train_conf, train_config=train_conf,
audio_config=audio_conf, audio_config=audio_conf,
evaluation_config=eval_conf, evaluation_config=eval_conf,

View File

@ -287,7 +287,10 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
source_detector = cast(Detector, source_model.detector) source_detector = cast(Detector, source_model.detector)
# When # When
api = BatDetect2API.from_checkpoint(checkpoint_path) api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=example_targets_config,
)
# Then # Then
detector = cast(Detector, api.model.detector) detector = cast(Detector, api.model.detector)
@ -307,6 +310,42 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
) )
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config()
finetune_dir = tmp_path / "heads_only"
api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow @pytest.mark.slow
def test_user_can_evaluate_small_dataset_and_get_metrics( def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API, api_v2: BatDetect2API,

View File

@ -1,114 +0,0 @@
from pathlib import Path
from typing import cast
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.models.detectors import Detector
from batdetect2.targets import TargetConfig
from batdetect2.train import load_model_from_checkpoint
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config()
source_classifier_head = api.model.detector.classifier_head
source_bbox_head = api.model.detector.bbox_head
source_backbone = api.model.detector.backbone
finetune_dir = tmp_path / "heads_only"
finetuned_api = api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=TargetConfig(),
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, finetuned_api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert finetuned_api is not api
assert detector.backbone is source_backbone
assert detector.classifier_head is not source_classifier_head
assert detector.bbox_head is not source_bbox_head
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow
def test_finetune_replaces_targets_and_checkpoint_owns_new_targets(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tuning writes checkpoints with the new targets."""
source_api = BatDetect2API.from_config()
source_evaluator = source_api.evaluator
source_formatter = source_api.formatter
source_output_transform = source_api.output_transform
new_targets = TargetConfig.model_validate(
{
"classification_targets": [
{
"name": "single_class",
"tags": [{"key": "class", "value": "single_class"}],
}
],
"roi": {"mapper": "top_left"},
}
)
finetune_dir = tmp_path / "new_targets"
finetuned_api = source_api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
targets_config=new_targets,
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
checkpoints = list(finetune_dir.rglob("*.ckpt"))
assert source_api.targets.get_config() != new_targets.model_dump(
mode="json"
)
assert finetuned_api.targets.get_config() == new_targets.model_dump(
mode="json"
)
assert finetuned_api.evaluator is not source_evaluator
assert finetuned_api.formatter is not source_formatter
assert finetuned_api.output_transform is not source_output_transform
assert finetuned_api.evaluator.targets is finetuned_api.targets
assert finetuned_api.evaluator.transform is finetuned_api.output_transform
assert finetuned_api.model.class_names == ["single_class"]
assert finetuned_api.model.dimension_names == ["width", "height"]
assert checkpoints
_, configs = load_model_from_checkpoint(checkpoints[0])
assert configs.targets.model_dump(mode="json") == new_targets.model_dump(
mode="json"
)

View File

@ -1,99 +0,0 @@
"""CLI tests for finetune command."""
from pathlib import Path
import pytest
from click.testing import CliRunner
from batdetect2.cli import cli
def test_cli_finetune_help() -> None:
"""User story: inspect finetune command interface and options."""
result = CliRunner().invoke(cli, ["finetune", "--help"])
assert result.exit_code == 0
assert "TRAIN_DATASET" in result.output
assert "--model" in result.output
assert "--targets" in result.output
assert "--training-config" in result.output
assert "--audio-config" in result.output
assert "--logging-config" in result.output
assert "--evaluation-config" not in result.output
assert "--inference-config" not in result.output
assert "--outputs-config" not in result.output
def test_cli_finetune_requires_model() -> None:
"""User story: finetune requires a checkpoint argument."""
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--targets",
"example_data/targets.yaml",
],
)
assert result.exit_code != 0
assert "--model" in result.output
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
"""User story: finetune requires a new target definition."""
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
],
)
assert result.exit_code != 0
assert "--targets" in result.output
@pytest.mark.slow
def test_cli_finetune_from_checkpoint_runs_on_small_dataset(
tmp_path: Path,
tiny_checkpoint_path: Path,
) -> None:
"""User story: fine-tune a checkpoint via CLI with new targets."""
ckpt_dir = tmp_path / "checkpoints"
log_dir = tmp_path / "logs"
ckpt_dir.mkdir()
log_dir.mkdir()
result = CliRunner().invoke(
cli,
[
"finetune",
"example_data/dataset.yaml",
"--val-dataset",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--targets",
"example_data/targets.yaml",
"--num-epochs",
"1",
"--train-workers",
"0",
"--val-workers",
"0",
"--ckpt-dir",
str(ckpt_dir),
"--log-dir",
str(log_dir),
],
)
assert result.exit_code == 0
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1

View File

@ -81,24 +81,3 @@ def test_cli_train_rejects_model_and_model_config_together(
assert result.exit_code != 0 assert result.exit_code != 0
assert "--model-config cannot be used with --model" in result.output assert "--model-config cannot be used with --model" in result.output
def test_cli_train_rejects_model_and_targets_together(
tiny_checkpoint_path: Path,
) -> None:
"""User story: checkpoint training does not accept new targets."""
result = CliRunner().invoke(
cli,
[
"train",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--targets",
"example_data/targets.yaml",
],
)
assert result.exit_code != 0
assert "--targets cannot be used with --model" in result.output

View File

@ -10,11 +10,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio.types import AudioLoader from batdetect2.audio.types import AudioLoader
from batdetect2.models import ( from batdetect2.models import ModelConfig, build_model
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.train import ( from batdetect2.train import (
TrainingConfig, TrainingConfig,
@ -326,49 +322,6 @@ def test_build_training_module_uses_provided_model() -> None:
assert module.model is model assert module.model is model
def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
None
):
source_targets_config = TargetConfig()
source_targets = build_targets(source_targets_config)
source_roi_mapper = build_roi_mapping(source_targets_config.roi)
source_model = build_model(
ModelConfig(),
class_names=source_targets.class_names,
dimension_names=source_roi_mapper.dimension_names,
)
new_targets_config = TargetConfig.model_validate(
{
"classification_targets": [
{
"name": "single_class",
"tags": [{"key": "class", "value": "single_class"}],
}
]
}
)
new_targets = build_targets(new_targets_config)
new_roi_mapper = build_roi_mapping(new_targets_config.roi)
rebuilt_model = build_model_with_new_targets(
model=source_model,
targets=new_targets,
roi_mapper=new_roi_mapper,
)
source_detector = source_model.detector
rebuilt_detector = rebuilt_model.detector
assert rebuilt_detector.backbone is source_detector.backbone
assert (
rebuilt_detector.classifier_head is not source_detector.classifier_head
)
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
assert rebuilt_model.class_names == ["single_class"]
assert rebuilt_model.dimension_names == ["width", "height"]
def test_run_train_rejects_incompatible_model_config( def test_run_train_rejects_incompatible_model_config(
example_annotations: list[data.ClipAnnotation], example_annotations: list[data.ClipAnnotation],
) -> None: ) -> None: