diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index a9412a0..a7fde38 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,17 +1,12 @@ -from typing import TYPE_CHECKING - import lightning as L -import torch from soundevent.data import PathLike from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.models import Model, ModelConfig, build_model -from batdetect2.train.losses import build_loss -from batdetect2.typing import ModelOutput, TrainExample - -if TYPE_CHECKING: - pass +from batdetect2.train.config import TrainingConfig +from batdetect2.train.losses import LossFunction, build_loss +from batdetect2.typing import LossProtocol, ModelOutput, TrainExample __all__ = [ "TrainingModule", @@ -20,13 +15,13 @@ __all__ = [ class TrainingModule(L.LightningModule): model: Model + loss: LossProtocol def __init__( self, model_config: dict | None = None, - t_max: int = 100, - learning_rate: float = 1e-3, - loss: torch.nn.Module | None = None, + train_config: dict | None = None, + loss: LossFunction | None = None, model: Model | None = None, ): super().__init__() @@ -34,11 +29,10 @@ class TrainingModule(L.LightningModule): self.save_hyperparameters(ignore=["model", "loss"], logger=False) self.model_config = ModelConfig.model_validate(model_config or {}) - self.learning_rate = learning_rate - self.t_max = t_max + self.train_config = TrainingConfig.model_validate(train_config or {}) if loss is None: - loss = build_loss() + loss = build_loss(config=self.train_config.loss) if model is None: model = build_model(config=self.model_config) @@ -50,9 +44,13 @@ class TrainingModule(L.LightningModule): outputs = self.model.detector(batch.spec) losses = self.loss(outputs, batch) self.log("total_loss/train", losses.total, prog_bar=True, logger=True) - self.log("detection_loss/train", losses.total, logger=True) - self.log("size_loss/train", losses.total, logger=True) - self.log("classification_loss/train", losses.total, logger=True) + self.log("detection_loss/train", losses.detection, logger=True) + self.log("size_loss/train", losses.size, logger=True) + self.log( + "classification_loss/train", + losses.classification, + logger=True, + ) return losses.total def validation_step( # type: ignore @@ -63,14 +61,15 @@ class TrainingModule(L.LightningModule): outputs = self.model.detector(batch.spec) losses = self.loss(outputs, batch) self.log("total_loss/val", losses.total, prog_bar=True, logger=True) - self.log("detection_loss/val", losses.total, logger=True) - self.log("size_loss/val", losses.total, logger=True) - self.log("classification_loss/val", losses.total, logger=True) + self.log("detection_loss/val", losses.detection, logger=True) + self.log("size_loss/val", losses.size, logger=True) + self.log("classification_loss/val", losses.classification, logger=True) return outputs def configure_optimizers(self): - optimizer = Adam(self.parameters(), lr=self.learning_rate) - scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) + config = self.train_config.optimizer + optimizer = Adam(self.parameters(), lr=config.learning_rate) + scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max) return [optimizer], [scheduler] @@ -98,17 +97,6 @@ def load_model_from_checkpoint( def build_training_module( model_config: dict | None = None, - t_max: int = 200, - learning_rate: float = 1e-3, - loss_config: dict | None = None, + train_config: dict | None = None, ) -> TrainingModule: - from batdetect2.train.config import LossConfig - from batdetect2.train.losses import build_loss - - loss = build_loss(LossConfig.model_validate(loss_config or {})) - return TrainingModule( - model_config=model_config, - t_max=t_max, - learning_rate=learning_rate, - loss=loss, - ) + return TrainingModule(model_config=model_config, train_config=train_config) diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index 2adfea6..b98d2dd 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -422,7 +422,7 @@ class LossFunction(nn.Module, LossProtocol): def build_loss( config: LossConfig | None = None, class_weights: np.ndarray | None = None, -) -> nn.Module: +) -> LossFunction: """Factory function to build the main LossFunction from configuration. Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 66f59c1..6f1ef01 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -96,11 +96,13 @@ def train( else None ) + train_config_dict = config.train.model_dump(mode="json") + if "optimizer" in train_config_dict: + train_config_dict["optimizer"]["t_max"] *= len(train_dataloader) + module = build_training_module( model_config=config.model.model_dump(mode="json"), - t_max=config.train.optimizer.t_max * len(train_dataloader), - learning_rate=config.train.optimizer.learning_rate, - loss_config=config.train.loss.model_dump(mode="json"), + train_config=train_config_dict, ) trainer = trainer or build_trainer( diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 928961e..b3ebfb0 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -14,6 +14,7 @@ def build_default_module(): config = BatDetect2Config() return build_training_module( model_config=config.model.model_dump(mode="json"), + train_config=config.train.model_dump(mode="json"), )