diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 10e5169..bd226b8 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -8,7 +8,6 @@ from soundevent.audio.files import get_audio_files from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader from batdetect2.config import BatDetect2Config -from batdetect2.core import merge_configs from batdetect2.data import Dataset, load_dataset_from_config from batdetect2.evaluate import ( DEFAULT_EVAL_DIR, @@ -47,7 +46,7 @@ from batdetect2.postprocess import ( build_postprocessor, ) from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor -from batdetect2.targets import TargetProtocol, build_targets +from batdetect2.targets import TargetConfig, TargetProtocol, build_targets from batdetect2.train import ( DEFAULT_CHECKPOINT_DIR, TrainingConfig, @@ -110,6 +109,7 @@ class BatDetect2API: num_epochs: int | None = None, run_name: str | None = None, seed: int | None = None, + model_config: ModelConfig | None = None, audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, ): @@ -118,7 +118,7 @@ class BatDetect2API: val_annotations=val_annotations, model=self.model, targets=self.targets, - model_config=self.model_config, + model_config=model_config or self.model_config, audio_loader=self.audio_loader, preprocessor=self.preprocessor, train_workers=train_workers, @@ -149,6 +149,7 @@ class BatDetect2API: num_epochs: int | None = None, run_name: str | None = None, seed: int | None = None, + model_config: ModelConfig | None = None, audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, ) -> "BatDetect2API": @@ -161,7 +162,7 @@ class BatDetect2API: val_annotations=val_annotations, model=self.model, targets=self.targets, - model_config=self.model_config, + model_config=model_config or self.model_config, preprocessor=self.preprocessor, audio_loader=self.audio_loader, train_workers=train_workers, @@ -499,76 +500,77 @@ class BatDetect2API: def from_checkpoint( cls, path: data.PathLike, - config: BatDetect2Config | None = None, - targets: TargetProtocol | None = None, + targets_config: TargetConfig | None = None, + audio_config: AudioConfig | None = None, + train_config: TrainingConfig | None = None, + evaluation_config: EvaluationConfig | None = None, + inference_config: InferenceConfig | None = None, + outputs_config: OutputsConfig | None = None, ) -> "BatDetect2API": - from batdetect2.audio import AudioConfig - model, model_config = load_model_from_checkpoint(path) - # Reconstruct a full BatDetect2Config from the checkpoint's - # ModelConfig, then overlay any caller-supplied overrides. - base = BatDetect2Config( - model=model_config, - audio=AudioConfig(samplerate=model_config.samplerate), + audio_config = audio_config or AudioConfig( + samplerate=model_config.samplerate, ) - config = merge_configs(base, config) if config else base + train_config = train_config or TrainingConfig() + evaluation_config = evaluation_config or EvaluationConfig() + inference_config = inference_config or InferenceConfig() + outputs_config = outputs_config or OutputsConfig() - if targets is None: - targets = build_targets(config=config.model.targets) - else: - target_config = getattr(targets, "config", None) - if target_config is not None: - config.model.targets = target_config - - audio_loader = build_audio_loader(config=config.audio) - - preprocessor = build_preprocessor( - input_samplerate=audio_loader.samplerate, - config=config.model.preprocess, - ) - - postprocessor = build_postprocessor( - preprocessor, - config=config.model.postprocess, - ) - - formatter = build_output_formatter( - targets, - config=config.outputs.format, - ) - output_transform = build_output_transform( - config=config.outputs.transform, - targets=targets, - ) - - evaluator = build_evaluator( - config=config.evaluation, - targets=targets, - transform=output_transform, - ) - - targets_changed = targets is not None or ( - config.model.targets.model_dump(mode="json") - != model_config.targets.model_dump(mode="json") - ) - if targets_changed: + if ( + targets_config is not None + and targets_config != model_config.targets + ): + targets = build_targets(config=targets_config) model = build_model_with_new_targets( model=model, targets=targets, ) + model_config = model_config.model_copy( + update={"targets": targets_config} + ) + + targets = build_targets(config=model_config.targets) + + audio_loader = build_audio_loader(config=audio_config) + + preprocessor = build_preprocessor( + input_samplerate=audio_loader.samplerate, + config=model_config.preprocess, + ) + + postprocessor = build_postprocessor( + preprocessor, + config=model_config.postprocess, + ) + + formatter = build_output_formatter( + targets, + config=outputs_config.format, + ) + + output_transform = build_output_transform( + config=outputs_config.transform, + targets=targets, + ) + + evaluator = build_evaluator( + config=evaluation_config, + targets=targets, + transform=output_transform, + ) model.preprocessor = preprocessor model.postprocessor = postprocessor model.targets = targets return cls( - model_config=config.model, - audio_config=config.audio, - train_config=config.train, - evaluation_config=config.evaluation, - inference_config=config.inference, - outputs_config=config.outputs, + model_config=model_config, + audio_config=audio_config, + train_config=train_config, + evaluation_config=evaluation_config, + inference_config=inference_config, + outputs_config=outputs_config, targets=targets, audio_loader=audio_loader, preprocessor=preprocessor, diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index f0d6517..9c76a33 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -12,9 +12,13 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation" @cli.command(name="evaluate") -@click.argument("model-path", type=click.Path(exists=True)) +@click.argument("model_path", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True)) -@click.option("--config", "config_path", type=click.Path()) +@click.option("--targets", "targets_config", type=click.Path(exists=True)) +@click.option("--audio-config", type=click.Path(exists=True)) +@click.option("--evaluation-config", type=click.Path(exists=True)) +@click.option("--inference-config", type=click.Path(exists=True)) +@click.option("--outputs-config", type=click.Path(exists=True)) @click.option("--base-dir", type=click.Path(), default=Path.cwd()) @click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR) @click.option("--experiment-name", type=str) @@ -24,15 +28,23 @@ def evaluate_command( model_path: Path, test_dataset: Path, base_dir: Path, - config_path: Path | None, + targets_config: Path | None, + audio_config: Path | None, + evaluation_config: Path | None, + inference_config: Path | None, + outputs_config: Path | None, output_dir: Path = DEFAULT_OUTPUT_DIR, num_workers: int = 0, experiment_name: str | None = None, run_name: str | None = None, ): from batdetect2.api_v2 import BatDetect2API - from batdetect2.config import load_full_config + from batdetect2.audio import AudioConfig from batdetect2.data import load_dataset_from_config + from batdetect2.evaluate import load_evaluation_config + from batdetect2.inference import InferenceConfig + from batdetect2.outputs import OutputsConfig + from batdetect2.targets import load_target_config logger.info("Initiating evaluation process...") @@ -46,11 +58,38 @@ def evaluate_command( num_annotations=len(test_annotations), ) - config = None - if config_path is not None: - config = load_full_config(config_path) + target_conf = ( + load_target_config(targets_config) + if targets_config is not None + else None + ) + audio_conf = ( + AudioConfig.load(audio_config) if audio_config is not None else None + ) + eval_conf = ( + load_evaluation_config(evaluation_config) + if evaluation_config is not None + else None + ) + inference_conf = ( + InferenceConfig.load(inference_config) + if inference_config is not None + else None + ) + outputs_conf = ( + OutputsConfig.load(outputs_config) + if outputs_config is not None + else None + ) - api = BatDetect2API.from_checkpoint(model_path, config=config) + api = BatDetect2API.from_checkpoint( + model_path, + targets_config=target_conf, + audio_config=audio_conf, + evaluation_config=eval_conf, + inference_config=inference_conf, + outputs_config=outputs_conf, + ) api.evaluate( test_annotations, diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 8837ad8..b473a60 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -13,10 +13,14 @@ __all__ = ["train_command"] @click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--model", "model_path", type=click.Path(exists=True)) @click.option("--targets", "targets_config", type=click.Path(exists=True)) +@click.option("--model-config", type=click.Path(exists=True)) +@click.option("--training-config", type=click.Path(exists=True)) +@click.option("--audio-config", type=click.Path(exists=True)) +@click.option("--evaluation-config", type=click.Path(exists=True)) +@click.option("--inference-config", type=click.Path(exists=True)) +@click.option("--outputs-config", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True)) -@click.option("--config", type=click.Path(exists=True)) -@click.option("--config-field", type=str) @click.option("--train-workers", type=int) @click.option("--val-workers", type=int) @click.option("--num-epochs", type=int) @@ -29,9 +33,13 @@ def train_command( model_path: Path | None = None, ckpt_dir: Path | None = None, log_dir: Path | None = None, - config: Path | None = None, targets_config: Path | None = None, - config_field: str | None = None, + model_config: Path | None = None, + training_config: Path | None = None, + audio_config: Path | None = None, + evaluation_config: Path | None = None, + inference_config: Path | None = None, + outputs_config: Path | None = None, seed: int | None = None, num_epochs: int | None = None, train_workers: int = 0, @@ -40,27 +48,56 @@ def train_command( run_name: str | None = None, ): from batdetect2.api_v2 import BatDetect2API - from batdetect2.config import ( - BatDetect2Config, - load_full_config, - ) + from batdetect2.audio import AudioConfig + from batdetect2.config import BatDetect2Config from batdetect2.data import load_dataset_from_config + from batdetect2.evaluate import load_evaluation_config + from batdetect2.inference import InferenceConfig + from batdetect2.models import ModelConfig + from batdetect2.outputs import OutputsConfig from batdetect2.targets import load_target_config + from batdetect2.train import load_train_config logger.info("Initiating training process...") logger.info("Loading configuration...") - conf = ( - load_full_config(config, field=config_field) - if config is not None - else BatDetect2Config() + target_conf = ( + load_target_config(targets_config) + if targets_config is not None + else None + ) + model_conf = ( + ModelConfig.load(model_config) if model_config is not None else None + ) + train_conf = ( + load_train_config(training_config) + if training_config is not None + else None + ) + audio_conf = ( + AudioConfig.load(audio_config) if audio_config is not None else None + ) + eval_conf = ( + load_evaluation_config(evaluation_config) + if evaluation_config is not None + else None + ) + inference_conf = ( + InferenceConfig.load(inference_config) + if inference_config is not None + else None + ) + outputs_conf = ( + OutputsConfig.load(outputs_config) + if outputs_config is not None + else None ) - if targets_config is not None: - logger.info("Loading targets configuration...") - conf = conf.model_copy( - update=dict(targets=load_target_config(targets_config)) - ) + if target_conf is not None: + logger.info("Loaded targets configuration.") + + if model_conf is not None and target_conf is not None: + model_conf = model_conf.model_copy(update={"targets": target_conf}) logger.info("Loading training dataset...") train_annotations = load_dataset_from_config(train_dataset) @@ -81,12 +118,40 @@ def train_command( logger.info("Configuration and data loaded. Starting training...") + if model_path is not None and model_conf is not None: + raise click.UsageError( + "--model-config cannot be used with --model. " + "Checkpoint model configuration is loaded from the checkpoint." + ) + if model_path is None: + conf = BatDetect2Config() + if model_conf is not None: + conf.model = model_conf + elif target_conf is not None: + conf.model = conf.model.model_copy(update={"targets": target_conf}) + + if train_conf is not None: + conf.train = train_conf + if audio_conf is not None: + conf.audio = audio_conf + if eval_conf is not None: + conf.evaluation = eval_conf + if inference_conf is not None: + conf.inference = inference_conf + if outputs_conf is not None: + conf.outputs = outputs_conf + api = BatDetect2API.from_config(conf) else: api = BatDetect2API.from_checkpoint( model_path, - config=conf if config is not None else None, + targets_config=target_conf, + train_config=train_conf, + audio_config=audio_conf, + evaluation_config=eval_conf, + inference_config=inference_conf, + outputs_config=outputs_conf, ) return api.train( diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index 118d366..49bb435 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -8,7 +8,6 @@ configuration data from files, with optional support for accessing nested configuration sections. """ -import sys from typing import Any, Type, TypeVar, overload import yaml @@ -16,17 +15,14 @@ from deepmerge.merger import Merger from pydantic import BaseModel, ConfigDict, TypeAdapter from soundevent.data import PathLike -if sys.version_info < (3, 11): - from typing_extensions import Self -else: - from typing import Self - __all__ = [ "BaseConfig", "load_config", "merge_configs", ] +C = TypeVar("C", bound="BaseConfig") + class BaseConfig(BaseModel): """Base class for all configuration models in BatDetect2. @@ -73,8 +69,8 @@ class BaseConfig(BaseModel): return cls.model_validate(yaml.safe_load(yaml_str)) @classmethod - def load(cls: Self, path: PathLike, field: str | None = None) -> Self: - return load_config(path, schema=cls, field=field) # type: ignore + def load(cls: Type[C], path: PathLike, field: str | None = None) -> C: + return load_config(path, schema=cls, field=field) T = TypeVar("T") diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index 4b3f71e..71a102a 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -201,7 +201,10 @@ def test_user_can_load_checkpoint_and_finetune( config.train.train_loader.batch_size = 1 config.train.train_loader.augmentations.enabled = False - api = BatDetect2API.from_checkpoint(checkpoint_path, config=config) + api = BatDetect2API.from_checkpoint( + checkpoint_path, + train_config=config.train, + ) finetune_dir = tmp_path / "finetuned" api.train( @@ -234,13 +237,13 @@ def test_user_can_load_checkpoint_with_new_targets( source_model, _ = load_model_from_checkpoint(checkpoint_path) api = BatDetect2API.from_checkpoint( checkpoint_path, - targets=sample_targets, + targets_config=sample_targets.config, ) source_detector = cast(Detector, source_model.detector) detector = cast(Detector, api.model.detector) classifier_head = cast(ClassifierHead, detector.classifier_head) - assert api.targets is sample_targets + assert api.targets.config == sample_targets.config assert detector.num_classes == len(sample_targets.class_names) assert ( classifier_head.classifier.out_channels @@ -255,6 +258,43 @@ def test_user_can_load_checkpoint_with_new_targets( torch.testing.assert_close(target_backbone[key], value) +def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( + tmp_path: Path, +) -> None: + """User story: same targets config does not rebuild prediction heads.""" + + module = build_training_module(model_config=BatDetect2Config().model) + trainer = L.Trainer(enable_checkpointing=False, logger=False) + checkpoint_path = tmp_path / "same_targets.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(checkpoint_path) + + source_model, source_model_config = load_model_from_checkpoint( + checkpoint_path + ) + source_detector = cast(Detector, source_model.detector) + + api = BatDetect2API.from_checkpoint( + checkpoint_path, + targets_config=source_model_config.targets, + ) + detector = cast(Detector, api.model.detector) + + for key, value in source_detector.classifier_head.state_dict().items(): + assert key in detector.classifier_head.state_dict() + torch.testing.assert_close( + detector.classifier_head.state_dict()[key], + value, + ) + + for key, value in source_detector.bbox_head.state_dict().items(): + assert key in detector.bbox_head.state_dict() + torch.testing.assert_close( + detector.bbox_head.state_dict()[key], + value, + ) + + def test_user_can_finetune_only_heads( tmp_path: Path, example_annotations, diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 05f5fea..1ea2e3c 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -176,10 +176,10 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path): api = BatDetect2API.from_checkpoint(path) - assert api.config.model.model_dump( + assert api.model_config.model_dump( mode="json" ) == module.model_config.model_dump(mode="json") - assert api.config.audio.samplerate == module.model_config.samplerate + assert api.audio_config.samplerate == module.model_config.samplerate def test_train_smoke_produces_loadable_checkpoint(