Better structure for training module

This commit is contained in:
mbsantiago 2025-08-23 18:23:45 +01:00
parent d4f249366e
commit 02adc19070
3 changed files with 69 additions and 38 deletions

View File

@ -1,16 +1,14 @@
import lightning as L
import torch
from pydantic import BaseModel
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import ModelOutput, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.models import ModelOutput
from batdetect2.models.types import DetectionModel
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
__all__ = [
"TrainingModule",
@ -18,29 +16,28 @@ __all__ = [
class TrainingModule(L.LightningModule):
def __init__(self, config: FullTrainingConfig):
def __init__(
self,
detector: DetectionModel,
loss: torch.nn.Module,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
super().__init__()
# NOTE: Need to convert to vanilla python object so that DVCLive can
# store it.
self._config = (
config.model_dump() if isinstance(config, BaseModel) else config
)
self.save_hyperparameters({"config": self._config})
self.learning_rate = learning_rate
self.t_max = t_max
self.config = FullTrainingConfig.model_validate(self._config)
self.loss = build_loss(self.config.train.loss)
self.targets = build_targets(self.config.targets)
self.detector = build_model(
num_classes=len(self.targets.class_names),
config=self.config.model,
)
self.preprocessor = build_preprocessor(self.config.preprocess)
self.postprocessor = build_postprocessor(
self.targets,
min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq,
)
self.loss = loss
self.targets = targets
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.save_hyperparameters(logger=False)
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
@ -65,10 +62,9 @@ class TrainingModule(L.LightningModule):
self.log("detection_loss/val", losses.total, logger=True)
self.log("size_loss/val", losses.total, logger=True)
self.log("classification_loss/val", losses.total, logger=True)
return outputs
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
return [optimizer], [scheduler]

View File

@ -13,10 +13,13 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import (
PreprocessorProtocol,
build_preprocessor,
)
from batdetect2.targets import TargetProtocol
from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
@ -28,6 +31,7 @@ from batdetect2.train.dataset import (
)
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
__all__ = [
"build_train_dataset",
@ -47,27 +51,27 @@ def train(
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
):
conf = config or FullTrainingConfig()
config = config or FullTrainingConfig()
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = TrainingModule(conf)
module = build_training_module(config)
trainer = build_trainer(conf, targets=module.targets)
trainer = build_trainer(config, targets=module.targets)
train_dataloader = build_train_loader(
train_examples,
preprocessor=module.preprocessor,
config=conf.train,
config=config.train,
num_workers=train_workers,
)
val_dataloader = (
build_val_loader(
val_examples,
config=conf.train,
config=config.train,
num_workers=val_workers,
)
if val_examples is not None
@ -83,6 +87,31 @@ def train(
logger.info("Training complete.")
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
targets = build_targets(config=config.targets)
loss = build_loss(config=config.train.loss)
preprocessor = build_preprocessor(config.preprocess)
postprocessor = build_postprocessor(
targets,
config=config.postprocess,
max_freq=preprocessor.max_freq,
min_freq=preprocessor.min_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=config.model,
)
return TrainingModule(
detector=model,
loss=loss,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=targets,
learning_rate=config.train.learning_rate,
t_max=config.train.t_max,
)
def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig
) -> List[Callback]:
@ -114,9 +143,13 @@ def build_trainer(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(conf.train.logger)
train_logger.log_hyperparams(conf.model_dump(mode="json"))
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
logger=build_logger(conf.train.logger),
logger=train_logger,
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
)

View File

@ -6,10 +6,12 @@ import xarray as xr
from soundevent import data
from batdetect2.train import FullTrainingConfig, TrainingModule
from batdetect2.train.train import build_training_module
def build_default_module():
return TrainingModule(FullTrainingConfig())
config = FullTrainingConfig()
return build_training_module(config)
def test_can_initialize_default_module():