Better structure for training module

This commit is contained in:
mbsantiago 2025-08-23 18:23:45 +01:00
parent d4f249366e
commit 02adc19070
3 changed files with 69 additions and 38 deletions

View File

@ -1,16 +1,14 @@
import lightning as L import lightning as L
import torch import torch
from pydantic import BaseModel
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import ModelOutput, build_model from batdetect2.models import ModelOutput
from batdetect2.postprocess import build_postprocessor from batdetect2.models.types import DetectionModel
from batdetect2.preprocess import build_preprocessor from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.targets import build_targets from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample from batdetect2.train import TrainExample
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"TrainingModule", "TrainingModule",
@ -18,29 +16,28 @@ __all__ = [
class TrainingModule(L.LightningModule): class TrainingModule(L.LightningModule):
def __init__(self, config: FullTrainingConfig): def __init__(
self,
detector: DetectionModel,
loss: torch.nn.Module,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
super().__init__() super().__init__()
# NOTE: Need to convert to vanilla python object so that DVCLive can self.learning_rate = learning_rate
# store it. self.t_max = t_max
self._config = (
config.model_dump() if isinstance(config, BaseModel) else config
)
self.save_hyperparameters({"config": self._config})
self.config = FullTrainingConfig.model_validate(self._config) self.loss = loss
self.loss = build_loss(self.config.train.loss) self.targets = targets
self.targets = build_targets(self.config.targets) self.detector = detector
self.detector = build_model( self.preprocessor = preprocessor
num_classes=len(self.targets.class_names), self.postprocessor = postprocessor
config=self.config.model,
) self.save_hyperparameters(logger=False)
self.preprocessor = build_preprocessor(self.config.preprocess)
self.postprocessor = build_postprocessor(
self.targets,
min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq,
)
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec) return self.detector(spec)
@ -65,10 +62,9 @@ class TrainingModule(L.LightningModule):
self.log("detection_loss/val", losses.total, logger=True) self.log("detection_loss/val", losses.total, logger=True)
self.log("size_loss/val", losses.total, logger=True) self.log("size_loss/val", losses.total, logger=True)
self.log("classification_loss/val", losses.total, logger=True) self.log("classification_loss/val", losses.total, logger=True)
return outputs return outputs
def configure_optimizers(self): def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate) optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max) scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
return [optimizer], [scheduler] return [optimizer], [scheduler]

View File

@ -13,10 +13,13 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import ( from batdetect2.preprocess import (
PreprocessorProtocol, PreprocessorProtocol,
build_preprocessor,
) )
from batdetect2.targets import TargetProtocol from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
@ -28,6 +31,7 @@ from batdetect2.train.dataset import (
) )
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"build_train_dataset", "build_train_dataset",
@ -47,27 +51,27 @@ def train(
train_workers: Optional[int] = None, train_workers: Optional[int] = None,
val_workers: Optional[int] = None, val_workers: Optional[int] = None,
): ):
conf = config or FullTrainingConfig() config = config or FullTrainingConfig()
if model_path is not None: if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path) logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else: else:
module = TrainingModule(conf) module = build_training_module(config)
trainer = build_trainer(conf, targets=module.targets) trainer = build_trainer(config, targets=module.targets)
train_dataloader = build_train_loader( train_dataloader = build_train_loader(
train_examples, train_examples,
preprocessor=module.preprocessor, preprocessor=module.preprocessor,
config=conf.train, config=config.train,
num_workers=train_workers, num_workers=train_workers,
) )
val_dataloader = ( val_dataloader = (
build_val_loader( build_val_loader(
val_examples, val_examples,
config=conf.train, config=config.train,
num_workers=val_workers, num_workers=val_workers,
) )
if val_examples is not None if val_examples is not None
@ -83,6 +87,31 @@ def train(
logger.info("Training complete.") logger.info("Training complete.")
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
targets = build_targets(config=config.targets)
loss = build_loss(config=config.train.loss)
preprocessor = build_preprocessor(config.preprocess)
postprocessor = build_postprocessor(
targets,
config=config.postprocess,
max_freq=preprocessor.max_freq,
min_freq=preprocessor.min_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=config.model,
)
return TrainingModule(
detector=model,
loss=loss,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=targets,
learning_rate=config.train.learning_rate,
t_max=config.train.t_max,
)
def build_trainer_callbacks( def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig targets: TargetProtocol, config: EvaluationConfig
) -> List[Callback]: ) -> List[Callback]:
@ -114,9 +143,13 @@ def build_trainer(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
train_logger = build_logger(conf.train.logger)
train_logger.log_hyperparams(conf.model_dump(mode="json"))
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
logger=build_logger(conf.train.logger), logger=train_logger,
callbacks=build_trainer_callbacks(targets, config=conf.evaluation), callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
) )

View File

@ -6,10 +6,12 @@ import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.train import FullTrainingConfig, TrainingModule from batdetect2.train import FullTrainingConfig, TrainingModule
from batdetect2.train.train import build_training_module
def build_default_module(): def build_default_module():
return TrainingModule(FullTrainingConfig()) config = FullTrainingConfig()
return build_training_module(config)
def test_can_initialize_default_module(): def test_can_initialize_default_module():