diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index faff98b..e30bc25 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -22,7 +22,14 @@ from batdetect2.train.config import ( load_full_training_config, load_train_config, ) -from batdetect2.train.dataset import TrainingDataset +from batdetect2.train.dataset import ( + TrainingDataset, + ValidationDataset, + build_train_dataset, + build_train_loader, + build_val_dataset, + build_val_loader, +) from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.lightning import TrainingModule from batdetect2.train.losses import ( @@ -33,14 +40,7 @@ from batdetect2.train.losses import ( SizeLossConfig, build_loss, ) -from batdetect2.train.train import ( - build_train_dataset, - build_train_loader, - build_trainer, - build_val_dataset, - build_val_loader, - train, -) +from batdetect2.train.train import build_trainer, train __all__ = [ "AugmentationsConfig", @@ -49,7 +49,6 @@ __all__ = [ "EchoAugmentationConfig", "FrequencyMaskAugmentationConfig", "FullTrainingConfig", - "TrainingDataset", "LossConfig", "LossFunction", "PLTrainerConfig", @@ -57,7 +56,9 @@ __all__ = [ "SizeLossConfig", "TimeMaskAugmentationConfig", "TrainingConfig", + "TrainingDataset", "TrainingModule", + "ValidationDataset", "VolumeAugmentationConfig", "WarpAugmentationConfig", "add_echo", diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index e4d586a..7d4fe9f 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -72,13 +72,16 @@ class TrainLoaderConfig(BaseConfig): ) -class TrainingConfig(BaseConfig): +class OptimizerConfig(BaseConfig): learning_rate: float = 1e-3 t_max: int = 100 + +class TrainingConfig(BaseConfig): train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig) val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) + optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) loss: LossConfig = Field(default_factory=LossConfig) cliping: RandomClipConfig = Field(default_factory=RandomClipConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index e72f75d..ced2028 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -1,18 +1,31 @@ -from typing import Optional, Sequence, Tuple +from typing import List, Optional, Sequence import torch +from loguru import logger from soundevent import data -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset +from batdetect2.plotting.clips import build_audio_loader +from batdetect2.preprocess import build_preprocessor +from batdetect2.train.augmentations import ( + RandomAudioSource, + build_augmentations, +) +from batdetect2.train.clips import build_clipper +from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig +from batdetect2.train.labels import build_clip_labeler from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol -from batdetect2.typing.train import ( - Augmentation, - ClipLabeller, -) +from batdetect2.typing.train import Augmentation, ClipLabeller +from batdetect2.utils.arrays import adjust_width __all__ = [ "TrainingDataset", + "ValidationDataset", + "build_val_loader", + "build_train_loader", + "build_train_dataset", + "build_val_dataset", ] @@ -124,3 +137,174 @@ class ValidationDataset(Dataset): start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) + + +def build_train_loader( + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + config: Optional[TrainLoaderConfig] = None, + num_workers: Optional[int] = None, +) -> DataLoader: + config = config or TrainLoaderConfig() + + logger.info("Building training data loader...") + logger.opt(lazy=True).debug( + "Training data loader config: \n{config}", + config=lambda: config.to_yaml_string(exclude_none=True), + ) + + train_dataset = build_train_dataset( + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, + preprocessor=preprocessor, + config=config, + ) + + num_workers = num_workers or config.num_workers + return DataLoader( + train_dataset, + batch_size=config.batch_size, + shuffle=config.shuffle, + num_workers=num_workers, + collate_fn=_collate_fn, + ) + + +def build_val_loader( + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + config: Optional[ValLoaderConfig] = None, + num_workers: Optional[int] = None, +): + logger.info("Building validation data loader...") + config = config or ValLoaderConfig() + logger.opt(lazy=True).debug( + "Validation data loader config: \n{config}", + config=lambda: config.to_yaml_string(exclude_none=True), + ) + + val_dataset = build_val_dataset( + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, + preprocessor=preprocessor, + config=config, + ) + + num_workers = num_workers or config.num_workers + return DataLoader( + val_dataset, + batch_size=1, + shuffle=False, + num_workers=num_workers, + collate_fn=_collate_fn, + ) + + +def build_train_dataset( + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + config: Optional[TrainLoaderConfig] = None, +) -> TrainingDataset: + logger.info("Building training dataset...") + config = config or TrainLoaderConfig() + + clipper = build_clipper(config=config.clipping_strategy) + + if audio_loader is None: + audio_loader = build_audio_loader() + + if preprocessor is None: + preprocessor = build_preprocessor() + + if labeller is None: + labeller = build_clip_labeler( + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + + random_example_source = RandomAudioSource( + clip_annotations, + audio_loader=audio_loader, + ) + + if config.augmentations.enabled: + audio_augmentation, spectrogram_augmentation = build_augmentations( + samplerate=preprocessor.input_samplerate, + config=config.augmentations, + audio_source=random_example_source, + ) + else: + logger.debug("No augmentations configured for training dataset.") + audio_augmentation = None + spectrogram_augmentation = None + + return TrainingDataset( + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, + clipper=clipper, + preprocessor=preprocessor, + audio_augmentation=audio_augmentation, + spectrogram_augmentation=spectrogram_augmentation, + ) + + +def build_val_dataset( + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + config: Optional[ValLoaderConfig] = None, +) -> ValidationDataset: + logger.info("Building validation dataset...") + config = config or ValLoaderConfig() + + if audio_loader is None: + audio_loader = build_audio_loader() + + if preprocessor is None: + preprocessor = build_preprocessor() + + if labeller is None: + labeller = build_clip_labeler( + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + + clipper = build_clipper(config.clipping_strategy) + return ValidationDataset( + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, + preprocessor=preprocessor, + clipper=clipper, + ) + + +def _collate_fn(batch: List[TrainExample]) -> TrainExample: + max_width = max(item.spec.shape[-1] for item in batch) + return TrainExample( + spec=torch.stack( + [adjust_width(item.spec, max_width) for item in batch] + ), + detection_heatmap=torch.stack( + [adjust_width(item.detection_heatmap, max_width) for item in batch] + ), + size_heatmap=torch.stack( + [adjust_width(item.size_heatmap, max_width) for item in batch] + ), + class_heatmap=torch.stack( + [adjust_width(item.class_heatmap, max_width) for item in batch] + ), + idx=torch.stack([item.idx for item in batch]), + start_time=torch.stack([item.start_time for item in batch]), + end_time=torch.stack([item.end_time for item in batch]), + ) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 317f65b..8970c0e 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -77,3 +77,15 @@ def load_model_from_checkpoint( ) -> Tuple[Model, FullTrainingConfig]: module = TrainingModule.load_from_checkpoint(path) # type: ignore return module.model, module.config + + +def build_training_module( + config: Optional[FullTrainingConfig] = None, + t_max: int = 200, +) -> TrainingModule: + config = config or FullTrainingConfig() + return TrainingModule( + config=config, + learning_rate=config.train.optimizer.learning_rate, + t_max=t_max, + ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 8837da4..a071db2 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -2,47 +2,31 @@ from collections.abc import Sequence from pathlib import Path from typing import List, Optional -import torch from lightning import Trainer, seed_everything from lightning.pytorch.callbacks import Callback, ModelCheckpoint from loguru import logger from soundevent import data -from torch.utils.data import DataLoader -from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.evaluator import build_evaluator -from batdetect2.plotting.clips import AudioLoader, build_audio_loader +from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets -from batdetect2.train.augmentations import ( - RandomAudioSource, - build_augmentations, -) from batdetect2.train.callbacks import ValidationMetrics -from batdetect2.train.clips import build_clipper from batdetect2.train.config import ( FullTrainingConfig, - TrainLoaderConfig, - ValLoaderConfig, ) -from batdetect2.train.dataset import TrainingDataset, ValidationDataset +from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler -from batdetect2.train.lightning import TrainingModule +from batdetect2.train.lightning import TrainingModule, build_training_module from batdetect2.train.logging import build_logger from batdetect2.typing import ( - PreprocessorProtocol, TargetProtocol, - TrainExample, ) +from batdetect2.typing.preprocess import AudioLoader from batdetect2.typing.train import ClipLabeller -from batdetect2.utils.arrays import adjust_width __all__ = [ - "build_train_dataset", - "build_train_loader", "build_trainer", - "build_val_dataset", - "build_val_loader", "train", ] @@ -52,6 +36,11 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" def train( train_annotations: Sequence[data.ClipAnnotation], val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, + trainer: Optional[Trainer] = None, + targets: Optional[TargetProtocol] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, config: Optional[FullTrainingConfig] = None, model_path: Optional[data.PathLike] = None, train_workers: Optional[int] = None, @@ -67,13 +56,15 @@ def train( config = config or FullTrainingConfig() - targets = build_targets(config.targets) + targets = targets or build_targets(config.targets) - preprocessor = build_preprocessor(config.preprocess) + preprocessor = preprocessor or build_preprocessor(config.preprocess) - audio_loader = build_audio_loader(config=config.preprocess.audio) + audio_loader = audio_loader or build_audio_loader( + config=config.preprocess.audio + ) - labeller = build_clip_labeler( + labeller = labeller or build_clip_labeler( targets, min_freq=preprocessor.min_freq, max_freq=preprocessor.max_freq, @@ -108,10 +99,10 @@ def train( else: module = build_training_module( config, - t_max=config.train.t_max * len(train_dataloader), + t_max=config.train.optimizer.t_max * len(train_dataloader), ) - trainer = build_trainer( + trainer = trainer or build_trainer( config, targets=targets, checkpoint_dir=checkpoint_dir, @@ -129,21 +120,9 @@ def train( logger.info("Training complete.") -def build_training_module( - config: Optional[FullTrainingConfig] = None, - t_max: int = 200, -) -> TrainingModule: - config = config or FullTrainingConfig() - return TrainingModule( - config=config, - learning_rate=config.train.learning_rate, - t_max=t_max, - ) - - def build_trainer_callbacks( targets: TargetProtocol, - config: EvaluationConfig, + config: FullTrainingConfig, checkpoint_dir: Optional[Path] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, @@ -157,7 +136,7 @@ def build_trainer_callbacks( if run_name is not None: checkpoint_dir = checkpoint_dir / run_name - evaluator = build_evaluator(config=config, targets=targets) + evaluator = build_evaluator(config=config.evaluation, targets=targets) return [ ModelCheckpoint( @@ -202,180 +181,9 @@ def build_trainer( logger=train_logger, callbacks=build_trainer_callbacks( targets, - config=conf.evaluation, + config=conf, checkpoint_dir=checkpoint_dir, experiment_name=experiment_name, run_name=run_name, ), ) - - -def build_train_loader( - clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[TrainLoaderConfig] = None, - num_workers: Optional[int] = None, -) -> DataLoader: - config = config or TrainLoaderConfig() - - logger.info("Building training data loader...") - logger.opt(lazy=True).debug( - "Training data loader config: \n{config}", - config=lambda: config.to_yaml_string(exclude_none=True), - ) - - train_dataset = build_train_dataset( - clip_annotations, - audio_loader=audio_loader, - labeller=labeller, - preprocessor=preprocessor, - config=config, - ) - - num_workers = num_workers or config.num_workers - return DataLoader( - train_dataset, - batch_size=config.batch_size, - shuffle=config.shuffle, - num_workers=num_workers, - collate_fn=_collate_fn, - ) - - -def build_val_loader( - clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[ValLoaderConfig] = None, - num_workers: Optional[int] = None, -): - logger.info("Building validation data loader...") - config = config or ValLoaderConfig() - logger.opt(lazy=True).debug( - "Validation data loader config: \n{config}", - config=lambda: config.to_yaml_string(exclude_none=True), - ) - - val_dataset = build_val_dataset( - clip_annotations, - audio_loader=audio_loader, - labeller=labeller, - preprocessor=preprocessor, - config=config, - ) - - num_workers = num_workers or config.num_workers - return DataLoader( - val_dataset, - batch_size=1, - shuffle=False, - num_workers=num_workers, - collate_fn=_collate_fn, - ) - - -def build_train_dataset( - clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[TrainLoaderConfig] = None, -) -> TrainingDataset: - logger.info("Building training dataset...") - config = config or TrainLoaderConfig() - - clipper = build_clipper(config=config.clipping_strategy) - - if audio_loader is None: - audio_loader = build_audio_loader() - - if preprocessor is None: - preprocessor = build_preprocessor() - - if labeller is None: - labeller = build_clip_labeler( - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - ) - - random_example_source = RandomAudioSource( - clip_annotations, - audio_loader=audio_loader, - ) - - if config.augmentations.enabled: - audio_augmentation, spectrogram_augmentation = build_augmentations( - samplerate=preprocessor.input_samplerate, - config=config.augmentations, - audio_source=random_example_source, - ) - else: - logger.debug("No augmentations configured for training dataset.") - audio_augmentation = None - spectrogram_augmentation = None - - return TrainingDataset( - clip_annotations, - audio_loader=audio_loader, - labeller=labeller, - clipper=clipper, - preprocessor=preprocessor, - audio_augmentation=audio_augmentation, - spectrogram_augmentation=spectrogram_augmentation, - ) - - -def build_val_dataset( - clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - config: Optional[ValLoaderConfig] = None, -) -> ValidationDataset: - logger.info("Building validation dataset...") - config = config or ValLoaderConfig() - - if audio_loader is None: - audio_loader = build_audio_loader() - - if preprocessor is None: - preprocessor = build_preprocessor() - - if labeller is None: - labeller = build_clip_labeler( - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - ) - - clipper = build_clipper(config.clipping_strategy) - return ValidationDataset( - clip_annotations, - audio_loader=audio_loader, - labeller=labeller, - preprocessor=preprocessor, - clipper=clipper, - ) - - -def _collate_fn(batch: List[TrainExample]) -> TrainExample: - max_width = max(item.spec.shape[-1] for item in batch) - return TrainExample( - spec=torch.stack( - [adjust_width(item.spec, max_width) for item in batch] - ), - detection_heatmap=torch.stack( - [adjust_width(item.detection_heatmap, max_width) for item in batch] - ), - size_heatmap=torch.stack( - [adjust_width(item.size_heatmap, max_width) for item in batch] - ), - class_heatmap=torch.stack( - [adjust_width(item.class_heatmap, max_width) for item in batch] - ), - idx=torch.stack([item.idx for item in batch]), - start_time=torch.stack([item.start_time for item in batch]), - end_time=torch.stack([item.end_time for item in batch]), - )