mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Better structure for training module
This commit is contained in:
parent
d4f249366e
commit
02adc19070
@ -1,16 +1,14 @@
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import ModelOutput, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.models import ModelOutput
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train import TrainExample
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
@ -18,29 +16,28 @@ __all__ = [
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
def __init__(self, config: FullTrainingConfig):
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
loss: torch.nn.Module,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# NOTE: Need to convert to vanilla python object so that DVCLive can
|
||||
# store it.
|
||||
self._config = (
|
||||
config.model_dump() if isinstance(config, BaseModel) else config
|
||||
)
|
||||
self.save_hyperparameters({"config": self._config})
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
self.config = FullTrainingConfig.model_validate(self._config)
|
||||
self.loss = build_loss(self.config.train.loss)
|
||||
self.targets = build_targets(self.config.targets)
|
||||
self.detector = build_model(
|
||||
num_classes=len(self.targets.class_names),
|
||||
config=self.config.model,
|
||||
)
|
||||
self.preprocessor = build_preprocessor(self.config.preprocess)
|
||||
self.postprocessor = build_postprocessor(
|
||||
self.targets,
|
||||
min_freq=self.preprocessor.min_freq,
|
||||
max_freq=self.preprocessor.max_freq,
|
||||
)
|
||||
self.loss = loss
|
||||
self.targets = targets
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
@ -65,10 +62,9 @@ class TrainingModule(L.LightningModule):
|
||||
self.log("detection_loss/val", losses.total, logger=True)
|
||||
self.log("size_loss/val", losses.total, logger=True)
|
||||
self.log("classification_loss/val", losses.total, logger=True)
|
||||
|
||||
return outputs
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
@ -13,10 +13,13 @@ from batdetect2.evaluate.metrics import (
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
build_preprocessor,
|
||||
)
|
||||
from batdetect2.targets import TargetProtocol
|
||||
from batdetect2.targets import TargetProtocol, build_targets
|
||||
from batdetect2.train.augmentations import build_augmentations
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
@ -28,6 +31,7 @@ from batdetect2.train.dataset import (
|
||||
)
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
__all__ = [
|
||||
"build_train_dataset",
|
||||
@ -47,27 +51,27 @@ def train(
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
):
|
||||
conf = config or FullTrainingConfig()
|
||||
config = config or FullTrainingConfig()
|
||||
|
||||
if model_path is not None:
|
||||
logger.debug("Loading model from: {path}", path=model_path)
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = TrainingModule(conf)
|
||||
module = build_training_module(config)
|
||||
|
||||
trainer = build_trainer(conf, targets=module.targets)
|
||||
trainer = build_trainer(config, targets=module.targets)
|
||||
|
||||
train_dataloader = build_train_loader(
|
||||
train_examples,
|
||||
preprocessor=module.preprocessor,
|
||||
config=conf.train,
|
||||
config=config.train,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
|
||||
val_dataloader = (
|
||||
build_val_loader(
|
||||
val_examples,
|
||||
config=conf.train,
|
||||
config=config.train,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
if val_examples is not None
|
||||
@ -83,6 +87,31 @@ def train(
|
||||
logger.info("Training complete.")
|
||||
|
||||
|
||||
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
||||
targets = build_targets(config=config.targets)
|
||||
loss = build_loss(config=config.train.loss)
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
config=config.postprocess,
|
||||
max_freq=preprocessor.max_freq,
|
||||
min_freq=preprocessor.min_freq,
|
||||
)
|
||||
model = build_model(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config.model,
|
||||
)
|
||||
return TrainingModule(
|
||||
detector=model,
|
||||
loss=loss,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
targets=targets,
|
||||
learning_rate=config.train.learning_rate,
|
||||
t_max=config.train.t_max,
|
||||
)
|
||||
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: TargetProtocol, config: EvaluationConfig
|
||||
) -> List[Callback]:
|
||||
@ -114,9 +143,13 @@ def build_trainer(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
train_logger = build_logger(conf.train.logger)
|
||||
|
||||
train_logger.log_hyperparams(conf.model_dump(mode="json"))
|
||||
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
logger=build_logger(conf.train.logger),
|
||||
logger=train_logger,
|
||||
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
||||
)
|
||||
|
||||
|
||||
@ -6,10 +6,12 @@ import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||
from batdetect2.train.train import build_training_module
|
||||
|
||||
|
||||
def build_default_module():
|
||||
return TrainingModule(FullTrainingConfig())
|
||||
config = FullTrainingConfig()
|
||||
return build_training_module(config)
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user