From 0c8fae4a7286c2e5b1a86c68dbdded67f71f1f41 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 17:39:50 -0600 Subject: [PATCH] Instantiate lightnign module from config --- src/batdetect2/train/__init__.py | 8 ++-- src/batdetect2/train/config.py | 26 ++++++++++++ src/batdetect2/train/lightning.py | 51 +++++++++++------------ src/batdetect2/train/train.py | 65 ++---------------------------- tests/test_train/test_lightning.py | 21 ++-------- 5 files changed, 60 insertions(+), 111 deletions(-) diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 44a0c58..df5dc27 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -15,8 +15,10 @@ from batdetect2.train.augmentations import ( ) from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.config import ( + FullTrainingConfig, PLTrainerConfig, TrainingConfig, + load_full_training_config, load_train_config, ) from batdetect2.train.dataset import ( @@ -26,6 +28,7 @@ from batdetect2.train.dataset import ( list_preprocessed_files, ) from batdetect2.train.labels import build_clip_labeler, load_label_config +from batdetect2.train.lightning import TrainingModule from batdetect2.train.losses import ( ClassificationLossConfig, DetectionLossConfig, @@ -39,14 +42,11 @@ from batdetect2.train.preprocess import ( preprocess_annotations, ) from batdetect2.train.train import ( - FullTrainingConfig, build_train_dataset, build_train_loader, build_trainer, - build_training_module, build_val_dataset, build_val_loader, - load_full_training_config, train, ) @@ -66,6 +66,7 @@ __all__ = [ "TimeMaskAugmentationConfig", "TrainExample", "TrainingConfig", + "TrainingModule", "VolumeAugmentationConfig", "WarpAugmentationConfig", "add_echo", @@ -76,7 +77,6 @@ __all__ = [ "build_train_dataset", "build_train_loader", "build_trainer", - "build_training_module", "build_val_dataset", "build_val_loader", "generate_train_example", diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index d6a2a31..2f1ff42 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -4,6 +4,10 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config +from batdetect2.models import BackboneConfig +from batdetect2.postprocess import PostprocessConfig +from batdetect2.preprocess import PreprocessingConfig +from batdetect2.targets import TargetConfig from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, @@ -15,6 +19,8 @@ from batdetect2.train.losses import LossConfig __all__ = [ "TrainingConfig", "load_train_config", + "FullTrainingConfig", + "load_full_training_config", ] @@ -57,3 +63,23 @@ def load_train_config( field: Optional[str] = None, ) -> TrainingConfig: return load_config(path, schema=TrainingConfig, field=field) + + +class FullTrainingConfig(BaseConfig): + """Full training configuration.""" + + train: TrainingConfig = Field(default_factory=TrainingConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + model: BackboneConfig = Field(default_factory=BackboneConfig) + preprocess: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + + +def load_full_training_config( + path: data.PathLike, + field: Optional[str] = None, +) -> FullTrainingConfig: + """Load the full training configuration.""" + return load_config(path, schema=FullTrainingConfig, field=field) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index f4f1959..373a868 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -3,15 +3,13 @@ import torch from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.models import ( - DetectionModel, - ModelOutput, -) -from batdetect2.postprocess.types import PostprocessorProtocol -from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.types import TargetProtocol +from batdetect2.models import ModelOutput, build_model +from batdetect2.postprocess import build_postprocessor +from batdetect2.preprocess import build_preprocessor +from batdetect2.targets import build_targets from batdetect2.train import TrainExample -from batdetect2.train.types import LossProtocol +from batdetect2.train.config import FullTrainingConfig +from batdetect2.train.losses import build_loss __all__ = [ "TrainingModule", @@ -19,26 +17,25 @@ __all__ = [ class TrainingModule(L.LightningModule): - def __init__( - self, - detector: DetectionModel, - loss: LossProtocol, - targets: TargetProtocol, - preprocessor: PreprocessorProtocol, - postprocessor: PostprocessorProtocol, - learning_rate: float = 0.001, - t_max: int = 100, - ): + def __init__(self, config: FullTrainingConfig): super().__init__() - self.loss = loss - self.detector = detector - self.preprocessor = preprocessor - self.targets = targets - self.postprocessor = postprocessor + self.save_hyperparameters() - self.learning_rate = learning_rate - self.t_max = t_max + self.loss = build_loss(config.train.loss) + self.targets = build_targets(config.targets) + self.detector = build_model( + num_classes=len(self.targets.class_names), + config=config.model, + ) + self.preprocessor = build_preprocessor(config.preprocess) + self.postprocessor = build_postprocessor( + self.targets, + min_freq=self.preprocessor.min_freq, + max_freq=self.preprocessor.max_freq, + ) + + self.config = config def forward(self, spec: torch.Tensor) -> ModelOutput: return self.detector(spec) @@ -68,6 +65,6 @@ class TrainingModule(L.LightningModule): return outputs def configure_optimizers(self): - optimizer = Adam(self.parameters(), lr=self.learning_rate) - scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) + optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate) + scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max) return [optimizer], [scheduler] diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index c2810dd..7b339b7 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -3,28 +3,22 @@ from typing import List, Optional from lightning import Trainer from lightning.pytorch.callbacks import Callback -from pydantic import Field from soundevent import data from torch.utils.data import DataLoader -from batdetect2.configs import BaseConfig, load_config from batdetect2.evaluate.metrics import ( ClassificationAccuracy, ClassificationMeanAveragePrecision, DetectionAveragePrecision, ) -from batdetect2.models import BackboneConfig, build_model -from batdetect2.postprocess import PostprocessConfig, build_postprocessor from batdetect2.preprocess import ( - PreprocessingConfig, PreprocessorProtocol, - build_preprocessor, ) -from batdetect2.targets import TargetConfig, TargetProtocol, build_targets +from batdetect2.targets import TargetProtocol from batdetect2.train.augmentations import build_augmentations from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper -from batdetect2.train.config import TrainingConfig +from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.dataset import ( LabeledDataset, RandomExampleSource, @@ -32,41 +26,17 @@ from batdetect2.train.dataset import ( ) from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger -from batdetect2.train.losses import build_loss __all__ = [ - "FullTrainingConfig", "build_train_dataset", "build_train_loader", "build_trainer", - "build_training_module", "build_val_dataset", "build_val_loader", - "load_full_training_config", "train", ] -class FullTrainingConfig(BaseConfig): - """Full training configuration.""" - - train: TrainingConfig = Field(default_factory=TrainingConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) - model: BackboneConfig = Field(default_factory=BackboneConfig) - preprocess: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) - - -def load_full_training_config( - path: data.PathLike, - field: Optional[str] = None, -) -> FullTrainingConfig: - """Load the full training configuration.""" - return load_config(path, schema=FullTrainingConfig, field=field) - - def train( train_examples: Sequence[data.PathLike], val_examples: Optional[Sequence[data.PathLike]] = None, @@ -80,7 +50,7 @@ def train( if model_path is not None: module = TrainingModule.load_from_checkpoint(model_path) # type: ignore else: - module = build_training_module(conf) + module = TrainingModule(conf) trainer = build_trainer(conf, targets=module.targets) @@ -108,35 +78,6 @@ def train( ) -def build_training_module(conf: FullTrainingConfig) -> TrainingModule: - preprocessor = build_preprocessor(conf.preprocess) - - targets = build_targets(conf.targets) - - postprocessor = build_postprocessor( - targets, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - ) - - model = build_model( - num_classes=len(targets.class_names), - config=conf.model, - ) - - loss = build_loss(conf.train.loss) - - return TrainingModule( - detector=model, - loss=loss, - targets=targets, - preprocessor=preprocessor, - postprocessor=postprocessor, - learning_rate=conf.train.learning_rate, - t_max=conf.train.t_max, - ) - - def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: return [ ValidationMetrics( diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index fb635d4..55d9093 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -1,31 +1,16 @@ from pathlib import Path import lightning as L +import pytest import torch import xarray as xr from soundevent import data -from batdetect2.models import build_model -from batdetect2.postprocess import build_postprocessor -from batdetect2.preprocess import build_preprocessor -from batdetect2.targets import build_targets -from batdetect2.train.lightning import TrainingModule -from batdetect2.train.losses import build_loss +from batdetect2.train import FullTrainingConfig, TrainingModule def build_default_module(): - loss = build_loss() - targets = build_targets() - detector = build_model(num_classes=len(targets.class_names)) - preprocessor = build_preprocessor() - postprocessor = build_postprocessor(targets) - return TrainingModule( - detector=detector, - loss=loss, - targets=targets, - preprocessor=preprocessor, - postprocessor=postprocessor, - ) + return TrainingModule(FullTrainingConfig()) def test_can_initialize_default_module():