From e8db1d40508b74b389586b71550d41ac0a807c3f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 19:53:19 -0600 Subject: [PATCH] Fix hyperparameter saving --- src/batdetect2/train/lightning.py | 19 ++++++++++++------- tests/test_train/test_lightning.py | 1 - 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 373a868..74dd282 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,5 +1,6 @@ import lightning as L import torch +from pydantic import BaseModel from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR @@ -20,23 +21,27 @@ class TrainingModule(L.LightningModule): def __init__(self, config: FullTrainingConfig): super().__init__() - self.save_hyperparameters() + # NOTE: Need to convert to vanilla python object so that DVCLive can + # store it. + self._config = ( + config.model_dump() if isinstance(config, BaseModel) else config + ) + self.save_hyperparameters({"config": self._config}) - self.loss = build_loss(config.train.loss) - self.targets = build_targets(config.targets) + self.config = FullTrainingConfig.model_validate(self._config) + self.loss = build_loss(self.config.train.loss) + self.targets = build_targets(self.config.targets) self.detector = build_model( num_classes=len(self.targets.class_names), - config=config.model, + config=self.config.model, ) - self.preprocessor = build_preprocessor(config.preprocess) + self.preprocessor = build_preprocessor(self.config.preprocess) self.postprocessor = build_postprocessor( self.targets, min_freq=self.preprocessor.min_freq, max_freq=self.preprocessor.max_freq, ) - self.config = config - def forward(self, spec: torch.Tensor) -> ModelOutput: return self.detector(spec) diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 55d9093..653afb1 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -1,7 +1,6 @@ from pathlib import Path import lightning as L -import pytest import torch import xarray as xr from soundevent import data