Cleaning train module

This commit is contained in:
mbsantiago 2025-09-15 16:50:08 +01:00
parent e752e96b93
commit 704b28292b
5 changed files with 237 additions and 229 deletions

View File

@ -22,7 +22,14 @@ from batdetect2.train.config import (
load_full_training_config,
load_train_config,
)
from batdetect2.train.dataset import TrainingDataset
from batdetect2.train.dataset import (
TrainingDataset,
ValidationDataset,
build_train_dataset,
build_train_loader,
build_val_dataset,
build_val_loader,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import (
@ -33,14 +40,7 @@ from batdetect2.train.losses import (
SizeLossConfig,
build_loss,
)
from batdetect2.train.train import (
build_train_dataset,
build_train_loader,
build_trainer,
build_val_dataset,
build_val_loader,
train,
)
from batdetect2.train.train import build_trainer, train
__all__ = [
"AugmentationsConfig",
@ -49,7 +49,6 @@ __all__ = [
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"TrainingDataset",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
@ -57,7 +56,9 @@ __all__ = [
"SizeLossConfig",
"TimeMaskAugmentationConfig",
"TrainingConfig",
"TrainingDataset",
"TrainingModule",
"ValidationDataset",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo",

View File

@ -72,13 +72,16 @@ class TrainLoaderConfig(BaseConfig):
)
class TrainingConfig(BaseConfig):
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainingConfig(BaseConfig):
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)

View File

@ -1,18 +1,31 @@
from typing import Optional, Sequence, Tuple
from typing import List, Optional, Sequence
import torch
from loguru import logger
from soundevent import data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,
)
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import (
Augmentation,
ClipLabeller,
)
from batdetect2.typing.train import Augmentation, ClipLabeller
from batdetect2.utils.arrays import adjust_width
__all__ = [
"TrainingDataset",
"ValidationDataset",
"build_val_loader",
"build_train_loader",
"build_train_dataset",
"build_val_dataset",
]
@ -124,3 +137,174 @@ class ValidationDataset(Dataset):
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainLoaderConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
train_dataset = build_train_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
config = config or ValLoaderConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
val_dataset = build_val_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
random_example_source = RandomAudioSource(
clip_annotations,
audio_loader=audio_loader,
)
if config.augmentations.enabled:
audio_augmentation, spectrogram_augmentation = build_augmentations(
samplerate=preprocessor.input_samplerate,
config=config.augmentations,
audio_source=random_example_source,
)
else:
logger.debug("No augmentations configured for training dataset.")
audio_augmentation = None
spectrogram_augmentation = None
return TrainingDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
clipper=clipper,
preprocessor=preprocessor,
audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation,
)
def build_val_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or ValLoaderConfig()
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -77,3 +77,15 @@ def load_model_from_checkpoint(
) -> Tuple[Model, FullTrainingConfig]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config
def build_training_module(
config: Optional[FullTrainingConfig] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
return TrainingModule(
config=config,
learning_rate=config.train.optimizer.learning_rate,
t_max=t_max,
)

View File

@ -2,47 +2,31 @@ from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional
import torch
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import (
FullTrainingConfig,
TrainLoaderConfig,
ValLoaderConfig,
)
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.lightning import TrainingModule, build_training_module
from batdetect2.train.logging import build_logger
from batdetect2.typing import (
PreprocessorProtocol,
TargetProtocol,
TrainExample,
)
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.typing.train import ClipLabeller
from batdetect2.utils.arrays import adjust_width
__all__ = [
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_val_dataset",
"build_val_loader",
"train",
]
@ -52,6 +36,11 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
trainer: Optional[Trainer] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None,
@ -67,13 +56,15 @@ def train(
config = config or FullTrainingConfig()
targets = build_targets(config.targets)
targets = targets or build_targets(config.targets)
preprocessor = build_preprocessor(config.preprocess)
preprocessor = preprocessor or build_preprocessor(config.preprocess)
audio_loader = build_audio_loader(config=config.preprocess.audio)
audio_loader = audio_loader or build_audio_loader(
config=config.preprocess.audio
)
labeller = build_clip_labeler(
labeller = labeller or build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
@ -108,10 +99,10 @@ def train(
else:
module = build_training_module(
config,
t_max=config.train.t_max * len(train_dataloader),
t_max=config.train.optimizer.t_max * len(train_dataloader),
)
trainer = build_trainer(
trainer = trainer or build_trainer(
config,
targets=targets,
checkpoint_dir=checkpoint_dir,
@ -129,21 +120,9 @@ def train(
logger.info("Training complete.")
def build_training_module(
config: Optional[FullTrainingConfig] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
return TrainingModule(
config=config,
learning_rate=config.train.learning_rate,
t_max=t_max,
)
def build_trainer_callbacks(
targets: TargetProtocol,
config: EvaluationConfig,
config: FullTrainingConfig,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
@ -157,7 +136,7 @@ def build_trainer_callbacks(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = build_evaluator(config=config, targets=targets)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
return [
ModelCheckpoint(
@ -202,180 +181,9 @@ def build_trainer(
logger=train_logger,
callbacks=build_trainer_callbacks(
targets,
config=conf.evaluation,
config=conf,
checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name,
run_name=run_name,
),
)
def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainLoaderConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
train_dataset = build_train_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
config = config or ValLoaderConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
val_dataset = build_val_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
random_example_source = RandomAudioSource(
clip_annotations,
audio_loader=audio_loader,
)
if config.augmentations.enabled:
audio_augmentation, spectrogram_augmentation = build_augmentations(
samplerate=preprocessor.input_samplerate,
config=config.augmentations,
audio_source=random_example_source,
)
else:
logger.debug("No augmentations configured for training dataset.")
audio_augmentation = None
spectrogram_augmentation = None
return TrainingDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
clipper=clipper,
preprocessor=preprocessor,
audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation,
)
def build_val_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or ValLoaderConfig()
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)