mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
fix: keep checkpoint targets immutable
This commit is contained in:
parent
7d416e0f99
commit
75e52cc548
@ -592,7 +592,6 @@ class BatDetect2API:
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
targets_config: TargetConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
@ -616,21 +615,21 @@ class BatDetect2API:
|
||||
build_targets,
|
||||
check_target_compatibility,
|
||||
)
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
|
||||
model, configs = load_model_from_checkpoint(path)
|
||||
|
||||
model_config = configs.model
|
||||
train_config = train_config or configs.train
|
||||
|
||||
audio_config = audio_config or AudioConfig(
|
||||
samplerate=model_config.samplerate,
|
||||
)
|
||||
train_config = train_config or TrainingConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
outputs_config = outputs_config or OutputsConfig()
|
||||
logging_config = logging_config or AppLoggingConfig()
|
||||
targets_config = targets_config or configs.targets
|
||||
targets_config = configs.targets
|
||||
|
||||
targets = build_targets(config=targets_config)
|
||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||
|
||||
@ -106,7 +106,6 @@ def evaluate_command(
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
@ -120,11 +119,6 @@ def evaluate_command(
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
target_conf = (
|
||||
TargetConfig.load(targets_config)
|
||||
if targets_config is not None
|
||||
else None
|
||||
)
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
@ -151,7 +145,6 @@ def evaluate_command(
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
targets_config=target_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
|
||||
@ -228,6 +228,12 @@ def train_command(
|
||||
"Checkpoint model configuration is loaded from the checkpoint."
|
||||
)
|
||||
|
||||
if model_path is not None and target_conf is not None:
|
||||
raise click.UsageError(
|
||||
"--targets cannot be used with --model. "
|
||||
"Checkpoint target configuration is loaded from the checkpoint."
|
||||
)
|
||||
|
||||
if model_path is None:
|
||||
api = BatDetect2API.from_config(
|
||||
model_config=model_conf,
|
||||
@ -242,7 +248,6 @@ def train_command(
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
targets_config=target_conf,
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
|
||||
@ -287,10 +287,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
source_detector = cast(Detector, source_model.detector)
|
||||
|
||||
# When
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
targets_config=example_targets_config,
|
||||
)
|
||||
api = BatDetect2API.from_checkpoint(checkpoint_path)
|
||||
|
||||
# Then
|
||||
detector = cast(Detector, api.model.detector)
|
||||
|
||||
@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together(
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--model-config cannot be used with --model" in result.output
|
||||
|
||||
|
||||
def test_cli_train_rejects_model_and_targets_together(
|
||||
tiny_checkpoint_path: Path,
|
||||
) -> None:
|
||||
"""User story: checkpoint training does not accept new targets."""
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
"example_data/dataset.yaml",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
"--targets",
|
||||
"example_data/targets.yaml",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--targets cannot be used with --model" in result.output
|
||||
|
||||
Loading…
Reference in New Issue
Block a user