mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
fix: keep checkpoint targets immutable
This commit is contained in:
parent
7d416e0f99
commit
75e52cc548
@ -592,7 +592,6 @@ class BatDetect2API:
|
|||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
targets_config: TargetConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
@ -616,21 +615,21 @@ class BatDetect2API:
|
|||||||
build_targets,
|
build_targets,
|
||||||
check_target_compatibility,
|
check_target_compatibility,
|
||||||
)
|
)
|
||||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
from batdetect2.train import load_model_from_checkpoint
|
||||||
|
|
||||||
model, configs = load_model_from_checkpoint(path)
|
model, configs = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
model_config = configs.model
|
model_config = configs.model
|
||||||
|
train_config = train_config or configs.train
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig(
|
audio_config = audio_config or AudioConfig(
|
||||||
samplerate=model_config.samplerate,
|
samplerate=model_config.samplerate,
|
||||||
)
|
)
|
||||||
train_config = train_config or TrainingConfig()
|
|
||||||
evaluation_config = evaluation_config or EvaluationConfig()
|
evaluation_config = evaluation_config or EvaluationConfig()
|
||||||
inference_config = inference_config or InferenceConfig()
|
inference_config = inference_config or InferenceConfig()
|
||||||
outputs_config = outputs_config or OutputsConfig()
|
outputs_config = outputs_config or OutputsConfig()
|
||||||
logging_config = logging_config or AppLoggingConfig()
|
logging_config = logging_config or AppLoggingConfig()
|
||||||
targets_config = targets_config or configs.targets
|
targets_config = configs.targets
|
||||||
|
|
||||||
targets = build_targets(config=targets_config)
|
targets = build_targets(config=targets_config)
|
||||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||||
|
|||||||
@ -106,7 +106,6 @@ def evaluate_command(
|
|||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
@ -120,11 +119,6 @@ def evaluate_command(
|
|||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
)
|
)
|
||||||
|
|
||||||
target_conf = (
|
|
||||||
TargetConfig.load(targets_config)
|
|
||||||
if targets_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
audio_conf = (
|
audio_conf = (
|
||||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||||
)
|
)
|
||||||
@ -151,7 +145,6 @@ def evaluate_command(
|
|||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
targets_config=target_conf,
|
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
evaluation_config=eval_conf,
|
evaluation_config=eval_conf,
|
||||||
inference_config=inference_conf,
|
inference_config=inference_conf,
|
||||||
|
|||||||
@ -228,6 +228,12 @@ def train_command(
|
|||||||
"Checkpoint model configuration is loaded from the checkpoint."
|
"Checkpoint model configuration is loaded from the checkpoint."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_path is not None and target_conf is not None:
|
||||||
|
raise click.UsageError(
|
||||||
|
"--targets cannot be used with --model. "
|
||||||
|
"Checkpoint target configuration is loaded from the checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
api = BatDetect2API.from_config(
|
api = BatDetect2API.from_config(
|
||||||
model_config=model_conf,
|
model_config=model_conf,
|
||||||
@ -242,7 +248,6 @@ def train_command(
|
|||||||
else:
|
else:
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
targets_config=target_conf,
|
|
||||||
train_config=train_conf,
|
train_config=train_conf,
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
evaluation_config=eval_conf,
|
evaluation_config=eval_conf,
|
||||||
|
|||||||
@ -287,10 +287,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
source_detector = cast(Detector, source_model.detector)
|
source_detector = cast(Detector, source_model.detector)
|
||||||
|
|
||||||
# When
|
# When
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(checkpoint_path)
|
||||||
checkpoint_path,
|
|
||||||
targets_config=example_targets_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
detector = cast(Detector, api.model.detector)
|
detector = cast(Detector, api.model.detector)
|
||||||
|
|||||||
@ -81,3 +81,24 @@ def test_cli_train_rejects_model_and_model_config_together(
|
|||||||
|
|
||||||
assert result.exit_code != 0
|
assert result.exit_code != 0
|
||||||
assert "--model-config cannot be used with --model" in result.output
|
assert "--model-config cannot be used with --model" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_train_rejects_model_and_targets_together(
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: checkpoint training does not accept new targets."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
"--targets",
|
||||||
|
"example_data/targets.yaml",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "--targets cannot be used with --model" in result.output
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user