diff --git a/batdetect2/train/clips.py b/batdetect2/train/clips.py index ec13b8e..7debd13 100644 --- a/batdetect2/train/clips.py +++ b/batdetect2/train/clips.py @@ -69,12 +69,15 @@ class Clipper(ClipperProtocol): ) -def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol: +def build_clipper( + config: Optional[ClipingConfig] = None, + random: Optional[bool] = None, +) -> ClipperProtocol: config = config or ClipingConfig() return Clipper( duration=config.duration, max_empty=config.max_empty, - random=config.random, + random=config.random if random else False, ) diff --git a/batdetect2/train/config.py b/batdetect2/train/config.py index e2ab8ec..b75d0fe 100644 --- a/batdetect2/train/config.py +++ b/batdetect2/train/config.py @@ -1,7 +1,7 @@ -from typing import Optional +from typing import Optional, Union from pydantic import Field -from soundevent.data import PathLike +from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.train.augmentations import ( @@ -23,8 +23,29 @@ class OptimizerConfig(BaseConfig): t_max: int = 100 +class TrainerConfig(BaseConfig): + accelerator: str = "auto" + accumulate_grad_batches: int = 1 + deterministic: bool = True + check_val_every_n_epoch: int = 1 + devices: Union[str, int] = "auto" + enable_checkpointing: bool = True + gradient_clip_val: Optional[float] = None + limit_train_batches: Optional[Union[int, float]] = None + limit_test_batches: Optional[Union[int, float]] = None + limit_val_batches: Optional[Union[int, float]] = None + log_every_n_steps: Optional[int] = None + max_epochs: Optional[int] = 200 + min_epochs: Optional[int] = None + max_steps: Optional[int] = None + min_steps: Optional[int] = None + max_time: Optional[str] = None + precision: Optional[str] = None + val_check_interval: Optional[Union[int, float]] = None + + class TrainingConfig(BaseConfig): - batch_size: int = 32 + batch_size: int = 8 loss: LossConfig = Field(default_factory=LossConfig) @@ -36,9 +57,11 @@ class TrainingConfig(BaseConfig): cliping: ClipingConfig = Field(default_factory=ClipingConfig) + trainer: TrainerConfig = Field(default_factory=TrainerConfig) + def load_train_config( - path: PathLike, + path: data.PathLike, field: Optional[str] = None, ) -> TrainingConfig: return load_config(path, schema=TrainingConfig, field=field) diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 9c0ea82..3426e9c 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -81,7 +81,10 @@ class LabeledDataset(Dataset): array: xr.DataArray, dtype=np.float32, ) -> torch.Tensor: - return torch.tensor(array.values.astype(dtype)) + return torch.nan_to_num( + torch.tensor(array.values.astype(dtype)), + nan=0, + ) def list_preprocessed_files( @@ -91,7 +94,11 @@ def list_preprocessed_files( class RandomExampleSource: - def __init__(self, filenames: List[str], clipper: ClipperProtocol): + def __init__( + self, + filenames: List[data.PathLike], + clipper: ClipperProtocol, + ): self.filenames = filenames self.clipper = clipper diff --git a/batdetect2/train/lightning.py b/batdetect2/train/lightning.py index 096ccd9..ffdebb2 100644 --- a/batdetect2/train/lightning.py +++ b/batdetect2/train/lightning.py @@ -40,7 +40,9 @@ class TrainingModule(L.LightningModule): self.learning_rate = learning_rate self.t_max = t_max - self.save_hyperparameters() + # NOTE: Ignore detector and loss from hyperparameter saving + # as they are nn.Module and should be saved regardless. + self.save_hyperparameters(ignore=["detector", "loss"]) def forward(self, spec: torch.Tensor) -> ModelOutput: return self.detector(spec) diff --git a/batdetect2/train/train.py b/batdetect2/train/train.py index 6ab8a6c..039bc32 100644 --- a/batdetect2/train/train.py +++ b/batdetect2/train/train.py @@ -1,68 +1,112 @@ -from typing import Optional, Union +from typing import List, Optional -from lightning import LightningModule -from lightning.pytorch import Trainer -from soundevent.data import PathLike +from lightning import Trainer +from soundevent import data from torch.utils.data import DataLoader -from batdetect2.configs import BaseConfig, load_config -from batdetect2.train.dataset import LabeledDataset +from batdetect2.models.types import DetectionModel +from batdetect2.postprocess.types import PostprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import TargetProtocol +from batdetect2.train.augmentations import ( + build_augmentations, +) +from batdetect2.train.clips import build_clipper +from batdetect2.train.config import TrainingConfig +from batdetect2.train.dataset import LabeledDataset, RandomExampleSource +from batdetect2.train.lightning import TrainingModule +from batdetect2.train.losses import build_loss __all__ = [ "train", - "TrainerConfig", - "load_trainer_config", ] -class TrainerConfig(BaseConfig): - accelerator: str = "auto" - accumulate_grad_batches: int = 1 - deterministic: bool = True - check_val_every_n_epoch: int = 1 - devices: Union[str, int] = "auto" - enable_checkpointing: bool = True - gradient_clip_val: Optional[float] = None - limit_train_batches: Optional[Union[int, float]] = None - limit_test_batches: Optional[Union[int, float]] = None - limit_val_batches: Optional[Union[int, float]] = None - log_every_n_steps: Optional[int] = None - max_epochs: Optional[int] = None - min_epochs: Optional[int] = 100 - max_steps: Optional[int] = None - min_steps: Optional[int] = None - max_time: Optional[str] = None - precision: Optional[str] = None - reload_dataloaders_every_n_epochs: Optional[int] = None - val_check_interval: Optional[Union[int, float]] = None - - -def load_trainer_config(path: PathLike, field: Optional[str] = None): - return load_config(path, schema=TrainerConfig, field=field) - - def train( - module: LightningModule, - train_dataset: LabeledDataset, - trainer_config: Optional[TrainerConfig] = None, - dev_run: bool = False, - overfit_batches: bool = False, - profiler: Optional[str] = None, -): - trainer_config = trainer_config or TrainerConfig() - trainer = Trainer( - **trainer_config.model_dump( - exclude_unset=True, - exclude_none=True, - ), - fast_dev_run=dev_run, - overfit_batches=overfit_batches, - profiler=profiler, + detector: DetectionModel, + targets: TargetProtocol, + preprocessor: PreprocessorProtocol, + postprocessor: PostprocessorProtocol, + train_examples: List[data.PathLike], + val_examples: Optional[List[data.PathLike]] = None, + config: Optional[TrainingConfig] = None, +) -> None: + config = config or TrainingConfig() + + train_dataset = build_dataset( + train_examples, + preprocessor, + config=config, + train=True, ) - train_loader = DataLoader( + + loss = build_loss(config.loss) + + module = TrainingModule( + detector=detector, + loss=loss, + targets=targets, + preprocessor=preprocessor, + postprocessor=postprocessor, + learning_rate=config.optimizer.learning_rate, + t_max=config.optimizer.t_max, + ) + + trainer = Trainer(**config.trainer.model_dump()) + + train_dataloader = DataLoader( train_dataset, - batch_size=module.config.train.batch_size, + batch_size=config.batch_size, shuffle=True, - num_workers=7, ) - trainer.fit(module, train_dataloaders=train_loader) + + val_dataloader = None + if val_examples: + val_dataset = build_dataset( + val_examples, + preprocessor, + config=config, + train=False, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=config.batch_size, + shuffle=False, + ) + + trainer.fit( + module, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + + +def build_dataset( + examples: List[data.PathLike], + preprocessor: PreprocessorProtocol, + config: Optional[TrainingConfig] = None, + train: bool = True, +): + config = config or TrainingConfig() + + clipper = build_clipper(config.cliping, random=train) + + augmentations = None + + if train: + random_example_source = RandomExampleSource( + examples, + clipper=clipper, + ) + + augmentations = build_augmentations( + preprocessor, + config=config.augmentations, + example_source=random_example_source, + ) + + return LabeledDataset( + examples, + clipper=clipper, + augmentation=augmentations, + )