use train config in training module

This commit is contained in:
mbsantiago 2026-03-17 14:16:40 +00:00
parent 65bd0dc6ae
commit 56f6affc72
4 changed files with 30 additions and 39 deletions

View File

@ -1,17 +1,12 @@
from typing import TYPE_CHECKING
import lightning as L import lightning as L
import torch
from soundevent.data import PathLike from soundevent.data import PathLike
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.train.losses import build_loss from batdetect2.train.config import TrainingConfig
from batdetect2.typing import ModelOutput, TrainExample from batdetect2.train.losses import LossFunction, build_loss
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
if TYPE_CHECKING:
pass
__all__ = [ __all__ = [
"TrainingModule", "TrainingModule",
@ -20,13 +15,13 @@ __all__ = [
class TrainingModule(L.LightningModule): class TrainingModule(L.LightningModule):
model: Model model: Model
loss: LossProtocol
def __init__( def __init__(
self, self,
model_config: dict | None = None, model_config: dict | None = None,
t_max: int = 100, train_config: dict | None = None,
learning_rate: float = 1e-3, loss: LossFunction | None = None,
loss: torch.nn.Module | None = None,
model: Model | None = None, model: Model | None = None,
): ):
super().__init__() super().__init__()
@ -34,11 +29,10 @@ class TrainingModule(L.LightningModule):
self.save_hyperparameters(ignore=["model", "loss"], logger=False) self.save_hyperparameters(ignore=["model", "loss"], logger=False)
self.model_config = ModelConfig.model_validate(model_config or {}) self.model_config = ModelConfig.model_validate(model_config or {})
self.learning_rate = learning_rate self.train_config = TrainingConfig.model_validate(train_config or {})
self.t_max = t_max
if loss is None: if loss is None:
loss = build_loss() loss = build_loss(config=self.train_config.loss)
if model is None: if model is None:
model = build_model(config=self.model_config) model = build_model(config=self.model_config)
@ -50,9 +44,13 @@ class TrainingModule(L.LightningModule):
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True) self.log("detection_loss/train", losses.detection, logger=True)
self.log("size_loss/train", losses.total, logger=True) self.log("size_loss/train", losses.size, logger=True)
self.log("classification_loss/train", losses.total, logger=True) self.log(
"classification_loss/train",
losses.classification,
logger=True,
)
return losses.total return losses.total
def validation_step( # type: ignore def validation_step( # type: ignore
@ -63,14 +61,15 @@ class TrainingModule(L.LightningModule):
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True) self.log("detection_loss/val", losses.detection, logger=True)
self.log("size_loss/val", losses.total, logger=True) self.log("size_loss/val", losses.size, logger=True)
self.log("classification_loss/val", losses.total, logger=True) self.log("classification_loss/val", losses.classification, logger=True)
return outputs return outputs
def configure_optimizers(self): def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate) config = self.train_config.optimizer
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) optimizer = Adam(self.parameters(), lr=config.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max)
return [optimizer], [scheduler] return [optimizer], [scheduler]
@ -98,17 +97,6 @@ def load_model_from_checkpoint(
def build_training_module( def build_training_module(
model_config: dict | None = None, model_config: dict | None = None,
t_max: int = 200, train_config: dict | None = None,
learning_rate: float = 1e-3,
loss_config: dict | None = None,
) -> TrainingModule: ) -> TrainingModule:
from batdetect2.train.config import LossConfig return TrainingModule(model_config=model_config, train_config=train_config)
from batdetect2.train.losses import build_loss
loss = build_loss(LossConfig.model_validate(loss_config or {}))
return TrainingModule(
model_config=model_config,
t_max=t_max,
learning_rate=learning_rate,
loss=loss,
)

View File

@ -422,7 +422,7 @@ class LossFunction(nn.Module, LossProtocol):
def build_loss( def build_loss(
config: LossConfig | None = None, config: LossConfig | None = None,
class_weights: np.ndarray | None = None, class_weights: np.ndarray | None = None,
) -> nn.Module: ) -> LossFunction:
"""Factory function to build the main LossFunction from configuration. """Factory function to build the main LossFunction from configuration.
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based

View File

@ -96,11 +96,13 @@ def train(
else None else None
) )
train_config_dict = config.train.model_dump(mode="json")
if "optimizer" in train_config_dict:
train_config_dict["optimizer"]["t_max"] *= len(train_dataloader)
module = build_training_module( module = build_training_module(
model_config=config.model.model_dump(mode="json"), model_config=config.model.model_dump(mode="json"),
t_max=config.train.optimizer.t_max * len(train_dataloader), train_config=train_config_dict,
learning_rate=config.train.optimizer.learning_rate,
loss_config=config.train.loss.model_dump(mode="json"),
) )
trainer = trainer or build_trainer( trainer = trainer or build_trainer(

View File

@ -14,6 +14,7 @@ def build_default_module():
config = BatDetect2Config() config = BatDetect2Config()
return build_training_module( return build_training_module(
model_config=config.model.model_dump(mode="json"), model_config=config.model.model_dump(mode="json"),
train_config=config.train.model_dump(mode="json"),
) )