From 75e52cc548779847ddeb7019e188253bac262ab3 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 5 May 2026 10:59:32 +0100 Subject: [PATCH] fix: keep checkpoint targets immutable --- src/batdetect2/api_v2.py | 7 +++---- src/batdetect2/cli/evaluate.py | 7 ------- src/batdetect2/cli/train.py | 7 ++++++- tests/test_api_v2/test_api_v2.py | 5 +---- tests/test_cli/test_train.py | 21 +++++++++++++++++++++ 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 0513d63..0cee1d5 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -592,7 +592,6 @@ class BatDetect2API: def from_checkpoint( cls, path: data.PathLike, - targets_config: TargetConfig | None = None, audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, evaluation_config: EvaluationConfig | None = None, @@ -616,21 +615,21 @@ class BatDetect2API: build_targets, check_target_compatibility, ) - from batdetect2.train import TrainingConfig, load_model_from_checkpoint + from batdetect2.train import load_model_from_checkpoint model, configs = load_model_from_checkpoint(path) model_config = configs.model + train_config = train_config or configs.train audio_config = audio_config or AudioConfig( samplerate=model_config.samplerate, ) - train_config = train_config or TrainingConfig() evaluation_config = evaluation_config or EvaluationConfig() inference_config = inference_config or InferenceConfig() outputs_config = outputs_config or OutputsConfig() logging_config = logging_config or AppLoggingConfig() - targets_config = targets_config or configs.targets + targets_config = configs.targets targets = build_targets(config=targets_config) roi_mapper = build_roi_mapping(config=targets_config.roi) diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index 472be19..feaa9e1 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -106,7 +106,6 @@ def evaluate_command( from batdetect2.inference import InferenceConfig from batdetect2.logging import AppLoggingConfig from batdetect2.outputs import OutputsConfig - from batdetect2.targets import TargetConfig logger.info("Initiating evaluation process...") @@ -120,11 +119,6 @@ def evaluate_command( num_annotations=len(test_annotations), ) - target_conf = ( - TargetConfig.load(targets_config) - if targets_config is not None - else None - ) audio_conf = ( AudioConfig.load(audio_config) if audio_config is not None else None ) @@ -151,7 +145,6 @@ def evaluate_command( api = BatDetect2API.from_checkpoint( model_path, - targets_config=target_conf, audio_config=audio_conf, evaluation_config=eval_conf, inference_config=inference_conf, diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 4326e0c..1b428b7 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -228,6 +228,12 @@ def train_command( "Checkpoint model configuration is loaded from the checkpoint." ) + if model_path is not None and target_conf is not None: + raise click.UsageError( + "--targets cannot be used with --model. " + "Checkpoint target configuration is loaded from the checkpoint." + ) + if model_path is None: api = BatDetect2API.from_config( model_config=model_conf, @@ -242,7 +248,6 @@ def train_command( else: api = BatDetect2API.from_checkpoint( model_path, - targets_config=target_conf, train_config=train_conf, audio_config=audio_conf, evaluation_config=eval_conf, diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index e3fa01f..e2ffe31 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -287,10 +287,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( source_detector = cast(Detector, source_model.detector) # When - api = BatDetect2API.from_checkpoint( - checkpoint_path, - targets_config=example_targets_config, - ) + api = BatDetect2API.from_checkpoint(checkpoint_path) # Then detector = cast(Detector, api.model.detector) diff --git a/tests/test_cli/test_train.py b/tests/test_cli/test_train.py index 2212a57..5fd0210 100644 --- a/tests/test_cli/test_train.py +++ b/tests/test_cli/test_train.py @@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together( assert result.exit_code != 0 assert "--model-config cannot be used with --model" in result.output + + +def test_cli_train_rejects_model_and_targets_together( + tiny_checkpoint_path: Path, +) -> None: + """User story: checkpoint training does not accept new targets.""" + + result = CliRunner().invoke( + cli, + [ + "train", + "example_data/dataset.yaml", + "--model", + str(tiny_checkpoint_path), + "--targets", + "example_data/targets.yaml", + ], + ) + + assert result.exit_code != 0 + assert "--targets cannot be used with --model" in result.output