diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index c6a5492..9f872b5 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -610,17 +610,15 @@ class BatDetect2API: from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor from batdetect2.targets import ( - build_default_target_config, build_roi_mapping, build_targets, check_target_compatibility, ) - from batdetect2.train import ( - TrainingConfig, - load_model_from_checkpoint, - ) + from batdetect2.train import TrainingConfig, load_model_from_checkpoint - model, model_config = load_model_from_checkpoint(path) + model, configs = load_model_from_checkpoint(path) + + model_config = configs.model audio_config = audio_config or AudioConfig( samplerate=model_config.samplerate, @@ -630,9 +628,7 @@ class BatDetect2API: inference_config = inference_config or InferenceConfig() outputs_config = outputs_config or OutputsConfig() logging_config = logging_config or AppLoggingConfig() - targets_config = targets_config or build_default_target_config( - class_names=model.class_names - ) + targets_config = targets_config or configs.targets targets = build_targets(config=targets_config) roi_mapper = build_roi_mapping(config=targets_config.roi) diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index f25a221..3060532 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -211,7 +211,7 @@ class Model(torch.nn.Module): def build_model( - config: ModelConfig | None = None, + config: ModelConfig | dict | None = None, class_names: list[str] | None = None, dimension_names: list[str] | None = None, preprocessor: PreprocessorProtocol | None = None, @@ -257,6 +257,9 @@ def build_model( config = config or ModelConfig() + if isinstance(config, dict): + config = ModelConfig.model_validate(config) + if class_names is None: raise ValueError("class_names must be provided when building a model.") diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 280dadd..80a17a4 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,15 +1,21 @@ +from dataclasses import dataclass + import lightning as L from soundevent.data import PathLike from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models.types import ModelOutput +from batdetect2.targets import TargetConfig from batdetect2.train.config import TrainingConfig from batdetect2.train.losses import build_loss from batdetect2.train.optimizers import build_optimizer from batdetect2.train.schedulers import build_scheduler from batdetect2.train.types import LossProtocol, TrainExample -__all__ = ["TrainingModule"] +__all__ = [ + "TrainingModule", + "load_model_from_checkpoint", +] class TrainingModule(L.LightningModule): @@ -19,6 +25,7 @@ class TrainingModule(L.LightningModule): def __init__( self, model_config: dict | None = None, + targets_config: dict | None = None, class_names: list[str] | None = None, dimension_names: list[str] | None = None, train_config: dict | None = None, @@ -29,9 +36,11 @@ class TrainingModule(L.LightningModule): self.save_hyperparameters(ignore=["model", "loss"], logger=False) - self.model_config = ModelConfig.model_validate(model_config or {}) + self.model_config: dict = model_config or {} + self.targets_config: dict = targets_config or {} self.class_names = list(class_names or []) self.dimension_names = list(dimension_names or []) + self.train_config = TrainingConfig.model_validate(train_config or {}) if loss is None: @@ -113,9 +122,16 @@ class TrainingModule(L.LightningModule): } +@dataclass +class StoredConfig: + model: ModelConfig + targets: TargetConfig + train: TrainingConfig + + def load_model_from_checkpoint( path: PathLike, -) -> tuple[Model, ModelConfig]: +) -> tuple[Model, StoredConfig]: """Load a model and its configuration from a Lightning checkpoint. Parameters @@ -131,11 +147,19 @@ def load_model_from_checkpoint( describes its architecture, preprocessing, and postprocessing. """ module = TrainingModule.load_from_checkpoint(path) # type: ignore - return module.model, module.model_config + training_config = TrainingConfig.model_validate(module.train_config) + model_config = ModelConfig.model_validate(module.model_config) + targets_config = TargetConfig.model_validate(module.targets_config) + return module.model, StoredConfig( + model=model_config, + targets=targets_config, + train=training_config, + ) def build_training_module( model_config: ModelConfig | None = None, + targets_config: TargetConfig | dict | None = None, class_names: list[str] | None = None, dimension_names: list[str] | None = None, train_config: TrainingConfig | None = None, @@ -147,10 +171,16 @@ def build_training_module( if train_config is None: train_config = TrainingConfig() + if targets_config is None: + targets_config = TargetConfig() + + targets_config = TargetConfig.model_validate(targets_config) + return TrainingModule( model_config=model_config.model_dump(mode="json"), + targets_config=targets_config.model_dump(mode="json"), + train_config=train_config.model_dump(mode="json"), class_names=class_names, dimension_names=dimension_names, - train_config=train_config.model_dump(mode="json"), model=model, ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index fb4743f..dcbcbb8 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -81,7 +81,10 @@ def run_train( "model." ) - targets = targets or build_targets(config=targets_config) + if targets is None: + targets = build_targets(config=targets_config) + else: + targets_config = TargetConfig.model_validate(targets.get_config()) roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi) @@ -132,6 +135,7 @@ def run_train( module = build_training_module( model_config=model_config, + targets_config=targets_config, class_names=targets.class_names, dimension_names=roi_mapper.dimension_names, train_config=train_config, diff --git a/tests/conftest.py b/tests/conftest.py index 9a65b91..8a6a0a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -472,6 +472,7 @@ def tiny_checkpoint_path( tmp_path: Path, ) -> Path: module = build_training_module( + targets_config=sample_targets.get_config(), class_names=sample_targets.class_names, dimension_names=sample_roi_mapper.dimension_names, ) diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index dd29a24..e3fa01f 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -225,6 +225,7 @@ def test_user_can_load_checkpoint_and_finetune( ) module = build_training_module( model_config=api.model_config, + targets_config=example_targets_config, class_names=api.targets.class_names, dimension_names=api.roi_mapper.dimension_names, ) @@ -273,6 +274,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( ) module = build_training_module( model_config=source_api.model_config, + targets_config=example_targets_config, class_names=source_api.targets.class_names, dimension_names=source_api.roi_mapper.dimension_names, ) diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 566ae62..6e9254a 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -73,7 +73,7 @@ def test_can_save_checkpoint( torch.testing.assert_close(output1, output2, rtol=0, atol=0) -def test_load_model_from_checkpoint_returns_model_and_config( +def test_load_model_from_checkpoint_returns_model_and_configs( tmp_path: Path, ): input_model_config = ModelConfig(samplerate=192_000) @@ -95,12 +95,18 @@ def test_load_model_from_checkpoint_returns_model_and_config( trainer.strategy.connect(module) trainer.save_checkpoint(path) - model, loaded_model_config = load_model_from_checkpoint(path) + model, loaded_configs = load_model_from_checkpoint(path) assert model is not None - assert loaded_model_config.model_dump( + assert loaded_configs.model.model_dump( mode="json" ) == expected_model_config.model_dump(mode="json") + assert loaded_configs.targets.model_dump( + mode="json" + ) == targets_config.model_dump(mode="json") + assert loaded_configs.train.model_dump( + mode="json" + ) == train_config.model_dump(mode="json") assert model.class_names == targets.class_names assert model.dimension_names == roi_mapper.dimension_names @@ -135,17 +141,40 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path): trainer.strategy.connect(module) trainer.save_checkpoint(path) - recovered = TrainingModule.load_from_checkpoint(path) + _, recovered_configs = load_model_from_checkpoint(path) assert not DeepDiff( - recovered.model_config.model_dump(mode="json"), + recovered_configs.model.model_dump(mode="json"), expected_model_config.model_dump(mode="json"), ) assert not DeepDiff( - recovered.train_config.model_dump(mode="json"), + recovered_configs.train.model_dump(mode="json"), train_config.model_dump(mode="json"), ) +def test_load_model_from_checkpoint_includes_targets_config(tmp_path: Path): + targets_config = TargetConfig() + targets = build_targets(targets_config) + roi_mapper = build_roi_mapping(targets_config.roi) + module = build_training_module( + model_config=ModelConfig(), + targets_config=targets_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, + train_config=TrainingConfig(), + ) + trainer = L.Trainer() + path = tmp_path / "example.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(path) + + _, loaded_configs = load_model_from_checkpoint(path) + + assert loaded_configs.targets.model_dump( + mode="json" + ) == targets_config.model_dump(mode="json") + + def test_configure_optimizers_uses_train_config_values(tmp_path: Path): model_config = ModelConfig() expected_model_config = ModelConfig.model_validate( @@ -179,14 +208,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path): trainer.strategy.connect(module) trainer.save_checkpoint(path) - recovered = TrainingModule.load_from_checkpoint(path) - assert recovered.model_config.model_dump( + _, recovered_configs = load_model_from_checkpoint(path) + assert recovered_configs.model.model_dump( mode="json" ) == expected_model_config.model_dump(mode="json") - assert recovered.train_config.model_dump( + assert recovered_configs.train.model_dump( mode="json" ) == train_config.model_dump(mode="json") + recovered = TrainingModule.load_from_checkpoint(path) + loaded_optimization_config = recovered.configure_optimizers() loaded_optimizer = loaded_optimization_config["optimizer"] loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"] @@ -201,12 +232,28 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path): trainer.strategy.connect(module) trainer.save_checkpoint(path) + _, stored_configs = load_model_from_checkpoint(path) api = BatDetect2API.from_checkpoint(path) assert api.model_config.model_dump( mode="json" - ) == module.model_config.model_dump(mode="json") - assert api.audio_config.samplerate == module.model_config.samplerate + ) == stored_configs.model.model_dump(mode="json") + assert api.audio_config.samplerate == stored_configs.model.samplerate + + +def test_api_from_checkpoint_reconstructs_targets_from_checkpoint( + tmp_path: Path, +) -> None: + targets_config = TargetConfig() + module = build_default_module(target_config=targets_config) + trainer = L.Trainer() + path = tmp_path / "example.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(path) + + api = BatDetect2API.from_checkpoint(path) + + assert api.targets.get_config() == targets_config.model_dump(mode="json") @pytest.mark.slow