mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
use train config in training module
This commit is contained in:
parent
65bd0dc6ae
commit
56f6affc72
@ -1,17 +1,12 @@
|
|||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.typing import ModelOutput, TrainExample
|
from batdetect2.train.losses import LossFunction, build_loss
|
||||||
|
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
@ -20,13 +15,13 @@ __all__ = [
|
|||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
model: Model
|
model: Model
|
||||||
|
loss: LossProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
t_max: int = 100,
|
train_config: dict | None = None,
|
||||||
learning_rate: float = 1e-3,
|
loss: LossFunction | None = None,
|
||||||
loss: torch.nn.Module | None = None,
|
|
||||||
model: Model | None = None,
|
model: Model | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -34,11 +29,10 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||||
|
|
||||||
self.model_config = ModelConfig.model_validate(model_config or {})
|
self.model_config = ModelConfig.model_validate(model_config or {})
|
||||||
self.learning_rate = learning_rate
|
self.train_config = TrainingConfig.model_validate(train_config or {})
|
||||||
self.t_max = t_max
|
|
||||||
|
|
||||||
if loss is None:
|
if loss is None:
|
||||||
loss = build_loss()
|
loss = build_loss(config=self.train_config.loss)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = build_model(config=self.model_config)
|
model = build_model(config=self.model_config)
|
||||||
@ -50,9 +44,13 @@ class TrainingModule(L.LightningModule):
|
|||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/train", losses.total, logger=True)
|
self.log("detection_loss/train", losses.detection, logger=True)
|
||||||
self.log("size_loss/train", losses.total, logger=True)
|
self.log("size_loss/train", losses.size, logger=True)
|
||||||
self.log("classification_loss/train", losses.total, logger=True)
|
self.log(
|
||||||
|
"classification_loss/train",
|
||||||
|
losses.classification,
|
||||||
|
logger=True,
|
||||||
|
)
|
||||||
return losses.total
|
return losses.total
|
||||||
|
|
||||||
def validation_step( # type: ignore
|
def validation_step( # type: ignore
|
||||||
@ -63,14 +61,15 @@ class TrainingModule(L.LightningModule):
|
|||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/val", losses.total, logger=True)
|
self.log("detection_loss/val", losses.detection, logger=True)
|
||||||
self.log("size_loss/val", losses.total, logger=True)
|
self.log("size_loss/val", losses.size, logger=True)
|
||||||
self.log("classification_loss/val", losses.total, logger=True)
|
self.log("classification_loss/val", losses.classification, logger=True)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
config = self.train_config.optimizer
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
optimizer = Adam(self.parameters(), lr=config.learning_rate)
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
|
||||||
@ -98,17 +97,6 @@ def load_model_from_checkpoint(
|
|||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model_config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
t_max: int = 200,
|
train_config: dict | None = None,
|
||||||
learning_rate: float = 1e-3,
|
|
||||||
loss_config: dict | None = None,
|
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
from batdetect2.train.config import LossConfig
|
return TrainingModule(model_config=model_config, train_config=train_config)
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
|
|
||||||
loss = build_loss(LossConfig.model_validate(loss_config or {}))
|
|
||||||
return TrainingModule(
|
|
||||||
model_config=model_config,
|
|
||||||
t_max=t_max,
|
|
||||||
learning_rate=learning_rate,
|
|
||||||
loss=loss,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -422,7 +422,7 @@ class LossFunction(nn.Module, LossProtocol):
|
|||||||
def build_loss(
|
def build_loss(
|
||||||
config: LossConfig | None = None,
|
config: LossConfig | None = None,
|
||||||
class_weights: np.ndarray | None = None,
|
class_weights: np.ndarray | None = None,
|
||||||
) -> nn.Module:
|
) -> LossFunction:
|
||||||
"""Factory function to build the main LossFunction from configuration.
|
"""Factory function to build the main LossFunction from configuration.
|
||||||
|
|
||||||
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
|
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
|
||||||
|
|||||||
@ -96,11 +96,13 @@ def train(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_config_dict = config.train.model_dump(mode="json")
|
||||||
|
if "optimizer" in train_config_dict:
|
||||||
|
train_config_dict["optimizer"]["t_max"] *= len(train_dataloader)
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=config.model.model_dump(mode="json"),
|
model_config=config.model.model_dump(mode="json"),
|
||||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
train_config=train_config_dict,
|
||||||
learning_rate=config.train.optimizer.learning_rate,
|
|
||||||
loss_config=config.train.loss.model_dump(mode="json"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
|
|||||||
@ -14,6 +14,7 @@ def build_default_module():
|
|||||||
config = BatDetect2Config()
|
config = BatDetect2Config()
|
||||||
return build_training_module(
|
return build_training_module(
|
||||||
model_config=config.model.model_dump(mode="json"),
|
model_config=config.model.model_dump(mode="json"),
|
||||||
|
train_config=config.train.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user