Wrote the main train function

This commit is contained in:
mbsantiago 2025-04-24 10:00:18 +01:00
parent 4b4c3ecdf5
commit 899f74efd5
5 changed files with 142 additions and 63 deletions

View File

@ -69,12 +69,15 @@ class Clipper(ClipperProtocol):
) )
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol: def build_clipper(
config: Optional[ClipingConfig] = None,
random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig() config = config or ClipingConfig()
return Clipper( return Clipper(
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,
random=config.random, random=config.random if random else False,
) )

View File

@ -1,7 +1,7 @@
from typing import Optional from typing import Optional, Union
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
@ -23,8 +23,29 @@ class OptimizerConfig(BaseConfig):
t_max: int = 100 t_max: int = 100
class TrainerConfig(BaseConfig):
accelerator: str = "auto"
accumulate_grad_batches: int = 1
deterministic: bool = True
check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = 200
min_epochs: Optional[int] = None
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
val_check_interval: Optional[Union[int, float]] = None
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
batch_size: int = 32 batch_size: int = 8
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
@ -36,9 +57,11 @@ class TrainingConfig(BaseConfig):
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
def load_train_config( def load_train_config(
path: PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
) -> TrainingConfig: ) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field) return load_config(path, schema=TrainingConfig, field=field)

View File

@ -81,7 +81,10 @@ class LabeledDataset(Dataset):
array: xr.DataArray, array: xr.DataArray,
dtype=np.float32, dtype=np.float32,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.tensor(array.values.astype(dtype)) return torch.nan_to_num(
torch.tensor(array.values.astype(dtype)),
nan=0,
)
def list_preprocessed_files( def list_preprocessed_files(
@ -91,7 +94,11 @@ def list_preprocessed_files(
class RandomExampleSource: class RandomExampleSource:
def __init__(self, filenames: List[str], clipper: ClipperProtocol): def __init__(
self,
filenames: List[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames self.filenames = filenames
self.clipper = clipper self.clipper = clipper

View File

@ -40,7 +40,9 @@ class TrainingModule(L.LightningModule):
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.t_max = t_max self.t_max = t_max
self.save_hyperparameters() # NOTE: Ignore detector and loss from hyperparameter saving
# as they are nn.Module and should be saved regardless.
self.save_hyperparameters(ignore=["detector", "loss"])
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec) return self.detector(spec)

View File

@ -1,68 +1,112 @@
from typing import Optional, Union from typing import List, Optional
from lightning import LightningModule from lightning import Trainer
from lightning.pytorch import Trainer from soundevent import data
from soundevent.data import PathLike
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig, load_config from batdetect2.models.types import DetectionModel
from batdetect2.train.dataset import LabeledDataset from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.augmentations import (
build_augmentations,
)
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"train", "train",
"TrainerConfig",
"load_trainer_config",
] ]
class TrainerConfig(BaseConfig):
accelerator: str = "auto"
accumulate_grad_batches: int = 1
deterministic: bool = True
check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = None
min_epochs: Optional[int] = 100
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
reload_dataloaders_every_n_epochs: Optional[int] = None
val_check_interval: Optional[Union[int, float]] = None
def load_trainer_config(path: PathLike, field: Optional[str] = None):
return load_config(path, schema=TrainerConfig, field=field)
def train( def train(
module: LightningModule, detector: DetectionModel,
train_dataset: LabeledDataset, targets: TargetProtocol,
trainer_config: Optional[TrainerConfig] = None, preprocessor: PreprocessorProtocol,
dev_run: bool = False, postprocessor: PostprocessorProtocol,
overfit_batches: bool = False, train_examples: List[data.PathLike],
profiler: Optional[str] = None, val_examples: Optional[List[data.PathLike]] = None,
): config: Optional[TrainingConfig] = None,
trainer_config = trainer_config or TrainerConfig() ) -> None:
trainer = Trainer( config = config or TrainingConfig()
**trainer_config.model_dump(
exclude_unset=True, train_dataset = build_dataset(
exclude_none=True, train_examples,
), preprocessor,
fast_dev_run=dev_run, config=config,
overfit_batches=overfit_batches, train=True,
profiler=profiler,
) )
train_loader = DataLoader(
loss = build_loss(config.loss)
module = TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=config.optimizer.learning_rate,
t_max=config.optimizer.t_max,
)
trainer = Trainer(**config.trainer.model_dump())
train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_size=module.config.train.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,
num_workers=7,
) )
trainer.fit(module, train_dataloaders=train_loader)
val_dataloader = None
if val_examples:
val_dataset = build_dataset(
val_examples,
preprocessor,
config=config,
train=False,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
)
trainer.fit(
module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
def build_dataset(
examples: List[data.PathLike],
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
train: bool = True,
):
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
augmentations = None
if train:
random_example_source = RandomExampleSource(
examples,
clipper=clipper,
)
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
return LabeledDataset(
examples,
clipper=clipper,
augmentation=augmentations,
)