mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
use train config in training module
This commit is contained in:
parent
65bd0dc6ae
commit
56f6affc72
@ -1,17 +1,12 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.losses import LossFunction, build_loss
|
||||
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
@ -20,13 +15,13 @@ __all__ = [
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
model: Model
|
||||
loss: LossProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: dict | None = None,
|
||||
t_max: int = 100,
|
||||
learning_rate: float = 1e-3,
|
||||
loss: torch.nn.Module | None = None,
|
||||
train_config: dict | None = None,
|
||||
loss: LossFunction | None = None,
|
||||
model: Model | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -34,11 +29,10 @@ class TrainingModule(L.LightningModule):
|
||||
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||
|
||||
self.model_config = ModelConfig.model_validate(model_config or {})
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
self.train_config = TrainingConfig.model_validate(train_config or {})
|
||||
|
||||
if loss is None:
|
||||
loss = build_loss()
|
||||
loss = build_loss(config=self.train_config.loss)
|
||||
|
||||
if model is None:
|
||||
model = build_model(config=self.model_config)
|
||||
@ -50,9 +44,13 @@ class TrainingModule(L.LightningModule):
|
||||
outputs = self.model.detector(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/train", losses.total, logger=True)
|
||||
self.log("size_loss/train", losses.total, logger=True)
|
||||
self.log("classification_loss/train", losses.total, logger=True)
|
||||
self.log("detection_loss/train", losses.detection, logger=True)
|
||||
self.log("size_loss/train", losses.size, logger=True)
|
||||
self.log(
|
||||
"classification_loss/train",
|
||||
losses.classification,
|
||||
logger=True,
|
||||
)
|
||||
return losses.total
|
||||
|
||||
def validation_step( # type: ignore
|
||||
@ -63,14 +61,15 @@ class TrainingModule(L.LightningModule):
|
||||
outputs = self.model.detector(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/val", losses.total, logger=True)
|
||||
self.log("size_loss/val", losses.total, logger=True)
|
||||
self.log("classification_loss/val", losses.total, logger=True)
|
||||
self.log("detection_loss/val", losses.detection, logger=True)
|
||||
self.log("size_loss/val", losses.size, logger=True)
|
||||
self.log("classification_loss/val", losses.classification, logger=True)
|
||||
return outputs
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
config = self.train_config.optimizer
|
||||
optimizer = Adam(self.parameters(), lr=config.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
|
||||
@ -98,17 +97,6 @@ def load_model_from_checkpoint(
|
||||
|
||||
def build_training_module(
|
||||
model_config: dict | None = None,
|
||||
t_max: int = 200,
|
||||
learning_rate: float = 1e-3,
|
||||
loss_config: dict | None = None,
|
||||
train_config: dict | None = None,
|
||||
) -> TrainingModule:
|
||||
from batdetect2.train.config import LossConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
loss = build_loss(LossConfig.model_validate(loss_config or {}))
|
||||
return TrainingModule(
|
||||
model_config=model_config,
|
||||
t_max=t_max,
|
||||
learning_rate=learning_rate,
|
||||
loss=loss,
|
||||
)
|
||||
return TrainingModule(model_config=model_config, train_config=train_config)
|
||||
|
||||
@ -422,7 +422,7 @@ class LossFunction(nn.Module, LossProtocol):
|
||||
def build_loss(
|
||||
config: LossConfig | None = None,
|
||||
class_weights: np.ndarray | None = None,
|
||||
) -> nn.Module:
|
||||
) -> LossFunction:
|
||||
"""Factory function to build the main LossFunction from configuration.
|
||||
|
||||
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
|
||||
|
||||
@ -96,11 +96,13 @@ def train(
|
||||
else None
|
||||
)
|
||||
|
||||
train_config_dict = config.train.model_dump(mode="json")
|
||||
if "optimizer" in train_config_dict:
|
||||
train_config_dict["optimizer"]["t_max"] *= len(train_dataloader)
|
||||
|
||||
module = build_training_module(
|
||||
model_config=config.model.model_dump(mode="json"),
|
||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||
learning_rate=config.train.optimizer.learning_rate,
|
||||
loss_config=config.train.loss.model_dump(mode="json"),
|
||||
train_config=train_config_dict,
|
||||
)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
|
||||
@ -14,6 +14,7 @@ def build_default_module():
|
||||
config = BatDetect2Config()
|
||||
return build_training_module(
|
||||
model_config=config.model.model_dump(mode="json"),
|
||||
train_config=config.train.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user