Instantiate lightnign module from config

This commit is contained in:
mbsantiago 2025-06-26 17:39:50 -06:00
parent 6d57f96c07
commit 0c8fae4a72
5 changed files with 60 additions and 111 deletions

View File

@ -15,8 +15,10 @@ from batdetect2.train.augmentations import (
) )
from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import ( from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig, PLTrainerConfig,
TrainingConfig, TrainingConfig,
load_full_training_config,
load_train_config, load_train_config,
) )
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
@ -26,6 +28,7 @@ from batdetect2.train.dataset import (
list_preprocessed_files, list_preprocessed_files,
) )
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import ( from batdetect2.train.losses import (
ClassificationLossConfig, ClassificationLossConfig,
DetectionLossConfig, DetectionLossConfig,
@ -39,14 +42,11 @@ from batdetect2.train.preprocess import (
preprocess_annotations, preprocess_annotations,
) )
from batdetect2.train.train import ( from batdetect2.train.train import (
FullTrainingConfig,
build_train_dataset, build_train_dataset,
build_train_loader, build_train_loader,
build_trainer, build_trainer,
build_training_module,
build_val_dataset, build_val_dataset,
build_val_loader, build_val_loader,
load_full_training_config,
train, train,
) )
@ -66,6 +66,7 @@ __all__ = [
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainExample", "TrainExample",
"TrainingConfig", "TrainingConfig",
"TrainingModule",
"VolumeAugmentationConfig", "VolumeAugmentationConfig",
"WarpAugmentationConfig", "WarpAugmentationConfig",
"add_echo", "add_echo",
@ -76,7 +77,6 @@ __all__ = [
"build_train_dataset", "build_train_dataset",
"build_train_loader", "build_train_loader",
"build_trainer", "build_trainer",
"build_training_module",
"build_val_dataset", "build_val_dataset",
"build_val_loader", "build_val_loader",
"generate_train_example", "generate_train_example",

View File

@ -4,6 +4,10 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG, DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig, AugmentationsConfig,
@ -15,6 +19,8 @@ from batdetect2.train.losses import LossConfig
__all__ = [ __all__ = [
"TrainingConfig", "TrainingConfig",
"load_train_config", "load_train_config",
"FullTrainingConfig",
"load_full_training_config",
] ]
@ -57,3 +63,23 @@ def load_train_config(
field: Optional[str] = None, field: Optional[str] = None,
) -> TrainingConfig: ) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field) return load_config(path, schema=TrainingConfig, field=field)
class FullTrainingConfig(BaseConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)

View File

@ -3,15 +3,13 @@ import torch
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import ( from batdetect2.models import ModelOutput, build_model
DetectionModel, from batdetect2.postprocess import build_postprocessor
ModelOutput, from batdetect2.preprocess import build_preprocessor
) from batdetect2.targets import build_targets
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample from batdetect2.train import TrainExample
from batdetect2.train.types import LossProtocol from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"TrainingModule", "TrainingModule",
@ -19,26 +17,25 @@ __all__ = [
class TrainingModule(L.LightningModule): class TrainingModule(L.LightningModule):
def __init__( def __init__(self, config: FullTrainingConfig):
self,
detector: DetectionModel,
loss: LossProtocol,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
super().__init__() super().__init__()
self.loss = loss self.save_hyperparameters()
self.detector = detector
self.preprocessor = preprocessor
self.targets = targets
self.postprocessor = postprocessor
self.learning_rate = learning_rate self.loss = build_loss(config.train.loss)
self.t_max = t_max self.targets = build_targets(config.targets)
self.detector = build_model(
num_classes=len(self.targets.class_names),
config=config.model,
)
self.preprocessor = build_preprocessor(config.preprocess)
self.postprocessor = build_postprocessor(
self.targets,
min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq,
)
self.config = config
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec) return self.detector(spec)
@ -68,6 +65,6 @@ class TrainingModule(L.LightningModule):
return outputs return outputs
def configure_optimizers(self): def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate) optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
return [optimizer], [scheduler] return [optimizer], [scheduler]

View File

@ -3,28 +3,22 @@ from typing import List, Optional
from lightning import Trainer from lightning import Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from pydantic import Field
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.metrics import (
ClassificationAccuracy, ClassificationAccuracy,
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import BackboneConfig, build_model
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import ( from batdetect2.preprocess import (
PreprocessingConfig,
PreprocessorProtocol, PreprocessorProtocol,
build_preprocessor,
) )
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets from batdetect2.targets import TargetProtocol
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
@ -32,41 +26,17 @@ from batdetect2.train.dataset import (
) )
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"FullTrainingConfig",
"build_train_dataset", "build_train_dataset",
"build_train_loader", "build_train_loader",
"build_trainer", "build_trainer",
"build_training_module",
"build_val_dataset", "build_val_dataset",
"build_val_loader", "build_val_loader",
"load_full_training_config",
"train", "train",
] ]
class FullTrainingConfig(BaseConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)
def train( def train(
train_examples: Sequence[data.PathLike], train_examples: Sequence[data.PathLike],
val_examples: Optional[Sequence[data.PathLike]] = None, val_examples: Optional[Sequence[data.PathLike]] = None,
@ -80,7 +50,7 @@ def train(
if model_path is not None: if model_path is not None:
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else: else:
module = build_training_module(conf) module = TrainingModule(conf)
trainer = build_trainer(conf, targets=module.targets) trainer = build_trainer(conf, targets=module.targets)
@ -108,35 +78,6 @@ def train(
) )
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
preprocessor = build_preprocessor(conf.preprocess)
targets = build_targets(conf.targets)
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=conf.model,
)
loss = build_loss(conf.train.loss)
return TrainingModule(
detector=model,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=conf.train.learning_rate,
t_max=conf.train.t_max,
)
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [ return [
ValidationMetrics( ValidationMetrics(

View File

@ -1,31 +1,16 @@
from pathlib import Path from pathlib import Path
import lightning as L import lightning as L
import pytest
import torch import torch
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.models import build_model from batdetect2.train import FullTrainingConfig, TrainingModule
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss
def build_default_module(): def build_default_module():
loss = build_loss() return TrainingModule(FullTrainingConfig())
targets = build_targets()
detector = build_model(num_classes=len(targets.class_names))
preprocessor = build_preprocessor()
postprocessor = build_postprocessor(targets)
return TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
def test_can_initialize_default_module(): def test_can_initialize_default_module():