Polish evaluate and train CLI

This commit is contained in:
mbsantiago 2026-03-18 19:15:57 +00:00
parent f9056eb19a
commit bf5b88016a
6 changed files with 240 additions and 98 deletions

View File

@ -8,7 +8,6 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import (
DEFAULT_EVAL_DIR,
@ -47,7 +46,7 @@ from batdetect2.postprocess import (
build_postprocessor,
)
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
@ -110,6 +109,7 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
):
@ -118,7 +118,7 @@ class BatDetect2API:
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
model_config=self.model_config,
model_config=model_config or self.model_config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
@ -149,6 +149,7 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
) -> "BatDetect2API":
@ -161,7 +162,7 @@ class BatDetect2API:
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
model_config=self.model_config,
model_config=model_config or self.model_config,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
train_workers=train_workers,
@ -499,76 +500,77 @@ class BatDetect2API:
def from_checkpoint(
cls,
path: data.PathLike,
config: BatDetect2Config | None = None,
targets: TargetProtocol | None = None,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
) -> "BatDetect2API":
from batdetect2.audio import AudioConfig
model, model_config = load_model_from_checkpoint(path)
# Reconstruct a full BatDetect2Config from the checkpoint's
# ModelConfig, then overlay any caller-supplied overrides.
base = BatDetect2Config(
model=model_config,
audio=AudioConfig(samplerate=model_config.samplerate),
audio_config = audio_config or AudioConfig(
samplerate=model_config.samplerate,
)
config = merge_configs(base, config) if config else base
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
if targets is None:
targets = build_targets(config=config.model.targets)
else:
target_config = getattr(targets, "config", None)
if target_config is not None:
config.model.targets = target_config
if (
targets_config is not None
and targets_config != model_config.targets
):
targets = build_targets(config=targets_config)
model = build_model_with_new_targets(
model=model,
targets=targets,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
)
audio_loader = build_audio_loader(config=config.audio)
targets = build_targets(config=model_config.targets)
audio_loader = build_audio_loader(config=audio_config)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.model.preprocess,
config=model_config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.model.postprocess,
config=model_config.postprocess,
)
formatter = build_output_formatter(
targets,
config=config.outputs.format,
config=outputs_config.format,
)
output_transform = build_output_transform(
config=config.outputs.transform,
config=outputs_config.transform,
targets=targets,
)
evaluator = build_evaluator(
config=config.evaluation,
config=evaluation_config,
targets=targets,
transform=output_transform,
)
targets_changed = targets is not None or (
config.model.targets.model_dump(mode="json")
!= model_config.targets.model_dump(mode="json")
)
if targets_changed:
model = build_model_with_new_targets(
model=model,
targets=targets,
)
model.preprocessor = preprocessor
model.postprocessor = postprocessor
model.targets = targets
return cls(
model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
model_config=model_config,
audio_config=audio_config,
train_config=train_config,
evaluation_config=evaluation_config,
inference_config=inference_config,
outputs_config=outputs_config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,

View File

@ -12,9 +12,13 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--config", "config_path", type=click.Path())
@click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@ -24,15 +28,23 @@ def evaluate_command(
model_path: Path,
test_dataset: Path,
base_dir: Path,
config_path: Path | None,
targets_config: Path | None,
audio_config: Path | None,
evaluation_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: int = 0,
experiment_name: str | None = None,
run_name: str | None = None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config
logger.info("Initiating evaluation process...")
@ -46,11 +58,38 @@ def evaluate_command(
num_annotations=len(test_annotations),
)
config = None
if config_path is not None:
config = load_full_config(config_path)
target_conf = (
load_target_config(targets_config)
if targets_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
load_evaluation_config(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
api = BatDetect2API.from_checkpoint(model_path, config=config)
api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
)
api.evaluate(
test_annotations,

View File

@ -13,10 +13,14 @@ __all__ = ["train_command"]
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model", "model_path", type=click.Path(exists=True))
@click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--model-config", type=click.Path(exists=True))
@click.option("--training-config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int)
@ -29,9 +33,13 @@ def train_command(
model_path: Path | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
config: Path | None = None,
targets_config: Path | None = None,
config_field: str | None = None,
model_config: Path | None = None,
training_config: Path | None = None,
audio_config: Path | None = None,
evaluation_config: Path | None = None,
inference_config: Path | None = None,
outputs_config: Path | None = None,
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
@ -40,27 +48,56 @@ def train_command(
run_name: str | None = None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import (
BatDetect2Config,
load_full_config,
)
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config
from batdetect2.train import load_train_config
logger.info("Initiating training process...")
logger.info("Loading configuration...")
conf = (
load_full_config(config, field=config_field)
if config is not None
else BatDetect2Config()
target_conf = (
load_target_config(targets_config)
if targets_config is not None
else None
)
model_conf = (
ModelConfig.load(model_config) if model_config is not None else None
)
train_conf = (
load_train_config(training_config)
if training_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
load_evaluation_config(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
if targets_config is not None:
logger.info("Loading targets configuration...")
conf = conf.model_copy(
update=dict(targets=load_target_config(targets_config))
)
if target_conf is not None:
logger.info("Loaded targets configuration.")
if model_conf is not None and target_conf is not None:
model_conf = model_conf.model_copy(update={"targets": target_conf})
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset)
@ -81,12 +118,40 @@ def train_command(
logger.info("Configuration and data loaded. Starting training...")
if model_path is not None and model_conf is not None:
raise click.UsageError(
"--model-config cannot be used with --model. "
"Checkpoint model configuration is loaded from the checkpoint."
)
if model_path is None:
conf = BatDetect2Config()
if model_conf is not None:
conf.model = model_conf
elif target_conf is not None:
conf.model = conf.model.model_copy(update={"targets": target_conf})
if train_conf is not None:
conf.train = train_conf
if audio_conf is not None:
conf.audio = audio_conf
if eval_conf is not None:
conf.evaluation = eval_conf
if inference_conf is not None:
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
api = BatDetect2API.from_config(conf)
else:
api = BatDetect2API.from_checkpoint(
model_path,
config=conf if config is not None else None,
targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
)
return api.train(

View File

@ -8,7 +8,6 @@ configuration data from files, with optional support for accessing nested
configuration sections.
"""
import sys
from typing import Any, Type, TypeVar, overload
import yaml
@ -16,17 +15,14 @@ from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict, TypeAdapter
from soundevent.data import PathLike
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
__all__ = [
"BaseConfig",
"load_config",
"merge_configs",
]
C = TypeVar("C", bound="BaseConfig")
class BaseConfig(BaseModel):
"""Base class for all configuration models in BatDetect2.
@ -73,8 +69,8 @@ class BaseConfig(BaseModel):
return cls.model_validate(yaml.safe_load(yaml_str))
@classmethod
def load(cls: Self, path: PathLike, field: str | None = None) -> Self:
return load_config(path, schema=cls, field=field) # type: ignore
def load(cls: Type[C], path: PathLike, field: str | None = None) -> C:
return load_config(path, schema=cls, field=field)
T = TypeVar("T")

View File

@ -201,7 +201,10 @@ def test_user_can_load_checkpoint_and_finetune(
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
api = BatDetect2API.from_checkpoint(checkpoint_path, config=config)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
train_config=config.train,
)
finetune_dir = tmp_path / "finetuned"
api.train(
@ -234,13 +237,13 @@ def test_user_can_load_checkpoint_with_new_targets(
source_model, _ = load_model_from_checkpoint(checkpoint_path)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets=sample_targets,
targets_config=sample_targets.config,
)
source_detector = cast(Detector, source_model.detector)
detector = cast(Detector, api.model.detector)
classifier_head = cast(ClassifierHead, detector.classifier_head)
assert api.targets is sample_targets
assert api.targets.config == sample_targets.config
assert detector.num_classes == len(sample_targets.class_names)
assert (
classifier_head.classifier.out_channels
@ -255,6 +258,43 @@ def test_user_can_load_checkpoint_with_new_targets(
torch.testing.assert_close(target_backbone[key], value)
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
tmp_path: Path,
) -> None:
"""User story: same targets config does not rebuild prediction heads."""
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "same_targets.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, source_model_config = load_model_from_checkpoint(
checkpoint_path
)
source_detector = cast(Detector, source_model.detector)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=source_model_config.targets,
)
detector = cast(Detector, api.model.detector)
for key, value in source_detector.classifier_head.state_dict().items():
assert key in detector.classifier_head.state_dict()
torch.testing.assert_close(
detector.classifier_head.state_dict()[key],
value,
)
for key, value in source_detector.bbox_head.state_dict().items():
assert key in detector.bbox_head.state_dict()
torch.testing.assert_close(
detector.bbox_head.state_dict()[key],
value,
)
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,

View File

@ -176,10 +176,10 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
api = BatDetect2API.from_checkpoint(path)
assert api.config.model.model_dump(
assert api.model_config.model_dump(
mode="json"
) == module.model_config.model_dump(mode="json")
assert api.config.audio.samplerate == module.model_config.samplerate
assert api.audio_config.samplerate == module.model_config.samplerate
def test_train_smoke_produces_loadable_checkpoint(