mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fix hyperparameter saving
This commit is contained in:
parent
b396d4908a
commit
e8db1d4050
@ -1,5 +1,6 @@
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
@ -20,23 +21,27 @@ class TrainingModule(L.LightningModule):
|
||||
def __init__(self, config: FullTrainingConfig):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters()
|
||||
# NOTE: Need to convert to vanilla python object so that DVCLive can
|
||||
# store it.
|
||||
self._config = (
|
||||
config.model_dump() if isinstance(config, BaseModel) else config
|
||||
)
|
||||
self.save_hyperparameters({"config": self._config})
|
||||
|
||||
self.loss = build_loss(config.train.loss)
|
||||
self.targets = build_targets(config.targets)
|
||||
self.config = FullTrainingConfig.model_validate(self._config)
|
||||
self.loss = build_loss(self.config.train.loss)
|
||||
self.targets = build_targets(self.config.targets)
|
||||
self.detector = build_model(
|
||||
num_classes=len(self.targets.class_names),
|
||||
config=config.model,
|
||||
config=self.config.model,
|
||||
)
|
||||
self.preprocessor = build_preprocessor(config.preprocess)
|
||||
self.preprocessor = build_preprocessor(self.config.preprocess)
|
||||
self.postprocessor = build_postprocessor(
|
||||
self.targets,
|
||||
min_freq=self.preprocessor.min_freq,
|
||||
max_freq=self.preprocessor.max_freq,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
import lightning as L
|
||||
import pytest
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
Loading…
Reference in New Issue
Block a user