mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fix hyperparameter saving
This commit is contained in:
parent
b396d4908a
commit
e8db1d4050
@ -1,5 +1,6 @@
|
|||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
@ -20,23 +21,27 @@ class TrainingModule(L.LightningModule):
|
|||||||
def __init__(self, config: FullTrainingConfig):
|
def __init__(self, config: FullTrainingConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.save_hyperparameters()
|
# NOTE: Need to convert to vanilla python object so that DVCLive can
|
||||||
|
# store it.
|
||||||
|
self._config = (
|
||||||
|
config.model_dump() if isinstance(config, BaseModel) else config
|
||||||
|
)
|
||||||
|
self.save_hyperparameters({"config": self._config})
|
||||||
|
|
||||||
self.loss = build_loss(config.train.loss)
|
self.config = FullTrainingConfig.model_validate(self._config)
|
||||||
self.targets = build_targets(config.targets)
|
self.loss = build_loss(self.config.train.loss)
|
||||||
|
self.targets = build_targets(self.config.targets)
|
||||||
self.detector = build_model(
|
self.detector = build_model(
|
||||||
num_classes=len(self.targets.class_names),
|
num_classes=len(self.targets.class_names),
|
||||||
config=config.model,
|
config=self.config.model,
|
||||||
)
|
)
|
||||||
self.preprocessor = build_preprocessor(config.preprocess)
|
self.preprocessor = build_preprocessor(self.config.preprocess)
|
||||||
self.postprocessor = build_postprocessor(
|
self.postprocessor = build_postprocessor(
|
||||||
self.targets,
|
self.targets,
|
||||||
min_freq=self.preprocessor.min_freq,
|
min_freq=self.preprocessor.min_freq,
|
||||||
max_freq=self.preprocessor.max_freq,
|
max_freq=self.preprocessor.max_freq,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.detector(spec)
|
return self.detector(spec)
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
Loading…
Reference in New Issue
Block a user