Instantiate lightnign module from config

This commit is contained in:
mbsantiago 2025-06-26 17:39:50 -06:00
parent 6d57f96c07
commit 0c8fae4a72
5 changed files with 60 additions and 111 deletions

View File

@ -15,8 +15,10 @@ from batdetect2.train.augmentations import (
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
load_full_training_config,
load_train_config,
)
from batdetect2.train.dataset import (
@ -26,6 +28,7 @@ from batdetect2.train.dataset import (
list_preprocessed_files,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import (
ClassificationLossConfig,
DetectionLossConfig,
@ -39,14 +42,11 @@ from batdetect2.train.preprocess import (
preprocess_annotations,
)
from batdetect2.train.train import (
FullTrainingConfig,
build_train_dataset,
build_train_loader,
build_trainer,
build_training_module,
build_val_dataset,
build_val_loader,
load_full_training_config,
train,
)
@ -66,6 +66,7 @@ __all__ = [
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainingConfig",
"TrainingModule",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo",
@ -76,7 +77,6 @@ __all__ = [
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_training_module",
"build_val_dataset",
"build_val_loader",
"generate_train_example",

View File

@ -4,6 +4,10 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
@ -15,6 +19,8 @@ from batdetect2.train.losses import LossConfig
__all__ = [
"TrainingConfig",
"load_train_config",
"FullTrainingConfig",
"load_full_training_config",
]
@ -57,3 +63,23 @@ def load_train_config(
field: Optional[str] = None,
) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field)
class FullTrainingConfig(BaseConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)

View File

@ -3,15 +3,13 @@ import torch
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import (
DetectionModel,
ModelOutput,
)
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.models import ModelOutput, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train import TrainExample
from batdetect2.train.types import LossProtocol
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
__all__ = [
"TrainingModule",
@ -19,26 +17,25 @@ __all__ = [
class TrainingModule(L.LightningModule):
def __init__(
self,
detector: DetectionModel,
loss: LossProtocol,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
def __init__(self, config: FullTrainingConfig):
super().__init__()
self.loss = loss
self.detector = detector
self.preprocessor = preprocessor
self.targets = targets
self.postprocessor = postprocessor
self.save_hyperparameters()
self.learning_rate = learning_rate
self.t_max = t_max
self.loss = build_loss(config.train.loss)
self.targets = build_targets(config.targets)
self.detector = build_model(
num_classes=len(self.targets.class_names),
config=config.model,
)
self.preprocessor = build_preprocessor(config.preprocess)
self.postprocessor = build_postprocessor(
self.targets,
min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq,
)
self.config = config
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
@ -68,6 +65,6 @@ class TrainingModule(L.LightningModule):
return outputs
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
return [optimizer], [scheduler]

View File

@ -3,28 +3,22 @@ from typing import List, Optional
from lightning import Trainer
from lightning.pytorch.callbacks import Callback
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import BackboneConfig, build_model
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import (
PreprocessingConfig,
PreprocessorProtocol,
build_preprocessor,
)
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.targets import TargetProtocol
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
@ -32,41 +26,17 @@ from batdetect2.train.dataset import (
)
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
__all__ = [
"FullTrainingConfig",
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_training_module",
"build_val_dataset",
"build_val_loader",
"load_full_training_config",
"train",
]
class FullTrainingConfig(BaseConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)
def train(
train_examples: Sequence[data.PathLike],
val_examples: Optional[Sequence[data.PathLike]] = None,
@ -80,7 +50,7 @@ def train(
if model_path is not None:
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(conf)
module = TrainingModule(conf)
trainer = build_trainer(conf, targets=module.targets)
@ -108,35 +78,6 @@ def train(
)
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
preprocessor = build_preprocessor(conf.preprocess)
targets = build_targets(conf.targets)
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=conf.model,
)
loss = build_loss(conf.train.loss)
return TrainingModule(
detector=model,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=conf.train.learning_rate,
t_max=conf.train.t_max,
)
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [
ValidationMetrics(

View File

@ -1,31 +1,16 @@
from pathlib import Path
import lightning as L
import pytest
import torch
import xarray as xr
from soundevent import data
from batdetect2.models import build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss
from batdetect2.train import FullTrainingConfig, TrainingModule
def build_default_module():
loss = build_loss()
targets = build_targets()
detector = build_model(num_classes=len(targets.class_names))
preprocessor = build_preprocessor()
postprocessor = build_postprocessor(targets)
return TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
return TrainingModule(FullTrainingConfig())
def test_can_initialize_default_module():