Fix hyperparameter saving

This commit is contained in:
mbsantiago 2025-06-26 19:53:19 -06:00
parent b396d4908a
commit e8db1d4050
2 changed files with 12 additions and 8 deletions

View File

@ -1,5 +1,6 @@
import lightning as L import lightning as L
import torch import torch
from pydantic import BaseModel
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
@ -20,23 +21,27 @@ class TrainingModule(L.LightningModule):
def __init__(self, config: FullTrainingConfig): def __init__(self, config: FullTrainingConfig):
super().__init__() super().__init__()
self.save_hyperparameters() # NOTE: Need to convert to vanilla python object so that DVCLive can
# store it.
self._config = (
config.model_dump() if isinstance(config, BaseModel) else config
)
self.save_hyperparameters({"config": self._config})
self.loss = build_loss(config.train.loss) self.config = FullTrainingConfig.model_validate(self._config)
self.targets = build_targets(config.targets) self.loss = build_loss(self.config.train.loss)
self.targets = build_targets(self.config.targets)
self.detector = build_model( self.detector = build_model(
num_classes=len(self.targets.class_names), num_classes=len(self.targets.class_names),
config=config.model, config=self.config.model,
) )
self.preprocessor = build_preprocessor(config.preprocess) self.preprocessor = build_preprocessor(self.config.preprocess)
self.postprocessor = build_postprocessor( self.postprocessor = build_postprocessor(
self.targets, self.targets,
min_freq=self.preprocessor.min_freq, min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq, max_freq=self.preprocessor.max_freq,
) )
self.config = config
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec) return self.detector(spec)

View File

@ -1,7 +1,6 @@
from pathlib import Path from pathlib import Path
import lightning as L import lightning as L
import pytest
import torch import torch
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data