Polish evaluate and train CLI

This commit is contained in:
mbsantiago 2026-03-18 19:15:57 +00:00
parent f9056eb19a
commit bf5b88016a
6 changed files with 240 additions and 98 deletions

View File

@ -8,7 +8,6 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import Dataset, load_dataset_from_config from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import ( from batdetect2.evaluate import (
DEFAULT_EVAL_DIR, DEFAULT_EVAL_DIR,
@ -47,7 +46,7 @@ from batdetect2.postprocess import (
build_postprocessor, build_postprocessor,
) )
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetProtocol, build_targets from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.train import ( from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR, DEFAULT_CHECKPOINT_DIR,
TrainingConfig, TrainingConfig,
@ -110,6 +109,7 @@ class BatDetect2API:
num_epochs: int | None = None, num_epochs: int | None = None,
run_name: str | None = None, run_name: str | None = None,
seed: int | None = None, seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
): ):
@ -118,7 +118,7 @@ class BatDetect2API:
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model, model=self.model,
targets=self.targets, targets=self.targets,
model_config=self.model_config, model_config=model_config or self.model_config,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
train_workers=train_workers, train_workers=train_workers,
@ -149,6 +149,7 @@ class BatDetect2API:
num_epochs: int | None = None, num_epochs: int | None = None,
run_name: str | None = None, run_name: str | None = None,
seed: int | None = None, seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
@ -161,7 +162,7 @@ class BatDetect2API:
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model, model=self.model,
targets=self.targets, targets=self.targets,
model_config=self.model_config, model_config=model_config or self.model_config,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
train_workers=train_workers, train_workers=train_workers,
@ -499,76 +500,77 @@ class BatDetect2API:
def from_checkpoint( def from_checkpoint(
cls, cls,
path: data.PathLike, path: data.PathLike,
config: BatDetect2Config | None = None, targets_config: TargetConfig | None = None,
targets: TargetProtocol | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
from batdetect2.audio import AudioConfig
model, model_config = load_model_from_checkpoint(path) model, model_config = load_model_from_checkpoint(path)
# Reconstruct a full BatDetect2Config from the checkpoint's audio_config = audio_config or AudioConfig(
# ModelConfig, then overlay any caller-supplied overrides. samplerate=model_config.samplerate,
base = BatDetect2Config(
model=model_config,
audio=AudioConfig(samplerate=model_config.samplerate),
) )
config = merge_configs(base, config) if config else base train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
if targets is None: if (
targets = build_targets(config=config.model.targets) targets_config is not None
else: and targets_config != model_config.targets
target_config = getattr(targets, "config", None) ):
if target_config is not None: targets = build_targets(config=targets_config)
config.model.targets = target_config model = build_model_with_new_targets(
model=model,
targets=targets,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
)
audio_loader = build_audio_loader(config=config.audio) targets = build_targets(config=model_config.targets)
audio_loader = build_audio_loader(config=audio_config)
preprocessor = build_preprocessor( preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.model.preprocess, config=model_config.preprocess,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
preprocessor, preprocessor,
config=config.model.postprocess, config=model_config.postprocess,
) )
formatter = build_output_formatter( formatter = build_output_formatter(
targets, targets,
config=config.outputs.format, config=outputs_config.format,
) )
output_transform = build_output_transform( output_transform = build_output_transform(
config=config.outputs.transform, config=outputs_config.transform,
targets=targets, targets=targets,
) )
evaluator = build_evaluator( evaluator = build_evaluator(
config=config.evaluation, config=evaluation_config,
targets=targets, targets=targets,
transform=output_transform, transform=output_transform,
) )
targets_changed = targets is not None or (
config.model.targets.model_dump(mode="json")
!= model_config.targets.model_dump(mode="json")
)
if targets_changed:
model = build_model_with_new_targets(
model=model,
targets=targets,
)
model.preprocessor = preprocessor model.preprocessor = preprocessor
model.postprocessor = postprocessor model.postprocessor = postprocessor
model.targets = targets model.targets = targets
return cls( return cls(
model_config=config.model, model_config=model_config,
audio_config=config.audio, audio_config=audio_config,
train_config=config.train, train_config=train_config,
evaluation_config=config.evaluation, evaluation_config=evaluation_config,
inference_config=config.inference, inference_config=inference_config,
outputs_config=config.outputs, outputs_config=outputs_config,
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,

View File

@ -12,9 +12,13 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate") @cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True)) @click.argument("model_path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--config", "config_path", type=click.Path()) @click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--base-dir", type=click.Path(), default=Path.cwd()) @click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR) @click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@ -24,15 +28,23 @@ def evaluate_command(
model_path: Path, model_path: Path,
test_dataset: Path, test_dataset: Path,
base_dir: Path, base_dir: Path,
config_path: Path | None, targets_config: Path | None,
audio_config: Path | None,
evaluation_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR, output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: int = 0, num_workers: int = 0,
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config
logger.info("Initiating evaluation process...") logger.info("Initiating evaluation process...")
@ -46,11 +58,38 @@ def evaluate_command(
num_annotations=len(test_annotations), num_annotations=len(test_annotations),
) )
config = None target_conf = (
if config_path is not None: load_target_config(targets_config)
config = load_full_config(config_path) if targets_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
load_evaluation_config(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
api = BatDetect2API.from_checkpoint(model_path, config=config) api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
)
api.evaluate( api.evaluate(
test_annotations, test_annotations,

View File

@ -13,10 +13,14 @@ __all__ = ["train_command"]
@click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model", "model_path", type=click.Path(exists=True)) @click.option("--model", "model_path", type=click.Path(exists=True))
@click.option("--targets", "targets_config", type=click.Path(exists=True)) @click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--model-config", type=click.Path(exists=True))
@click.option("--training-config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int) @click.option("--num-epochs", type=int)
@ -29,9 +33,13 @@ def train_command(
model_path: Path | None = None, model_path: Path | None = None,
ckpt_dir: Path | None = None, ckpt_dir: Path | None = None,
log_dir: Path | None = None, log_dir: Path | None = None,
config: Path | None = None,
targets_config: Path | None = None, targets_config: Path | None = None,
config_field: str | None = None, model_config: Path | None = None,
training_config: Path | None = None,
audio_config: Path | None = None,
evaluation_config: Path | None = None,
inference_config: Path | None = None,
outputs_config: Path | None = None,
seed: int | None = None, seed: int | None = None,
num_epochs: int | None = None, num_epochs: int | None = None,
train_workers: int = 0, train_workers: int = 0,
@ -40,27 +48,56 @@ def train_command(
run_name: str | None = None, run_name: str | None = None,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import ( from batdetect2.audio import AudioConfig
BatDetect2Config, from batdetect2.config import BatDetect2Config
load_full_config,
)
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config from batdetect2.targets import load_target_config
from batdetect2.train import load_train_config
logger.info("Initiating training process...") logger.info("Initiating training process...")
logger.info("Loading configuration...") logger.info("Loading configuration...")
conf = ( target_conf = (
load_full_config(config, field=config_field) load_target_config(targets_config)
if config is not None if targets_config is not None
else BatDetect2Config() else None
)
model_conf = (
ModelConfig.load(model_config) if model_config is not None else None
)
train_conf = (
load_train_config(training_config)
if training_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
load_evaluation_config(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
) )
if targets_config is not None: if target_conf is not None:
logger.info("Loading targets configuration...") logger.info("Loaded targets configuration.")
conf = conf.model_copy(
update=dict(targets=load_target_config(targets_config)) if model_conf is not None and target_conf is not None:
) model_conf = model_conf.model_copy(update={"targets": target_conf})
logger.info("Loading training dataset...") logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset) train_annotations = load_dataset_from_config(train_dataset)
@ -81,12 +118,40 @@ def train_command(
logger.info("Configuration and data loaded. Starting training...") logger.info("Configuration and data loaded. Starting training...")
if model_path is not None and model_conf is not None:
raise click.UsageError(
"--model-config cannot be used with --model. "
"Checkpoint model configuration is loaded from the checkpoint."
)
if model_path is None: if model_path is None:
conf = BatDetect2Config()
if model_conf is not None:
conf.model = model_conf
elif target_conf is not None:
conf.model = conf.model.model_copy(update={"targets": target_conf})
if train_conf is not None:
conf.train = train_conf
if audio_conf is not None:
conf.audio = audio_conf
if eval_conf is not None:
conf.evaluation = eval_conf
if inference_conf is not None:
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
api = BatDetect2API.from_config(conf) api = BatDetect2API.from_config(conf)
else: else:
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,
config=conf if config is not None else None, targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
) )
return api.train( return api.train(

View File

@ -8,7 +8,6 @@ configuration data from files, with optional support for accessing nested
configuration sections. configuration sections.
""" """
import sys
from typing import Any, Type, TypeVar, overload from typing import Any, Type, TypeVar, overload
import yaml import yaml
@ -16,17 +15,14 @@ from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict, TypeAdapter from pydantic import BaseModel, ConfigDict, TypeAdapter
from soundevent.data import PathLike from soundevent.data import PathLike
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
__all__ = [ __all__ = [
"BaseConfig", "BaseConfig",
"load_config", "load_config",
"merge_configs", "merge_configs",
] ]
C = TypeVar("C", bound="BaseConfig")
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
"""Base class for all configuration models in BatDetect2. """Base class for all configuration models in BatDetect2.
@ -73,8 +69,8 @@ class BaseConfig(BaseModel):
return cls.model_validate(yaml.safe_load(yaml_str)) return cls.model_validate(yaml.safe_load(yaml_str))
@classmethod @classmethod
def load(cls: Self, path: PathLike, field: str | None = None) -> Self: def load(cls: Type[C], path: PathLike, field: str | None = None) -> C:
return load_config(path, schema=cls, field=field) # type: ignore return load_config(path, schema=cls, field=field)
T = TypeVar("T") T = TypeVar("T")

View File

@ -201,7 +201,10 @@ def test_user_can_load_checkpoint_and_finetune(
config.train.train_loader.batch_size = 1 config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False config.train.train_loader.augmentations.enabled = False
api = BatDetect2API.from_checkpoint(checkpoint_path, config=config) api = BatDetect2API.from_checkpoint(
checkpoint_path,
train_config=config.train,
)
finetune_dir = tmp_path / "finetuned" finetune_dir = tmp_path / "finetuned"
api.train( api.train(
@ -234,13 +237,13 @@ def test_user_can_load_checkpoint_with_new_targets(
source_model, _ = load_model_from_checkpoint(checkpoint_path) source_model, _ = load_model_from_checkpoint(checkpoint_path)
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
checkpoint_path, checkpoint_path,
targets=sample_targets, targets_config=sample_targets.config,
) )
source_detector = cast(Detector, source_model.detector) source_detector = cast(Detector, source_model.detector)
detector = cast(Detector, api.model.detector) detector = cast(Detector, api.model.detector)
classifier_head = cast(ClassifierHead, detector.classifier_head) classifier_head = cast(ClassifierHead, detector.classifier_head)
assert api.targets is sample_targets assert api.targets.config == sample_targets.config
assert detector.num_classes == len(sample_targets.class_names) assert detector.num_classes == len(sample_targets.class_names)
assert ( assert (
classifier_head.classifier.out_channels classifier_head.classifier.out_channels
@ -255,6 +258,43 @@ def test_user_can_load_checkpoint_with_new_targets(
torch.testing.assert_close(target_backbone[key], value) torch.testing.assert_close(target_backbone[key], value)
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
tmp_path: Path,
) -> None:
"""User story: same targets config does not rebuild prediction heads."""
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "same_targets.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, source_model_config = load_model_from_checkpoint(
checkpoint_path
)
source_detector = cast(Detector, source_model.detector)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=source_model_config.targets,
)
detector = cast(Detector, api.model.detector)
for key, value in source_detector.classifier_head.state_dict().items():
assert key in detector.classifier_head.state_dict()
torch.testing.assert_close(
detector.classifier_head.state_dict()[key],
value,
)
for key, value in source_detector.bbox_head.state_dict().items():
assert key in detector.bbox_head.state_dict()
torch.testing.assert_close(
detector.bbox_head.state_dict()[key],
value,
)
def test_user_can_finetune_only_heads( def test_user_can_finetune_only_heads(
tmp_path: Path, tmp_path: Path,
example_annotations, example_annotations,

View File

@ -176,10 +176,10 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
api = BatDetect2API.from_checkpoint(path) api = BatDetect2API.from_checkpoint(path)
assert api.config.model.model_dump( assert api.model_config.model_dump(
mode="json" mode="json"
) == module.model_config.model_dump(mode="json") ) == module.model_config.model_dump(mode="json")
assert api.config.audio.samplerate == module.model_config.samplerate assert api.audio_config.samplerate == module.model_config.samplerate
def test_train_smoke_produces_loadable_checkpoint( def test_train_smoke_produces_loadable_checkpoint(