mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Better structure for training module
This commit is contained in:
parent
d4f249366e
commit
02adc19070
@ -1,16 +1,14 @@
|
|||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import ModelOutput, build_model
|
from batdetect2.models import ModelOutput
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.models.types import DetectionModel
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train import TrainExample
|
from batdetect2.train import TrainExample
|
||||||
from batdetect2.train.config import FullTrainingConfig
|
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
@ -18,29 +16,28 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
def __init__(self, config: FullTrainingConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
detector: DetectionModel,
|
||||||
|
loss: torch.nn.Module,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
learning_rate: float = 0.001,
|
||||||
|
t_max: int = 100,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# NOTE: Need to convert to vanilla python object so that DVCLive can
|
self.learning_rate = learning_rate
|
||||||
# store it.
|
self.t_max = t_max
|
||||||
self._config = (
|
|
||||||
config.model_dump() if isinstance(config, BaseModel) else config
|
|
||||||
)
|
|
||||||
self.save_hyperparameters({"config": self._config})
|
|
||||||
|
|
||||||
self.config = FullTrainingConfig.model_validate(self._config)
|
self.loss = loss
|
||||||
self.loss = build_loss(self.config.train.loss)
|
self.targets = targets
|
||||||
self.targets = build_targets(self.config.targets)
|
self.detector = detector
|
||||||
self.detector = build_model(
|
self.preprocessor = preprocessor
|
||||||
num_classes=len(self.targets.class_names),
|
self.postprocessor = postprocessor
|
||||||
config=self.config.model,
|
|
||||||
)
|
self.save_hyperparameters(logger=False)
|
||||||
self.preprocessor = build_preprocessor(self.config.preprocess)
|
|
||||||
self.postprocessor = build_postprocessor(
|
|
||||||
self.targets,
|
|
||||||
min_freq=self.preprocessor.min_freq,
|
|
||||||
max_freq=self.preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.detector(spec)
|
return self.detector(spec)
|
||||||
@ -65,10 +62,9 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.log("detection_loss/val", losses.total, logger=True)
|
self.log("detection_loss/val", losses.total, logger=True)
|
||||||
self.log("size_loss/val", losses.total, logger=True)
|
self.log("size_loss/val", losses.total, logger=True)
|
||||||
self.log("classification_loss/val", losses.total, logger=True)
|
self.log("classification_loss/val", losses.total, logger=True)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|||||||
@ -13,10 +13,13 @@ from batdetect2.evaluate.metrics import (
|
|||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models import build_model
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
|
build_preprocessor,
|
||||||
)
|
)
|
||||||
from batdetect2.targets import TargetProtocol
|
from batdetect2.targets import TargetProtocol, build_targets
|
||||||
from batdetect2.train.augmentations import build_augmentations
|
from batdetect2.train.augmentations import build_augmentations
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
@ -28,6 +31,7 @@ from batdetect2.train.dataset import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
@ -47,27 +51,27 @@ def train(
|
|||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
):
|
):
|
||||||
conf = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
logger.debug("Loading model from: {path}", path=model_path)
|
logger.debug("Loading model from: {path}", path=model_path)
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
else:
|
else:
|
||||||
module = TrainingModule(conf)
|
module = build_training_module(config)
|
||||||
|
|
||||||
trainer = build_trainer(conf, targets=module.targets)
|
trainer = build_trainer(config, targets=module.targets)
|
||||||
|
|
||||||
train_dataloader = build_train_loader(
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor=module.preprocessor,
|
preprocessor=module.preprocessor,
|
||||||
config=conf.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = (
|
val_dataloader = (
|
||||||
build_val_loader(
|
build_val_loader(
|
||||||
val_examples,
|
val_examples,
|
||||||
config=conf.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
if val_examples is not None
|
if val_examples is not None
|
||||||
@ -83,6 +87,31 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
||||||
|
targets = build_targets(config=config.targets)
|
||||||
|
loss = build_loss(config=config.train.loss)
|
||||||
|
preprocessor = build_preprocessor(config.preprocess)
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
targets,
|
||||||
|
config=config.postprocess,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
)
|
||||||
|
model = build_model(
|
||||||
|
num_classes=len(targets.class_names),
|
||||||
|
config=config.model,
|
||||||
|
)
|
||||||
|
return TrainingModule(
|
||||||
|
detector=model,
|
||||||
|
loss=loss,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
targets=targets,
|
||||||
|
learning_rate=config.train.learning_rate,
|
||||||
|
t_max=config.train.t_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(
|
def build_trainer_callbacks(
|
||||||
targets: TargetProtocol, config: EvaluationConfig
|
targets: TargetProtocol, config: EvaluationConfig
|
||||||
) -> List[Callback]:
|
) -> List[Callback]:
|
||||||
@ -114,9 +143,13 @@ def build_trainer(
|
|||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
train_logger = build_logger(conf.train.logger)
|
||||||
|
|
||||||
|
train_logger.log_hyperparams(conf.model_dump(mode="json"))
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
logger=build_logger(conf.train.logger),
|
logger=train_logger,
|
||||||
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -6,10 +6,12 @@ import xarray as xr
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||||
|
from batdetect2.train.train import build_training_module
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
return TrainingModule(FullTrainingConfig())
|
config = FullTrainingConfig()
|
||||||
|
return build_training_module(config)
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user