fix: keep checkpoint targets immutable

This commit is contained in:
mbsantiago 2026-05-05 10:59:32 +01:00
parent 7d416e0f99
commit 75e52cc548
5 changed files with 31 additions and 16 deletions

View File

@ -592,7 +592,6 @@ class BatDetect2API:
def from_checkpoint(
cls,
path: data.PathLike,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
@ -616,21 +615,21 @@ class BatDetect2API:
build_targets,
check_target_compatibility,
)
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.train import load_model_from_checkpoint
model, configs = load_model_from_checkpoint(path)
model_config = configs.model
train_config = train_config or configs.train
audio_config = audio_config or AudioConfig(
samplerate=model_config.samplerate,
)
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
targets_config = targets_config or configs.targets
targets_config = configs.targets
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)

View File

@ -106,7 +106,6 @@ def evaluate_command(
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
logger.info("Initiating evaluation process...")
@ -120,11 +119,6 @@ def evaluate_command(
num_annotations=len(test_annotations),
)
target_conf = (
TargetConfig.load(targets_config)
if targets_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
@ -151,7 +145,6 @@ def evaluate_command(
api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,

View File

@ -228,6 +228,12 @@ def train_command(
"Checkpoint model configuration is loaded from the checkpoint."
)
if model_path is not None and target_conf is not None:
raise click.UsageError(
"--targets cannot be used with --model. "
"Checkpoint target configuration is loaded from the checkpoint."
)
if model_path is None:
api = BatDetect2API.from_config(
model_config=model_conf,
@ -242,7 +248,6 @@ def train_command(
else:
api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,

View File

@ -287,10 +287,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
source_detector = cast(Detector, source_model.detector)
# When
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=example_targets_config,
)
api = BatDetect2API.from_checkpoint(checkpoint_path)
# Then
detector = cast(Detector, api.model.detector)

View File

@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together(
assert result.exit_code != 0
assert "--model-config cannot be used with --model" in result.output
def test_cli_train_rejects_model_and_targets_together(
tiny_checkpoint_path: Path,
) -> None:
"""User story: checkpoint training does not accept new targets."""
result = CliRunner().invoke(
cli,
[
"train",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--targets",
"example_data/targets.yaml",
],
)
assert result.exit_code != 0
assert "--targets cannot be used with --model" in result.output