mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: persist target configs in checkpoints
This commit is contained in:
parent
20a7c058fc
commit
aa36df668f
@ -610,17 +610,15 @@ class BatDetect2API:
|
|||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
build_default_target_config,
|
|
||||||
build_roi_mapping,
|
build_roi_mapping,
|
||||||
build_targets,
|
build_targets,
|
||||||
check_target_compatibility,
|
check_target_compatibility,
|
||||||
)
|
)
|
||||||
from batdetect2.train import (
|
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||||
TrainingConfig,
|
|
||||||
load_model_from_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
model, configs = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
|
model_config = configs.model
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig(
|
audio_config = audio_config or AudioConfig(
|
||||||
samplerate=model_config.samplerate,
|
samplerate=model_config.samplerate,
|
||||||
@ -630,9 +628,7 @@ class BatDetect2API:
|
|||||||
inference_config = inference_config or InferenceConfig()
|
inference_config = inference_config or InferenceConfig()
|
||||||
outputs_config = outputs_config or OutputsConfig()
|
outputs_config = outputs_config or OutputsConfig()
|
||||||
logging_config = logging_config or AppLoggingConfig()
|
logging_config = logging_config or AppLoggingConfig()
|
||||||
targets_config = targets_config or build_default_target_config(
|
targets_config = targets_config or configs.targets
|
||||||
class_names=model.class_names
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = build_targets(config=targets_config)
|
targets = build_targets(config=targets_config)
|
||||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||||
|
|||||||
@ -211,7 +211,7 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def build_model(
|
def build_model(
|
||||||
config: ModelConfig | None = None,
|
config: ModelConfig | dict | None = None,
|
||||||
class_names: list[str] | None = None,
|
class_names: list[str] | None = None,
|
||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
@ -257,6 +257,9 @@ def build_model(
|
|||||||
|
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
|
|
||||||
|
if isinstance(config, dict):
|
||||||
|
config = ModelConfig.model_validate(config)
|
||||||
|
|
||||||
if class_names is None:
|
if class_names is None:
|
||||||
raise ValueError("class_names must be provided when building a model.")
|
raise ValueError("class_names must be provided when building a model.")
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,21 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.train.optimizers import build_optimizer
|
from batdetect2.train.optimizers import build_optimizer
|
||||||
from batdetect2.train.schedulers import build_scheduler
|
from batdetect2.train.schedulers import build_scheduler
|
||||||
from batdetect2.train.types import LossProtocol, TrainExample
|
from batdetect2.train.types import LossProtocol, TrainExample
|
||||||
|
|
||||||
__all__ = ["TrainingModule"]
|
__all__ = [
|
||||||
|
"TrainingModule",
|
||||||
|
"load_model_from_checkpoint",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
@ -19,6 +25,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
|
targets_config: dict | None = None,
|
||||||
class_names: list[str] | None = None,
|
class_names: list[str] | None = None,
|
||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
train_config: dict | None = None,
|
train_config: dict | None = None,
|
||||||
@ -29,9 +36,11 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||||
|
|
||||||
self.model_config = ModelConfig.model_validate(model_config or {})
|
self.model_config: dict = model_config or {}
|
||||||
|
self.targets_config: dict = targets_config or {}
|
||||||
self.class_names = list(class_names or [])
|
self.class_names = list(class_names or [])
|
||||||
self.dimension_names = list(dimension_names or [])
|
self.dimension_names = list(dimension_names or [])
|
||||||
|
|
||||||
self.train_config = TrainingConfig.model_validate(train_config or {})
|
self.train_config = TrainingConfig.model_validate(train_config or {})
|
||||||
|
|
||||||
if loss is None:
|
if loss is None:
|
||||||
@ -113,9 +122,16 @@ class TrainingModule(L.LightningModule):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StoredConfig:
|
||||||
|
model: ModelConfig
|
||||||
|
targets: TargetConfig
|
||||||
|
train: TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
) -> tuple[Model, ModelConfig]:
|
) -> tuple[Model, StoredConfig]:
|
||||||
"""Load a model and its configuration from a Lightning checkpoint.
|
"""Load a model and its configuration from a Lightning checkpoint.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -131,11 +147,19 @@ def load_model_from_checkpoint(
|
|||||||
describes its architecture, preprocessing, and postprocessing.
|
describes its architecture, preprocessing, and postprocessing.
|
||||||
"""
|
"""
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
return module.model, module.model_config
|
training_config = TrainingConfig.model_validate(module.train_config)
|
||||||
|
model_config = ModelConfig.model_validate(module.model_config)
|
||||||
|
targets_config = TargetConfig.model_validate(module.targets_config)
|
||||||
|
return module.model, StoredConfig(
|
||||||
|
model=model_config,
|
||||||
|
targets=targets_config,
|
||||||
|
train=training_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model_config: ModelConfig | None = None,
|
model_config: ModelConfig | None = None,
|
||||||
|
targets_config: TargetConfig | dict | None = None,
|
||||||
class_names: list[str] | None = None,
|
class_names: list[str] | None = None,
|
||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
@ -147,10 +171,16 @@ def build_training_module(
|
|||||||
if train_config is None:
|
if train_config is None:
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
|
|
||||||
|
if targets_config is None:
|
||||||
|
targets_config = TargetConfig()
|
||||||
|
|
||||||
|
targets_config = TargetConfig.model_validate(targets_config)
|
||||||
|
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model_config=model_config.model_dump(mode="json"),
|
model_config=model_config.model_dump(mode="json"),
|
||||||
|
targets_config=targets_config.model_dump(mode="json"),
|
||||||
|
train_config=train_config.model_dump(mode="json"),
|
||||||
class_names=class_names,
|
class_names=class_names,
|
||||||
dimension_names=dimension_names,
|
dimension_names=dimension_names,
|
||||||
train_config=train_config.model_dump(mode="json"),
|
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -81,7 +81,10 @@ def run_train(
|
|||||||
"model."
|
"model."
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = targets or build_targets(config=targets_config)
|
if targets is None:
|
||||||
|
targets = build_targets(config=targets_config)
|
||||||
|
else:
|
||||||
|
targets_config = TargetConfig.model_validate(targets.get_config())
|
||||||
|
|
||||||
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
|
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
|
||||||
|
|
||||||
@ -132,6 +135,7 @@ def run_train(
|
|||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
targets_config=targets_config,
|
||||||
class_names=targets.class_names,
|
class_names=targets.class_names,
|
||||||
dimension_names=roi_mapper.dimension_names,
|
dimension_names=roi_mapper.dimension_names,
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
|
|||||||
@ -472,6 +472,7 @@ def tiny_checkpoint_path(
|
|||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
|
targets_config=sample_targets.get_config(),
|
||||||
class_names=sample_targets.class_names,
|
class_names=sample_targets.class_names,
|
||||||
dimension_names=sample_roi_mapper.dimension_names,
|
dimension_names=sample_roi_mapper.dimension_names,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -225,6 +225,7 @@ def test_user_can_load_checkpoint_and_finetune(
|
|||||||
)
|
)
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=api.model_config,
|
model_config=api.model_config,
|
||||||
|
targets_config=example_targets_config,
|
||||||
class_names=api.targets.class_names,
|
class_names=api.targets.class_names,
|
||||||
dimension_names=api.roi_mapper.dimension_names,
|
dimension_names=api.roi_mapper.dimension_names,
|
||||||
)
|
)
|
||||||
@ -273,6 +274,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
)
|
)
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=source_api.model_config,
|
model_config=source_api.model_config,
|
||||||
|
targets_config=example_targets_config,
|
||||||
class_names=source_api.targets.class_names,
|
class_names=source_api.targets.class_names,
|
||||||
dimension_names=source_api.roi_mapper.dimension_names,
|
dimension_names=source_api.roi_mapper.dimension_names,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -73,7 +73,7 @@ def test_can_save_checkpoint(
|
|||||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
def test_load_model_from_checkpoint_returns_model_and_config(
|
def test_load_model_from_checkpoint_returns_model_and_configs(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
input_model_config = ModelConfig(samplerate=192_000)
|
input_model_config = ModelConfig(samplerate=192_000)
|
||||||
@ -95,12 +95,18 @@ def test_load_model_from_checkpoint_returns_model_and_config(
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
model, loaded_model_config = load_model_from_checkpoint(path)
|
model, loaded_configs = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
assert loaded_model_config.model_dump(
|
assert loaded_configs.model.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == expected_model_config.model_dump(mode="json")
|
) == expected_model_config.model_dump(mode="json")
|
||||||
|
assert loaded_configs.targets.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == targets_config.model_dump(mode="json")
|
||||||
|
assert loaded_configs.train.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == train_config.model_dump(mode="json")
|
||||||
assert model.class_names == targets.class_names
|
assert model.class_names == targets.class_names
|
||||||
assert model.dimension_names == roi_mapper.dimension_names
|
assert model.dimension_names == roi_mapper.dimension_names
|
||||||
|
|
||||||
@ -135,17 +141,40 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
_, recovered_configs = load_model_from_checkpoint(path)
|
||||||
assert not DeepDiff(
|
assert not DeepDiff(
|
||||||
recovered.model_config.model_dump(mode="json"),
|
recovered_configs.model.model_dump(mode="json"),
|
||||||
expected_model_config.model_dump(mode="json"),
|
expected_model_config.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
assert not DeepDiff(
|
assert not DeepDiff(
|
||||||
recovered.train_config.model_dump(mode="json"),
|
recovered_configs.train.model_dump(mode="json"),
|
||||||
train_config.model_dump(mode="json"),
|
train_config.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_from_checkpoint_includes_targets_config(tmp_path: Path):
|
||||||
|
targets_config = TargetConfig()
|
||||||
|
targets = build_targets(targets_config)
|
||||||
|
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||||
|
module = build_training_module(
|
||||||
|
model_config=ModelConfig(),
|
||||||
|
targets_config=targets_config,
|
||||||
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
|
train_config=TrainingConfig(),
|
||||||
|
)
|
||||||
|
trainer = L.Trainer()
|
||||||
|
path = tmp_path / "example.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
_, loaded_configs = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
|
assert loaded_configs.targets.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == targets_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
expected_model_config = ModelConfig.model_validate(
|
expected_model_config = ModelConfig.model_validate(
|
||||||
@ -179,14 +208,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
_, recovered_configs = load_model_from_checkpoint(path)
|
||||||
assert recovered.model_config.model_dump(
|
assert recovered_configs.model.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == expected_model_config.model_dump(mode="json")
|
) == expected_model_config.model_dump(mode="json")
|
||||||
assert recovered.train_config.model_dump(
|
assert recovered_configs.train.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == train_config.model_dump(mode="json")
|
) == train_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
|
||||||
loaded_optimization_config = recovered.configure_optimizers()
|
loaded_optimization_config = recovered.configure_optimizers()
|
||||||
loaded_optimizer = loaded_optimization_config["optimizer"]
|
loaded_optimizer = loaded_optimization_config["optimizer"]
|
||||||
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
|
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
|
||||||
@ -201,12 +232,28 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
|||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
_, stored_configs = load_model_from_checkpoint(path)
|
||||||
api = BatDetect2API.from_checkpoint(path)
|
api = BatDetect2API.from_checkpoint(path)
|
||||||
|
|
||||||
assert api.model_config.model_dump(
|
assert api.model_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == module.model_config.model_dump(mode="json")
|
) == stored_configs.model.model_dump(mode="json")
|
||||||
assert api.audio_config.samplerate == module.model_config.samplerate
|
assert api.audio_config.samplerate == stored_configs.model.samplerate
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_from_checkpoint_reconstructs_targets_from_checkpoint(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
targets_config = TargetConfig()
|
||||||
|
module = build_default_module(target_config=targets_config)
|
||||||
|
trainer = L.Trainer()
|
||||||
|
path = tmp_path / "example.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
api = BatDetect2API.from_checkpoint(path)
|
||||||
|
|
||||||
|
assert api.targets.get_config() == targets_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user