fix: keep checkpoint targets immutable

This commit is contained in:
mbsantiago 2026-05-05 10:59:32 +01:00
parent 7d416e0f99
commit 75e52cc548
5 changed files with 31 additions and 16 deletions

View File

@ -592,7 +592,6 @@ class BatDetect2API:
def from_checkpoint( def from_checkpoint(
cls, cls,
path: data.PathLike, path: data.PathLike,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None, evaluation_config: EvaluationConfig | None = None,
@ -616,21 +615,21 @@ class BatDetect2API:
build_targets, build_targets,
check_target_compatibility, check_target_compatibility,
) )
from batdetect2.train import TrainingConfig, load_model_from_checkpoint from batdetect2.train import load_model_from_checkpoint
model, configs = load_model_from_checkpoint(path) model, configs = load_model_from_checkpoint(path)
model_config = configs.model model_config = configs.model
train_config = train_config or configs.train
audio_config = audio_config or AudioConfig( audio_config = audio_config or AudioConfig(
samplerate=model_config.samplerate, samplerate=model_config.samplerate,
) )
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig() evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig() inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig() outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig() logging_config = logging_config or AppLoggingConfig()
targets_config = targets_config or configs.targets targets_config = configs.targets
targets = build_targets(config=targets_config) targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi) roi_mapper = build_roi_mapping(config=targets_config.roi)

View File

@ -106,7 +106,6 @@ def evaluate_command(
from batdetect2.inference import InferenceConfig from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
logger.info("Initiating evaluation process...") logger.info("Initiating evaluation process...")
@ -120,11 +119,6 @@ def evaluate_command(
num_annotations=len(test_annotations), num_annotations=len(test_annotations),
) )
target_conf = (
TargetConfig.load(targets_config)
if targets_config is not None
else None
)
audio_conf = ( audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None AudioConfig.load(audio_config) if audio_config is not None else None
) )
@ -151,7 +145,6 @@ def evaluate_command(
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,
targets_config=target_conf,
audio_config=audio_conf, audio_config=audio_conf,
evaluation_config=eval_conf, evaluation_config=eval_conf,
inference_config=inference_conf, inference_config=inference_conf,

View File

@ -228,6 +228,12 @@ def train_command(
"Checkpoint model configuration is loaded from the checkpoint." "Checkpoint model configuration is loaded from the checkpoint."
) )
if model_path is not None and target_conf is not None:
raise click.UsageError(
"--targets cannot be used with --model. "
"Checkpoint target configuration is loaded from the checkpoint."
)
if model_path is None: if model_path is None:
api = BatDetect2API.from_config( api = BatDetect2API.from_config(
model_config=model_conf, model_config=model_conf,
@ -242,7 +248,6 @@ def train_command(
else: else:
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,
targets_config=target_conf,
train_config=train_conf, train_config=train_conf,
audio_config=audio_conf, audio_config=audio_conf,
evaluation_config=eval_conf, evaluation_config=eval_conf,

View File

@ -287,10 +287,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
source_detector = cast(Detector, source_model.detector) source_detector = cast(Detector, source_model.detector)
# When # When
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(checkpoint_path)
checkpoint_path,
targets_config=example_targets_config,
)
# Then # Then
detector = cast(Detector, api.model.detector) detector = cast(Detector, api.model.detector)

View File

@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together(
assert result.exit_code != 0 assert result.exit_code != 0
assert "--model-config cannot be used with --model" in result.output assert "--model-config cannot be used with --model" in result.output
def test_cli_train_rejects_model_and_targets_together(
tiny_checkpoint_path: Path,
) -> None:
"""User story: checkpoint training does not accept new targets."""
result = CliRunner().invoke(
cli,
[
"train",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--targets",
"example_data/targets.yaml",
],
)
assert result.exit_code != 0
assert "--targets cannot be used with --model" in result.output