mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Instantiate lightnign module from config
This commit is contained in:
parent
6d57f96c07
commit
0c8fae4a72
@ -15,8 +15,10 @@ from batdetect2.train.augmentations import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper, select_subclip
|
from batdetect2.train.clips import build_clipper, select_subclip
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
|
FullTrainingConfig,
|
||||||
PLTrainerConfig,
|
PLTrainerConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
|
load_full_training_config,
|
||||||
load_train_config,
|
load_train_config,
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
@ -26,6 +28,7 @@ from batdetect2.train.dataset import (
|
|||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.losses import (
|
from batdetect2.train.losses import (
|
||||||
ClassificationLossConfig,
|
ClassificationLossConfig,
|
||||||
DetectionLossConfig,
|
DetectionLossConfig,
|
||||||
@ -39,14 +42,11 @@ from batdetect2.train.preprocess import (
|
|||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.train import (
|
from batdetect2.train.train import (
|
||||||
FullTrainingConfig,
|
|
||||||
build_train_dataset,
|
build_train_dataset,
|
||||||
build_train_loader,
|
build_train_loader,
|
||||||
build_trainer,
|
build_trainer,
|
||||||
build_training_module,
|
|
||||||
build_val_dataset,
|
build_val_dataset,
|
||||||
build_val_loader,
|
build_val_loader,
|
||||||
load_full_training_config,
|
|
||||||
train,
|
train,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,6 +66,7 @@ __all__ = [
|
|||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
|
"TrainingModule",
|
||||||
"VolumeAugmentationConfig",
|
"VolumeAugmentationConfig",
|
||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
@ -76,7 +77,6 @@ __all__ = [
|
|||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
"build_train_loader",
|
"build_train_loader",
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"build_training_module",
|
|
||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
"build_val_loader",
|
"build_val_loader",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
|
@ -4,6 +4,10 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.models import BackboneConfig
|
||||||
|
from batdetect2.postprocess import PostprocessConfig
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
@ -15,6 +19,8 @@ from batdetect2.train.losses import LossConfig
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
|
"FullTrainingConfig",
|
||||||
|
"load_full_training_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -57,3 +63,23 @@ def load_train_config(
|
|||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
) -> TrainingConfig:
|
) -> TrainingConfig:
|
||||||
return load_config(path, schema=TrainingConfig, field=field)
|
return load_config(path, schema=TrainingConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
class FullTrainingConfig(BaseConfig):
|
||||||
|
"""Full training configuration."""
|
||||||
|
|
||||||
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_full_training_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> FullTrainingConfig:
|
||||||
|
"""Load the full training configuration."""
|
||||||
|
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||||
|
@ -3,15 +3,13 @@ import torch
|
|||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import (
|
from batdetect2.models import ModelOutput, build_model
|
||||||
DetectionModel,
|
from batdetect2.postprocess import build_postprocessor
|
||||||
ModelOutput,
|
from batdetect2.preprocess import build_preprocessor
|
||||||
)
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train import TrainExample
|
from batdetect2.train import TrainExample
|
||||||
from batdetect2.train.types import LossProtocol
|
from batdetect2.train.config import FullTrainingConfig
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
@ -19,26 +17,25 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
def __init__(
|
def __init__(self, config: FullTrainingConfig):
|
||||||
self,
|
|
||||||
detector: DetectionModel,
|
|
||||||
loss: LossProtocol,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
learning_rate: float = 0.001,
|
|
||||||
t_max: int = 100,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss = loss
|
self.save_hyperparameters()
|
||||||
self.detector = detector
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.targets = targets
|
|
||||||
self.postprocessor = postprocessor
|
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self.loss = build_loss(config.train.loss)
|
||||||
self.t_max = t_max
|
self.targets = build_targets(config.targets)
|
||||||
|
self.detector = build_model(
|
||||||
|
num_classes=len(self.targets.class_names),
|
||||||
|
config=config.model,
|
||||||
|
)
|
||||||
|
self.preprocessor = build_preprocessor(config.preprocess)
|
||||||
|
self.postprocessor = build_postprocessor(
|
||||||
|
self.targets,
|
||||||
|
min_freq=self.preprocessor.min_freq,
|
||||||
|
max_freq=self.preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.detector(spec)
|
return self.detector(spec)
|
||||||
@ -68,6 +65,6 @@ class TrainingModule(L.LightningModule):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
@ -3,28 +3,22 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.evaluate.metrics import (
|
||||||
ClassificationAccuracy,
|
ClassificationAccuracy,
|
||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
from batdetect2.models import BackboneConfig, build_model
|
|
||||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
PreprocessingConfig,
|
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
build_preprocessor,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
from batdetect2.targets import TargetProtocol
|
||||||
from batdetect2.train.augmentations import build_augmentations
|
from batdetect2.train.augmentations import build_augmentations
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
@ -32,41 +26,17 @@ from batdetect2.train.dataset import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FullTrainingConfig",
|
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
"build_train_loader",
|
"build_train_loader",
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"build_training_module",
|
|
||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
"build_val_loader",
|
"build_val_loader",
|
||||||
"load_full_training_config",
|
|
||||||
"train",
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FullTrainingConfig(BaseConfig):
|
|
||||||
"""Full training configuration."""
|
|
||||||
|
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_full_training_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> FullTrainingConfig:
|
|
||||||
"""Load the full training configuration."""
|
|
||||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_examples: Sequence[data.PathLike],
|
train_examples: Sequence[data.PathLike],
|
||||||
val_examples: Optional[Sequence[data.PathLike]] = None,
|
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||||
@ -80,7 +50,7 @@ def train(
|
|||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
else:
|
else:
|
||||||
module = build_training_module(conf)
|
module = TrainingModule(conf)
|
||||||
|
|
||||||
trainer = build_trainer(conf, targets=module.targets)
|
trainer = build_trainer(conf, targets=module.targets)
|
||||||
|
|
||||||
@ -108,35 +78,6 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
|
|
||||||
preprocessor = build_preprocessor(conf.preprocess)
|
|
||||||
|
|
||||||
targets = build_targets(conf.targets)
|
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = build_model(
|
|
||||||
num_classes=len(targets.class_names),
|
|
||||||
config=conf.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = build_loss(conf.train.loss)
|
|
||||||
|
|
||||||
return TrainingModule(
|
|
||||||
detector=model,
|
|
||||||
loss=loss,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
learning_rate=conf.train.learning_rate,
|
|
||||||
t_max=conf.train.t_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||||
return [
|
return [
|
||||||
ValidationMetrics(
|
ValidationMetrics(
|
||||||
|
@ -1,31 +1,16 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models import build_model
|
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||||
from batdetect2.postprocess import build_postprocessor
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.train.lightning import TrainingModule
|
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
loss = build_loss()
|
return TrainingModule(FullTrainingConfig())
|
||||||
targets = build_targets()
|
|
||||||
detector = build_model(num_classes=len(targets.class_names))
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
postprocessor = build_postprocessor(targets)
|
|
||||||
return TrainingModule(
|
|
||||||
detector=detector,
|
|
||||||
loss=loss,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
|
Loading…
Reference in New Issue
Block a user