mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Instantiate lightnign module from config
This commit is contained in:
parent
6d57f96c07
commit
0c8fae4a72
@ -15,8 +15,10 @@ from batdetect2.train.augmentations import (
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
load_full_training_config,
|
||||
load_train_config,
|
||||
)
|
||||
from batdetect2.train.dataset import (
|
||||
@ -26,6 +28,7 @@ from batdetect2.train.dataset import (
|
||||
list_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import (
|
||||
ClassificationLossConfig,
|
||||
DetectionLossConfig,
|
||||
@ -39,14 +42,11 @@ from batdetect2.train.preprocess import (
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.train.train import (
|
||||
FullTrainingConfig,
|
||||
build_train_dataset,
|
||||
build_train_loader,
|
||||
build_trainer,
|
||||
build_training_module,
|
||||
build_val_dataset,
|
||||
build_val_loader,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
|
||||
@ -66,6 +66,7 @@ __all__ = [
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
"TrainingConfig",
|
||||
"TrainingModule",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"add_echo",
|
||||
@ -76,7 +77,6 @@ __all__ = [
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_training_module",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"generate_train_example",
|
||||
|
@ -4,6 +4,10 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models import BackboneConfig
|
||||
from batdetect2.postprocess import PostprocessConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
@ -15,6 +19,8 @@ from batdetect2.train.losses import LossConfig
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
"FullTrainingConfig",
|
||||
"load_full_training_config",
|
||||
]
|
||||
|
||||
|
||||
@ -57,3 +63,23 @@ def load_train_config(
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
||||
|
||||
|
||||
class FullTrainingConfig(BaseConfig):
|
||||
"""Full training configuration."""
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
|
||||
|
||||
def load_full_training_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> FullTrainingConfig:
|
||||
"""Load the full training configuration."""
|
||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||
|
@ -3,15 +3,13 @@ import torch
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import (
|
||||
DetectionModel,
|
||||
ModelOutput,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.models import ModelOutput, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train import TrainExample
|
||||
from batdetect2.train.types import LossProtocol
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
@ -19,26 +17,25 @@ __all__ = [
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
loss: LossProtocol,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
):
|
||||
def __init__(self, config: FullTrainingConfig):
|
||||
super().__init__()
|
||||
|
||||
self.loss = loss
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.targets = targets
|
||||
self.postprocessor = postprocessor
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
self.loss = build_loss(config.train.loss)
|
||||
self.targets = build_targets(config.targets)
|
||||
self.detector = build_model(
|
||||
num_classes=len(self.targets.class_names),
|
||||
config=config.model,
|
||||
)
|
||||
self.preprocessor = build_preprocessor(config.preprocess)
|
||||
self.postprocessor = build_postprocessor(
|
||||
self.targets,
|
||||
min_freq=self.preprocessor.min_freq,
|
||||
max_freq=self.preprocessor.max_freq,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
@ -68,6 +65,6 @@ class TrainingModule(L.LightningModule):
|
||||
return outputs
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
@ -3,28 +3,22 @@ from typing import List, Optional
|
||||
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import BackboneConfig, build_model
|
||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
PreprocessorProtocol,
|
||||
build_preprocessor,
|
||||
)
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.targets import TargetProtocol
|
||||
from batdetect2.train.augmentations import build_augmentations
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
@ -32,41 +26,17 @@ from batdetect2.train.dataset import (
|
||||
)
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
__all__ = [
|
||||
"FullTrainingConfig",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_training_module",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"load_full_training_config",
|
||||
"train",
|
||||
]
|
||||
|
||||
|
||||
class FullTrainingConfig(BaseConfig):
|
||||
"""Full training configuration."""
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
|
||||
|
||||
def load_full_training_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> FullTrainingConfig:
|
||||
"""Load the full training configuration."""
|
||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||
|
||||
|
||||
def train(
|
||||
train_examples: Sequence[data.PathLike],
|
||||
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||
@ -80,7 +50,7 @@ def train(
|
||||
if model_path is not None:
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = build_training_module(conf)
|
||||
module = TrainingModule(conf)
|
||||
|
||||
trainer = build_trainer(conf, targets=module.targets)
|
||||
|
||||
@ -108,35 +78,6 @@ def train(
|
||||
)
|
||||
|
||||
|
||||
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
|
||||
preprocessor = build_preprocessor(conf.preprocess)
|
||||
|
||||
targets = build_targets(conf.targets)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
model = build_model(
|
||||
num_classes=len(targets.class_names),
|
||||
config=conf.model,
|
||||
)
|
||||
|
||||
loss = build_loss(conf.train.loss)
|
||||
|
||||
return TrainingModule(
|
||||
detector=model,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
learning_rate=conf.train.learning_rate,
|
||||
t_max=conf.train.t_max,
|
||||
)
|
||||
|
||||
|
||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||
return [
|
||||
ValidationMetrics(
|
||||
|
@ -1,31 +1,16 @@
|
||||
from pathlib import Path
|
||||
|
||||
import lightning as L
|
||||
import pytest
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||
|
||||
|
||||
def build_default_module():
|
||||
loss = build_loss()
|
||||
targets = build_targets()
|
||||
detector = build_model(num_classes=len(targets.class_names))
|
||||
preprocessor = build_preprocessor()
|
||||
postprocessor = build_postprocessor(targets)
|
||||
return TrainingModule(
|
||||
detector=detector,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
return TrainingModule(FullTrainingConfig())
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
|
Loading…
Reference in New Issue
Block a user