mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
2 Commits
7d416e0f99
...
7b2699786f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b2699786f | ||
|
|
75e52cc548 |
@ -137,6 +137,7 @@ class BatDetect2API:
|
|||||||
def finetune(
|
def finetune(
|
||||||
self,
|
self,
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
targets_config: TargetConfig,
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
trainable: Literal[
|
trainable: Literal[
|
||||||
"all", "heads", "classifier_head", "bbox_head"
|
"all", "heads", "classifier_head", "bbox_head"
|
||||||
@ -149,25 +150,76 @@ class BatDetect2API:
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
model_config: ModelConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
"""Fine-tune the model with trainable-parameter selection."""
|
"""Fine-tune from a checkpoint using a new target definition."""
|
||||||
|
from batdetect2.evaluate import build_evaluator
|
||||||
|
from batdetect2.models import build_model_with_new_targets
|
||||||
|
from batdetect2.outputs import (
|
||||||
|
build_output_formatter,
|
||||||
|
build_output_transform,
|
||||||
|
)
|
||||||
|
from batdetect2.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
)
|
||||||
from batdetect2.train import run_train
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
self._set_trainable_parameters(trainable)
|
target_config = TargetConfig.model_validate(targets_config)
|
||||||
|
targets = build_targets(config=target_config)
|
||||||
|
roi_mapper = build_roi_mapping(config=target_config.roi)
|
||||||
|
model = build_model_with_new_targets(
|
||||||
|
model=self.model,
|
||||||
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
)
|
||||||
|
output_transform = build_output_transform(
|
||||||
|
config=self.outputs_config.transform,
|
||||||
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
)
|
||||||
|
api = BatDetect2API(
|
||||||
|
model_config=self.model_config,
|
||||||
|
audio_config=audio_config or self.audio_config,
|
||||||
|
train_config=train_config or self.train_config,
|
||||||
|
evaluation_config=self.evaluation_config,
|
||||||
|
inference_config=self.inference_config,
|
||||||
|
outputs_config=self.outputs_config,
|
||||||
|
logging_config=self.logging_config,
|
||||||
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
postprocessor=self.postprocessor,
|
||||||
|
evaluator=build_evaluator(
|
||||||
|
config=self.evaluation_config,
|
||||||
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
transform=output_transform,
|
||||||
|
),
|
||||||
|
formatter=build_output_formatter(
|
||||||
|
targets,
|
||||||
|
config=self.outputs_config.format,
|
||||||
|
),
|
||||||
|
output_transform=output_transform,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
api._set_trainable_parameters(trainable)
|
||||||
|
api.model.train()
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
model=self.model,
|
model=api.model,
|
||||||
targets=self.targets,
|
targets=api.targets,
|
||||||
roi_mapper=self.roi_mapper,
|
roi_mapper=api.roi_mapper,
|
||||||
model_config=model_config or self.model_config,
|
model_config=api.model_config,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=api.preprocessor,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=api.audio_loader,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
@ -176,11 +228,12 @@ class BatDetect2API:
|
|||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
audio_config=audio_config or self.audio_config,
|
audio_config=api.audio_config,
|
||||||
train_config=train_config or self.train_config,
|
train_config=api.train_config,
|
||||||
logger_config=logger_config or self.logging_config.train,
|
logger_config=logger_config or api.logging_config.train,
|
||||||
)
|
)
|
||||||
return self
|
api.model.eval()
|
||||||
|
return api
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
@ -592,7 +645,6 @@ class BatDetect2API:
|
|||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
targets_config: TargetConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
@ -616,21 +668,21 @@ class BatDetect2API:
|
|||||||
build_targets,
|
build_targets,
|
||||||
check_target_compatibility,
|
check_target_compatibility,
|
||||||
)
|
)
|
||||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
from batdetect2.train import load_model_from_checkpoint
|
||||||
|
|
||||||
model, configs = load_model_from_checkpoint(path)
|
model, configs = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
model_config = configs.model
|
model_config = configs.model
|
||||||
|
train_config = train_config or configs.train
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig(
|
audio_config = audio_config or AudioConfig(
|
||||||
samplerate=model_config.samplerate,
|
samplerate=model_config.samplerate,
|
||||||
)
|
)
|
||||||
train_config = train_config or TrainingConfig()
|
|
||||||
evaluation_config = evaluation_config or EvaluationConfig()
|
evaluation_config = evaluation_config or EvaluationConfig()
|
||||||
inference_config = inference_config or InferenceConfig()
|
inference_config = inference_config or InferenceConfig()
|
||||||
outputs_config = outputs_config or OutputsConfig()
|
outputs_config = outputs_config or OutputsConfig()
|
||||||
logging_config = logging_config or AppLoggingConfig()
|
logging_config = logging_config or AppLoggingConfig()
|
||||||
targets_config = targets_config or configs.targets
|
targets_config = configs.targets
|
||||||
|
|
||||||
targets = build_targets(config=targets_config)
|
targets = build_targets(config=targets_config)
|
||||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
|
|||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.evaluate import evaluate_command
|
from batdetect2.cli.evaluate import evaluate_command
|
||||||
|
from batdetect2.cli.finetune import finetune_command
|
||||||
from batdetect2.cli.inference import predict
|
from batdetect2.cli.inference import predict
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ __all__ = [
|
|||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
|
"finetune_command",
|
||||||
"evaluate_command",
|
"evaluate_command",
|
||||||
"predict",
|
"predict",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -106,7 +106,6 @@ def evaluate_command(
|
|||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
@ -120,11 +119,6 @@ def evaluate_command(
|
|||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
)
|
)
|
||||||
|
|
||||||
target_conf = (
|
|
||||||
TargetConfig.load(targets_config)
|
|
||||||
if targets_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
audio_conf = (
|
audio_conf = (
|
||||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||||
)
|
)
|
||||||
@ -151,7 +145,6 @@ def evaluate_command(
|
|||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
targets_config=target_conf,
|
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
evaluation_config=eval_conf,
|
evaluation_config=eval_conf,
|
||||||
inference_config=inference_conf,
|
inference_config=inference_conf,
|
||||||
|
|||||||
188
src/batdetect2/cli/finetune.py
Normal file
188
src/batdetect2/cli/finetune.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
|
import click
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
|
||||||
|
__all__ = ["finetune_command"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(
|
||||||
|
name="finetune", short_help="Fine-tune a checkpoint on new targets."
|
||||||
|
)
|
||||||
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
|
@click.option(
|
||||||
|
"--model",
|
||||||
|
"model_path",
|
||||||
|
required=True,
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Path to a checkpoint to fine-tune from.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--targets",
|
||||||
|
"targets_config",
|
||||||
|
required=True,
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Path to the new targets config file.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--val-dataset",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Path to validation dataset config file.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Base directory used to resolve relative paths inside the training "
|
||||||
|
"and validation dataset configs."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--training-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Path to training config file.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--audio-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Path to audio config file.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--logging-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Path to logging config file.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--trainable",
|
||||||
|
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
||||||
|
default="heads",
|
||||||
|
show_default=True,
|
||||||
|
help="Which model parameters remain trainable during fine-tuning.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ckpt-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Directory where checkpoints are saved.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--log-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Directory where logs are written.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--train-workers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of worker processes for training data loading.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--val-workers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of worker processes for validation data loading.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
help="Maximum number of training epochs.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--experiment-name",
|
||||||
|
type=str,
|
||||||
|
help="Experiment name used for logging backends.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--run-name",
|
||||||
|
type=str,
|
||||||
|
help="Run name used for logging backends.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
help="Random seed used for reproducibility.",
|
||||||
|
)
|
||||||
|
def finetune_command(
|
||||||
|
train_dataset: Path,
|
||||||
|
model_path: Path,
|
||||||
|
targets_config: Path,
|
||||||
|
val_dataset: Path | None = None,
|
||||||
|
ckpt_dir: Path | None = None,
|
||||||
|
log_dir: Path | None = None,
|
||||||
|
base_dir: Path | None = None,
|
||||||
|
training_config: Path | None = None,
|
||||||
|
audio_config: Path | None = None,
|
||||||
|
logging_config: Path | None = None,
|
||||||
|
trainable: str = "heads",
|
||||||
|
seed: int | None = None,
|
||||||
|
num_epochs: int | None = None,
|
||||||
|
train_workers: int = 0,
|
||||||
|
val_workers: int = 0,
|
||||||
|
experiment_name: str | None = None,
|
||||||
|
run_name: str | None = None,
|
||||||
|
):
|
||||||
|
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.audio import AudioConfig
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.logging import AppLoggingConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
|
from batdetect2.train import TrainingConfig
|
||||||
|
|
||||||
|
logger.info("Initiating fine-tuning process...")
|
||||||
|
|
||||||
|
target_conf = TargetConfig.load(targets_config)
|
||||||
|
train_conf = (
|
||||||
|
TrainingConfig.load(training_config)
|
||||||
|
if training_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
audio_conf = (
|
||||||
|
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||||
|
)
|
||||||
|
logging_conf = (
|
||||||
|
AppLoggingConfig.load(logging_config)
|
||||||
|
if logging_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
train_annotations = load_dataset_from_config(
|
||||||
|
train_dataset,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
val_annotations = None
|
||||||
|
if val_dataset is not None:
|
||||||
|
val_annotations = load_dataset_from_config(
|
||||||
|
val_dataset,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
api = BatDetect2API.from_checkpoint(
|
||||||
|
model_path,
|
||||||
|
train_config=train_conf,
|
||||||
|
audio_config=audio_conf,
|
||||||
|
logging_config=logging_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
return api.finetune(
|
||||||
|
train_annotations=train_annotations,
|
||||||
|
val_annotations=val_annotations,
|
||||||
|
targets_config=target_conf,
|
||||||
|
trainable=cast(
|
||||||
|
Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||||
|
trainable,
|
||||||
|
),
|
||||||
|
train_workers=train_workers,
|
||||||
|
val_workers=val_workers,
|
||||||
|
checkpoint_dir=ckpt_dir,
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
num_epochs=num_epochs,
|
||||||
|
run_name=run_name,
|
||||||
|
seed=seed,
|
||||||
|
train_config=train_conf,
|
||||||
|
audio_config=audio_conf,
|
||||||
|
logger_config=logging_conf.train if logging_conf is not None else None,
|
||||||
|
)
|
||||||
@ -228,6 +228,12 @@ def train_command(
|
|||||||
"Checkpoint model configuration is loaded from the checkpoint."
|
"Checkpoint model configuration is loaded from the checkpoint."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_path is not None and target_conf is not None:
|
||||||
|
raise click.UsageError(
|
||||||
|
"--targets cannot be used with --model. "
|
||||||
|
"Checkpoint target configuration is loaded from the checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
api = BatDetect2API.from_config(
|
api = BatDetect2API.from_config(
|
||||||
model_config=model_conf,
|
model_config=model_conf,
|
||||||
@ -242,7 +248,6 @@ def train_command(
|
|||||||
else:
|
else:
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
targets_config=target_conf,
|
|
||||||
train_config=train_conf,
|
train_config=train_conf,
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
evaluation_config=eval_conf,
|
evaluation_config=eval_conf,
|
||||||
|
|||||||
0
tests/test_api_v2/__init__.py
Normal file
0
tests/test_api_v2/__init__.py
Normal file
@ -287,10 +287,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
source_detector = cast(Detector, source_model.detector)
|
source_detector = cast(Detector, source_model.detector)
|
||||||
|
|
||||||
# When
|
# When
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(checkpoint_path)
|
||||||
checkpoint_path,
|
|
||||||
targets_config=example_targets_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
detector = cast(Detector, api.model.detector)
|
detector = cast(Detector, api.model.detector)
|
||||||
@ -310,42 +307,6 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_user_can_finetune_only_heads(
|
|
||||||
tmp_path: Path,
|
|
||||||
example_annotations,
|
|
||||||
) -> None:
|
|
||||||
"""User story: fine-tune only prediction heads."""
|
|
||||||
|
|
||||||
api = BatDetect2API.from_config()
|
|
||||||
finetune_dir = tmp_path / "heads_only"
|
|
||||||
|
|
||||||
api.finetune(
|
|
||||||
train_annotations=example_annotations[:1],
|
|
||||||
val_annotations=example_annotations[:1],
|
|
||||||
trainable="heads",
|
|
||||||
train_workers=0,
|
|
||||||
val_workers=0,
|
|
||||||
checkpoint_dir=finetune_dir,
|
|
||||||
log_dir=tmp_path / "logs",
|
|
||||||
num_epochs=1,
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
detector = cast(Detector, api.model.detector)
|
|
||||||
|
|
||||||
backbone_params = list(detector.backbone.parameters())
|
|
||||||
classifier_params = list(detector.classifier_head.parameters())
|
|
||||||
bbox_params = list(detector.bbox_head.parameters())
|
|
||||||
|
|
||||||
assert backbone_params
|
|
||||||
assert classifier_params
|
|
||||||
assert bbox_params
|
|
||||||
assert all(not parameter.requires_grad for parameter in backbone_params)
|
|
||||||
assert all(parameter.requires_grad for parameter in classifier_params)
|
|
||||||
assert all(parameter.requires_grad for parameter in bbox_params)
|
|
||||||
assert list(finetune_dir.rglob("*.ckpt"))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
|
|||||||
114
tests/test_api_v2/test_finetune.py
Normal file
114
tests/test_api_v2/test_finetune.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.models.detectors import Detector
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
|
from batdetect2.train import load_model_from_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_user_can_finetune_only_heads(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations,
|
||||||
|
) -> None:
|
||||||
|
"""User story: fine-tune only prediction heads."""
|
||||||
|
|
||||||
|
api = BatDetect2API.from_config()
|
||||||
|
source_classifier_head = api.model.detector.classifier_head
|
||||||
|
source_bbox_head = api.model.detector.bbox_head
|
||||||
|
source_backbone = api.model.detector.backbone
|
||||||
|
finetune_dir = tmp_path / "heads_only"
|
||||||
|
|
||||||
|
finetuned_api = api.finetune(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
targets_config=TargetConfig(),
|
||||||
|
trainable="heads",
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=finetune_dir,
|
||||||
|
log_dir=tmp_path / "logs",
|
||||||
|
num_epochs=1,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
detector = cast(Detector, finetuned_api.model.detector)
|
||||||
|
|
||||||
|
backbone_params = list(detector.backbone.parameters())
|
||||||
|
classifier_params = list(detector.classifier_head.parameters())
|
||||||
|
bbox_params = list(detector.bbox_head.parameters())
|
||||||
|
|
||||||
|
assert backbone_params
|
||||||
|
assert classifier_params
|
||||||
|
assert bbox_params
|
||||||
|
assert all(not parameter.requires_grad for parameter in backbone_params)
|
||||||
|
assert all(parameter.requires_grad for parameter in classifier_params)
|
||||||
|
assert all(parameter.requires_grad for parameter in bbox_params)
|
||||||
|
assert finetuned_api is not api
|
||||||
|
assert detector.backbone is source_backbone
|
||||||
|
assert detector.classifier_head is not source_classifier_head
|
||||||
|
assert detector.bbox_head is not source_bbox_head
|
||||||
|
assert list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_finetune_replaces_targets_and_checkpoint_owns_new_targets(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations,
|
||||||
|
) -> None:
|
||||||
|
"""User story: fine-tuning writes checkpoints with the new targets."""
|
||||||
|
|
||||||
|
source_api = BatDetect2API.from_config()
|
||||||
|
source_evaluator = source_api.evaluator
|
||||||
|
source_formatter = source_api.formatter
|
||||||
|
source_output_transform = source_api.output_transform
|
||||||
|
new_targets = TargetConfig.model_validate(
|
||||||
|
{
|
||||||
|
"classification_targets": [
|
||||||
|
{
|
||||||
|
"name": "single_class",
|
||||||
|
"tags": [{"key": "class", "value": "single_class"}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"roi": {"mapper": "top_left"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finetune_dir = tmp_path / "new_targets"
|
||||||
|
|
||||||
|
finetuned_api = source_api.finetune(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
targets_config=new_targets,
|
||||||
|
trainable="heads",
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=finetune_dir,
|
||||||
|
log_dir=tmp_path / "logs",
|
||||||
|
num_epochs=1,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
|
||||||
|
assert source_api.targets.get_config() != new_targets.model_dump(
|
||||||
|
mode="json"
|
||||||
|
)
|
||||||
|
assert finetuned_api.targets.get_config() == new_targets.model_dump(
|
||||||
|
mode="json"
|
||||||
|
)
|
||||||
|
assert finetuned_api.evaluator is not source_evaluator
|
||||||
|
assert finetuned_api.formatter is not source_formatter
|
||||||
|
assert finetuned_api.output_transform is not source_output_transform
|
||||||
|
assert finetuned_api.evaluator.targets is finetuned_api.targets
|
||||||
|
assert finetuned_api.evaluator.transform is finetuned_api.output_transform
|
||||||
|
assert finetuned_api.model.class_names == ["single_class"]
|
||||||
|
assert finetuned_api.model.dimension_names == ["width", "height"]
|
||||||
|
assert checkpoints
|
||||||
|
|
||||||
|
_, configs = load_model_from_checkpoint(checkpoints[0])
|
||||||
|
assert configs.targets.model_dump(mode="json") == new_targets.model_dump(
|
||||||
|
mode="json"
|
||||||
|
)
|
||||||
0
tests/test_cli/__init__.py
Normal file
0
tests/test_cli/__init__.py
Normal file
99
tests/test_cli/test_finetune.py
Normal file
99
tests/test_cli/test_finetune.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
"""CLI tests for finetune command."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_finetune_help() -> None:
|
||||||
|
"""User story: inspect finetune command interface and options."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(cli, ["finetune", "--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "TRAIN_DATASET" in result.output
|
||||||
|
assert "--model" in result.output
|
||||||
|
assert "--targets" in result.output
|
||||||
|
assert "--training-config" in result.output
|
||||||
|
assert "--audio-config" in result.output
|
||||||
|
assert "--logging-config" in result.output
|
||||||
|
assert "--evaluation-config" not in result.output
|
||||||
|
assert "--inference-config" not in result.output
|
||||||
|
assert "--outputs-config" not in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_finetune_requires_model() -> None:
|
||||||
|
"""User story: finetune requires a checkpoint argument."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"finetune",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--targets",
|
||||||
|
"example_data/targets.yaml",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "--model" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
||||||
|
"""User story: finetune requires a new target definition."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"finetune",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "--targets" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_cli_finetune_from_checkpoint_runs_on_small_dataset(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: fine-tune a checkpoint via CLI with new targets."""
|
||||||
|
|
||||||
|
ckpt_dir = tmp_path / "checkpoints"
|
||||||
|
log_dir = tmp_path / "logs"
|
||||||
|
ckpt_dir.mkdir()
|
||||||
|
log_dir.mkdir()
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"finetune",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--val-dataset",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
"--targets",
|
||||||
|
"example_data/targets.yaml",
|
||||||
|
"--num-epochs",
|
||||||
|
"1",
|
||||||
|
"--train-workers",
|
||||||
|
"0",
|
||||||
|
"--val-workers",
|
||||||
|
"0",
|
||||||
|
"--ckpt-dir",
|
||||||
|
str(ckpt_dir),
|
||||||
|
"--log-dir",
|
||||||
|
str(log_dir),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1
|
||||||
@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together(
|
|||||||
|
|
||||||
assert result.exit_code != 0
|
assert result.exit_code != 0
|
||||||
assert "--model-config cannot be used with --model" in result.output
|
assert "--model-config cannot be used with --model" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_train_rejects_model_and_targets_together(
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: checkpoint training does not accept new targets."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
"--targets",
|
||||||
|
"example_data/targets.yaml",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "--targets cannot be used with --model" in result.output
|
||||||
|
|||||||
@ -10,7 +10,11 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
|||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.models import ModelConfig, build_model
|
from batdetect2.models import (
|
||||||
|
ModelConfig,
|
||||||
|
build_model,
|
||||||
|
build_model_with_new_targets,
|
||||||
|
)
|
||||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
@ -322,6 +326,49 @@ def test_build_training_module_uses_provided_model() -> None:
|
|||||||
assert module.model is model
|
assert module.model is model
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
|
source_targets_config = TargetConfig()
|
||||||
|
source_targets = build_targets(source_targets_config)
|
||||||
|
source_roi_mapper = build_roi_mapping(source_targets_config.roi)
|
||||||
|
source_model = build_model(
|
||||||
|
ModelConfig(),
|
||||||
|
class_names=source_targets.class_names,
|
||||||
|
dimension_names=source_roi_mapper.dimension_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_targets_config = TargetConfig.model_validate(
|
||||||
|
{
|
||||||
|
"classification_targets": [
|
||||||
|
{
|
||||||
|
"name": "single_class",
|
||||||
|
"tags": [{"key": "class", "value": "single_class"}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
new_targets = build_targets(new_targets_config)
|
||||||
|
new_roi_mapper = build_roi_mapping(new_targets_config.roi)
|
||||||
|
|
||||||
|
rebuilt_model = build_model_with_new_targets(
|
||||||
|
model=source_model,
|
||||||
|
targets=new_targets,
|
||||||
|
roi_mapper=new_roi_mapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
source_detector = source_model.detector
|
||||||
|
rebuilt_detector = rebuilt_model.detector
|
||||||
|
|
||||||
|
assert rebuilt_detector.backbone is source_detector.backbone
|
||||||
|
assert (
|
||||||
|
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
||||||
|
)
|
||||||
|
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
|
||||||
|
assert rebuilt_model.class_names == ["single_class"]
|
||||||
|
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||||
|
|
||||||
|
|
||||||
def test_run_train_rejects_incompatible_model_config(
|
def test_run_train_rejects_incompatible_model_config(
|
||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user