diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 5d9151e..481fd9f 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,16 +1,14 @@ import lightning as L import torch -from pydantic import BaseModel from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.models import ModelOutput, build_model -from batdetect2.postprocess import build_postprocessor -from batdetect2.preprocess import build_preprocessor -from batdetect2.targets import build_targets +from batdetect2.models import ModelOutput +from batdetect2.models.types import DetectionModel +from batdetect2.postprocess.types import PostprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import TargetProtocol from batdetect2.train import TrainExample -from batdetect2.train.config import FullTrainingConfig -from batdetect2.train.losses import build_loss __all__ = [ "TrainingModule", @@ -18,29 +16,28 @@ __all__ = [ class TrainingModule(L.LightningModule): - def __init__(self, config: FullTrainingConfig): + def __init__( + self, + detector: DetectionModel, + loss: torch.nn.Module, + targets: TargetProtocol, + preprocessor: PreprocessorProtocol, + postprocessor: PostprocessorProtocol, + learning_rate: float = 0.001, + t_max: int = 100, + ): super().__init__() - # NOTE: Need to convert to vanilla python object so that DVCLive can - # store it. - self._config = ( - config.model_dump() if isinstance(config, BaseModel) else config - ) - self.save_hyperparameters({"config": self._config}) + self.learning_rate = learning_rate + self.t_max = t_max - self.config = FullTrainingConfig.model_validate(self._config) - self.loss = build_loss(self.config.train.loss) - self.targets = build_targets(self.config.targets) - self.detector = build_model( - num_classes=len(self.targets.class_names), - config=self.config.model, - ) - self.preprocessor = build_preprocessor(self.config.preprocess) - self.postprocessor = build_postprocessor( - self.targets, - min_freq=self.preprocessor.min_freq, - max_freq=self.preprocessor.max_freq, - ) + self.loss = loss + self.targets = targets + self.detector = detector + self.preprocessor = preprocessor + self.postprocessor = postprocessor + + self.save_hyperparameters(logger=False) def forward(self, spec: torch.Tensor) -> ModelOutput: return self.detector(spec) @@ -65,10 +62,9 @@ class TrainingModule(L.LightningModule): self.log("detection_loss/val", losses.total, logger=True) self.log("size_loss/val", losses.total, logger=True) self.log("classification_loss/val", losses.total, logger=True) - return outputs def configure_optimizers(self): - optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate) - scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max) + optimizer = Adam(self.parameters(), lr=self.learning_rate) + scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) return [optimizer], [scheduler] diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 1828ede..fed22b5 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -13,10 +13,13 @@ from batdetect2.evaluate.metrics import ( ClassificationMeanAveragePrecision, DetectionAveragePrecision, ) +from batdetect2.models import build_model +from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import ( PreprocessorProtocol, + build_preprocessor, ) -from batdetect2.targets import TargetProtocol +from batdetect2.targets import TargetProtocol, build_targets from batdetect2.train.augmentations import build_augmentations from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper @@ -28,6 +31,7 @@ from batdetect2.train.dataset import ( ) from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger +from batdetect2.train.losses import build_loss __all__ = [ "build_train_dataset", @@ -47,27 +51,27 @@ def train( train_workers: Optional[int] = None, val_workers: Optional[int] = None, ): - conf = config or FullTrainingConfig() + config = config or FullTrainingConfig() if model_path is not None: logger.debug("Loading model from: {path}", path=model_path) module = TrainingModule.load_from_checkpoint(model_path) # type: ignore else: - module = TrainingModule(conf) + module = build_training_module(config) - trainer = build_trainer(conf, targets=module.targets) + trainer = build_trainer(config, targets=module.targets) train_dataloader = build_train_loader( train_examples, preprocessor=module.preprocessor, - config=conf.train, + config=config.train, num_workers=train_workers, ) val_dataloader = ( build_val_loader( val_examples, - config=conf.train, + config=config.train, num_workers=val_workers, ) if val_examples is not None @@ -83,6 +87,31 @@ def train( logger.info("Training complete.") +def build_training_module(config: FullTrainingConfig) -> TrainingModule: + targets = build_targets(config=config.targets) + loss = build_loss(config=config.train.loss) + preprocessor = build_preprocessor(config.preprocess) + postprocessor = build_postprocessor( + targets, + config=config.postprocess, + max_freq=preprocessor.max_freq, + min_freq=preprocessor.min_freq, + ) + model = build_model( + num_classes=len(targets.class_names), + config=config.model, + ) + return TrainingModule( + detector=model, + loss=loss, + preprocessor=preprocessor, + postprocessor=postprocessor, + targets=targets, + learning_rate=config.train.learning_rate, + t_max=config.train.t_max, + ) + + def build_trainer_callbacks( targets: TargetProtocol, config: EvaluationConfig ) -> List[Callback]: @@ -114,9 +143,13 @@ def build_trainer( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) + train_logger = build_logger(conf.train.logger) + + train_logger.log_hyperparams(conf.model_dump(mode="json")) + return Trainer( **trainer_conf.model_dump(exclude_none=True), - logger=build_logger(conf.train.logger), + logger=train_logger, callbacks=build_trainer_callbacks(targets, config=conf.evaluation), ) diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 653afb1..828a464 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -6,10 +6,12 @@ import xarray as xr from soundevent import data from batdetect2.train import FullTrainingConfig, TrainingModule +from batdetect2.train.train import build_training_module def build_default_module(): - return TrainingModule(FullTrainingConfig()) + config = FullTrainingConfig() + return build_training_module(config) def test_can_initialize_default_module():