From 587742b41ea59b6ef47b75c886d87eeb228a8b8f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 16:02:41 -0600 Subject: [PATCH] Change train to use full config --- src/batdetect2/cli/__init__.py | 4 +- src/batdetect2/cli/train.py | 239 ++++-------------------------- src/batdetect2/train/__init__.py | 17 ++- src/batdetect2/train/config.py | 15 +- src/batdetect2/train/train.py | 245 +++++++++++++++++++++---------- 5 files changed, 213 insertions(+), 307 deletions(-) diff --git a/src/batdetect2/cli/__init__.py b/src/batdetect2/cli/__init__.py index 5643b47..59ec0c0 100644 --- a/src/batdetect2/cli/__init__.py +++ b/src/batdetect2/cli/__init__.py @@ -2,13 +2,13 @@ from batdetect2.cli.base import cli from batdetect2.cli.compat import detect from batdetect2.cli.data import data from batdetect2.cli.preprocess import preprocess -from batdetect2.cli.train import train +from batdetect2.cli.train import train_detector __all__ = [ "cli", "detect", "data", - "train", + "train_detector", "preprocess", ] diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 8c5169c..e9bf4b4 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -5,236 +5,53 @@ import click from loguru import logger from batdetect2.cli.base import cli -from batdetect2.evaluate.metrics import ( - ClassificationAccuracy, - ClassificationMeanAveragePrecision, - DetectionAveragePrecision, +from batdetect2.train import ( + FullTrainingConfig, + load_full_training_config, + train, ) -from batdetect2.models import build_model -from batdetect2.models.backbones import load_backbone_config -from batdetect2.postprocess import build_postprocessor, load_postprocess_config -from batdetect2.preprocess import build_preprocessor, load_preprocessing_config -from batdetect2.targets import build_targets, load_target_config -from batdetect2.train import train -from batdetect2.train.callbacks import ValidationMetrics -from batdetect2.train.config import TrainingConfig, load_train_config from batdetect2.train.dataset import list_preprocessed_files __all__ = [ "train_command", ] -DEFAULT_CONFIG_FILE = Path("config.yaml") - @cli.command(name="train") -@click.option( - "--train-examples", - type=click.Path(exists=True), - required=True, -) -@click.option("--val-examples", type=click.Path(exists=True)) -@click.option( - "--model-path", - type=click.Path(exists=True), -) -@click.option( - "--train-config", - type=click.Path(exists=True), - default=DEFAULT_CONFIG_FILE, -) -@click.option( - "--train-config-field", - type=str, - default="train", -) -@click.option( - "--preprocess-config", - type=click.Path(exists=True), - help=( - "Path to the preprocessing configuration file. This file tells " - "the program how to prepare your audio data before training, such " - "as resampling or applying filters." - ), - default=DEFAULT_CONFIG_FILE, -) -@click.option( - "--preprocess-config-field", - type=str, - help=( - "If the preprocessing settings are inside a nested dictionary " - "within the preprocessing configuration file, specify the key " - "here to access them. If the preprocessing settings are at the " - "top level, you don't need to specify this." - ), - default="preprocess", -) -@click.option( - "--target-config", - type=click.Path(exists=True), - help=( - "Path to the training target configuration file. This file " - "specifies what sounds the model should learn to predict." - ), - default=DEFAULT_CONFIG_FILE, -) -@click.option( - "--target-config-field", - type=str, - help=( - "If the target settings are inside a nested dictionary " - "within the target configuration file, specify the key here. " - "If the settings are at the top level, you don't need to specify this." - ), - default="targets", -) -@click.option( - "--postprocess-config", - type=click.Path(exists=True), - default=DEFAULT_CONFIG_FILE, -) -@click.option( - "--postprocess-config-field", - type=str, - default="postprocess", -) -@click.option( - "--model-config", - type=click.Path(exists=True), - default=DEFAULT_CONFIG_FILE, -) -@click.option( - "--model-config-field", - type=str, - default="model", -) -@click.option( - "--train-workers", - type=int, - default=0, -) -@click.option( - "--val-workers", - type=int, - default=0, -) +@click.option("--train-dir", type=click.Path(exists=True), required=True) +@click.option("--val-dir", type=click.Path(exists=True)) +@click.option("--model-path", type=click.Path(exists=True)) +@click.option("--config", type=click.Path(exists=True)) +@click.option("--config-field", type=str) +@click.option("--train-workers", type=int, default=0) +@click.option("--val-workers", type=int, default=0) def train_command( - train_examples: Path, - val_examples: Optional[Path] = None, + train_dir: Path, + val_dir: Optional[Path] = None, model_path: Optional[Path] = None, - train_config: Path = DEFAULT_CONFIG_FILE, - train_config_field: str = "train", - preprocess_config: Path = DEFAULT_CONFIG_FILE, - preprocess_config_field: str = "preprocess", - target_config: Path = DEFAULT_CONFIG_FILE, - target_config_field: str = "targets", - postprocess_config: Path = DEFAULT_CONFIG_FILE, - postprocess_config_field: str = "postprocess", - model_config: Path = DEFAULT_CONFIG_FILE, - model_config_field: str = "model", + config: Optional[Path] = None, + config_field: Optional[str] = None, train_workers: int = 0, val_workers: int = 0, ): logger.info("Starting training!") - try: - target_config_loaded = load_target_config( - path=target_config, - field=target_config_field, - ) - targets = build_targets(config=target_config_loaded) - logger.debug( - "Loaded targets info from config file {path}", path=target_config - ) - except IOError: - logger.debug( - "Could not load target info from config file, using default" - ) - targets = build_targets() - - try: - preprocess_config_loaded = load_preprocessing_config( - path=preprocess_config, - field=preprocess_config_field, - ) - preprocessor = build_preprocessor(preprocess_config_loaded) - logger.debug( - "Loaded preprocessor from config file {path}", path=target_config - ) - - except IOError: - logger.debug( - "Could not load preprocessor from config file, using default" - ) - preprocessor = build_preprocessor() - - try: - model_config_loaded = load_backbone_config( - path=model_config, field=model_config_field - ) - model = build_model( - num_classes=len(targets.class_names), - config=model_config_loaded, - ) - except IOError: - model = build_model(num_classes=len(targets.class_names)) - - try: - postprocess_config_loaded = load_postprocess_config( - path=postprocess_config, - field=postprocess_config_field, - ) - postprocessor = build_postprocessor( - targets=targets, - config=postprocess_config_loaded, - ) - logger.debug( - "Loaded postprocessor from file {path}", path=postprocess_config - ) - except IOError: - logger.debug( - "Could not load postprocessor config from file. Using default" - ) - postprocessor = build_postprocessor(targets=targets) - - try: - train_config_loaded = load_train_config( - path=train_config, field=train_config_field - ) - logger.debug( - "Loaded training config from file {path}", - path=train_config, - ) - except IOError: - train_config_loaded = TrainingConfig() - logger.debug("Could not load training config from file. Using default") - - train_files = list_preprocessed_files(train_examples) - - val_files = ( - None if val_examples is None else list_preprocessed_files(val_examples) + conf = ( + load_full_training_config(config, field=config_field) + if config is not None + else FullTrainingConfig() ) - return train( - detector=model, - train_examples=train_files, # type: ignore - val_examples=val_files, # type: ignore + train_examples = list_preprocessed_files(train_dir) + val_examples = ( + list_preprocessed_files(val_dir) if val_dir is not None else None + ) + + train( + train_examples=train_examples, + val_examples=val_examples, + config=conf, model_path=model_path, - preprocessor=preprocessor, - postprocessor=postprocessor, - targets=targets, - config=train_config_loaded, - callbacks=[ - ValidationMetrics( - metrics=[ - DetectionAveragePrecision(), - ClassificationMeanAveragePrecision( - class_names=targets.class_names, - ), - ClassificationAccuracy(class_names=targets.class_names), - ] - ) - ], train_workers=train_workers, val_workers=val_workers, ) diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index d0baebc..44a0c58 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -15,7 +15,7 @@ from batdetect2.train.augmentations import ( ) from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.config import ( - TrainerConfig, + PLTrainerConfig, TrainingConfig, load_train_config, ) @@ -39,8 +39,14 @@ from batdetect2.train.preprocess import ( preprocess_annotations, ) from batdetect2.train.train import ( + FullTrainingConfig, build_train_dataset, + build_train_loader, + build_trainer, + build_training_module, build_val_dataset, + build_val_loader, + load_full_training_config, train, ) @@ -50,14 +56,15 @@ __all__ = [ "DetectionLossConfig", "EchoAugmentationConfig", "FrequencyMaskAugmentationConfig", + "FullTrainingConfig", "LabeledDataset", "LossConfig", "LossFunction", + "PLTrainerConfig", "RandomExampleSource", "SizeLossConfig", "TimeMaskAugmentationConfig", "TrainExample", - "TrainerConfig", "TrainingConfig", "VolumeAugmentationConfig", "WarpAugmentationConfig", @@ -67,9 +74,14 @@ __all__ = [ "build_clipper", "build_loss", "build_train_dataset", + "build_train_loader", + "build_trainer", + "build_training_module", "build_val_dataset", + "build_val_loader", "generate_train_example", "list_preprocessed_files", + "load_full_training_config", "load_label_config", "load_train_config", "mask_frequency", @@ -79,6 +91,5 @@ __all__ = [ "scale_volume", "select_subclip", "train", - "train", "warp_spectrogram", ] diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 2ca4338..d6a2a31 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -13,18 +13,12 @@ from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.losses import LossConfig __all__ = [ - "OptimizerConfig", "TrainingConfig", "load_train_config", ] -class OptimizerConfig(BaseConfig): - learning_rate: float = 1e-3 - t_max: int = 100 - - -class TrainerConfig(BaseConfig): +class PLTrainerConfig(BaseConfig): accelerator: str = "auto" accumulate_grad_batches: int = 1 deterministic: bool = True @@ -45,15 +39,16 @@ class TrainerConfig(BaseConfig): val_check_interval: Optional[Union[int, float]] = None -class TrainingConfig(BaseConfig): +class TrainingConfig(PLTrainerConfig): batch_size: int = 8 + learning_rate: float = 1e-3 + t_max: int = 100 loss: LossConfig = Field(default_factory=LossConfig) - optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) augmentations: AugmentationsConfig = Field( default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG ) cliping: ClipingConfig = Field(default_factory=ClipingConfig) - trainer: TrainerConfig = Field(default_factory=TrainerConfig) + trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index d386e98..c2810dd 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,20 +1,28 @@ +from collections.abc import Sequence from typing import List, Optional from lightning import Trainer from lightning.pytorch.callbacks import Callback +from pydantic import Field from soundevent import data from torch.utils.data import DataLoader -from batdetect2.models.types import DetectionModel -from batdetect2.postprocess import build_postprocessor -from batdetect2.postprocess.types import PostprocessorProtocol -from batdetect2.preprocess import build_preprocessor -from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets import build_targets -from batdetect2.targets.types import TargetProtocol -from batdetect2.train.augmentations import ( - build_augmentations, +from batdetect2.configs import BaseConfig, load_config +from batdetect2.evaluate.metrics import ( + ClassificationAccuracy, + ClassificationMeanAveragePrecision, + DetectionAveragePrecision, ) +from batdetect2.models import BackboneConfig, build_model +from batdetect2.postprocess import PostprocessConfig, build_postprocessor +from batdetect2.preprocess import ( + PreprocessingConfig, + PreprocessorProtocol, + build_preprocessor, +) +from batdetect2.targets import TargetConfig, TargetProtocol, build_targets +from batdetect2.train.augmentations import build_augmentations +from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper from batdetect2.train.config import TrainingConfig from batdetect2.train.dataset import ( @@ -27,94 +35,71 @@ from batdetect2.train.logging import build_logger from batdetect2.train.losses import build_loss __all__ = [ - "train", - "build_val_dataset", + "FullTrainingConfig", "build_train_dataset", + "build_train_loader", + "build_trainer", + "build_training_module", + "build_val_dataset", + "build_val_loader", + "load_full_training_config", + "train", ] +class FullTrainingConfig(BaseConfig): + """Full training configuration.""" + + train: TrainingConfig = Field(default_factory=TrainingConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + model: BackboneConfig = Field(default_factory=BackboneConfig) + preprocess: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + + +def load_full_training_config( + path: data.PathLike, + field: Optional[str] = None, +) -> FullTrainingConfig: + """Load the full training configuration.""" + return load_config(path, schema=FullTrainingConfig, field=field) + + def train( - detector: DetectionModel, - train_examples: List[data.PathLike], - targets: Optional[TargetProtocol] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - postprocessor: Optional[PostprocessorProtocol] = None, - val_examples: Optional[List[data.PathLike]] = None, - config: Optional[TrainingConfig] = None, - callbacks: Optional[List[Callback]] = None, + train_examples: Sequence[data.PathLike], + val_examples: Optional[Sequence[data.PathLike]] = None, + config: Optional[FullTrainingConfig] = None, model_path: Optional[data.PathLike] = None, train_workers: int = 0, val_workers: int = 0, - **trainer_kwargs, -) -> None: - config = config or TrainingConfig() +): + conf = config or FullTrainingConfig() - if model_path is None: - if preprocessor is None: - preprocessor = build_preprocessor() - - if targets is None: - targets = build_targets() - - if postprocessor is None: - postprocessor = build_postprocessor( - targets, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - ) - - loss = build_loss(config.loss) - - module = TrainingModule( - detector=detector, - loss=loss, - targets=targets, - preprocessor=preprocessor, - postprocessor=postprocessor, - learning_rate=config.optimizer.learning_rate, - t_max=config.optimizer.t_max, - ) - else: + if model_path is not None: module = TrainingModule.load_from_checkpoint(model_path) # type: ignore + else: + module = build_training_module(conf) - train_dataset = build_train_dataset( + trainer = build_trainer(conf, targets=module.targets) + + train_dataloader = build_train_loader( train_examples, preprocessor=module.preprocessor, - config=config, - ) - - logger = build_logger(config.logger) - if logger and hasattr(logger, 'log_hyperparams'): - logger.log_hyperparams(config.model_dump(exclude_none=True)) - - trainer = Trainer( - **config.trainer.model_dump(exclude_none=True, exclude={"logger"}), - callbacks=callbacks, - logger=logger, - **trainer_kwargs, - ) - - train_dataloader = DataLoader( - train_dataset, - batch_size=config.batch_size, - shuffle=True, + config=conf.train, num_workers=train_workers, - collate_fn=collate_fn, ) - val_dataloader = None - if val_examples: - val_dataset = build_val_dataset( + val_dataloader = ( + build_val_loader( val_examples, - config=config, - ) - val_dataloader = DataLoader( - val_dataset, - batch_size=config.batch_size, - shuffle=False, + config=conf.train, num_workers=val_workers, - collate_fn=collate_fn, ) + if val_examples is not None + else None + ) trainer.fit( module, @@ -123,8 +108,106 @@ def train( ) +def build_training_module(conf: FullTrainingConfig) -> TrainingModule: + preprocessor = build_preprocessor(conf.preprocess) + + targets = build_targets(conf.targets) + + postprocessor = build_postprocessor( + targets, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + + model = build_model( + num_classes=len(targets.class_names), + config=conf.model, + ) + + loss = build_loss(conf.train.loss) + + return TrainingModule( + detector=model, + loss=loss, + targets=targets, + preprocessor=preprocessor, + postprocessor=postprocessor, + learning_rate=conf.train.learning_rate, + t_max=conf.train.t_max, + ) + + +def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: + return [ + ValidationMetrics( + metrics=[ + DetectionAveragePrecision(), + ClassificationMeanAveragePrecision( + class_names=targets.class_names + ), + ClassificationAccuracy(class_names=targets.class_names), + ] + ) + ] + + +def build_trainer( + conf: FullTrainingConfig, + targets: TargetProtocol, +) -> Trainer: + logger = build_logger(conf.train.logger) + + if logger and hasattr(logger, "log_hyperparams"): + logger.log_hyperparams(conf.model_dump(exclude_none=True)) + + return Trainer( + accelerator=conf.train.accelerator, + logger=logger, + callbacks=build_trainer_callbacks(targets), + ) + + +def build_train_loader( + train_examples: Sequence[data.PathLike], + preprocessor: PreprocessorProtocol, + config: TrainingConfig, + num_workers: Optional[int] = None, +) -> DataLoader: + train_dataset = build_train_dataset( + train_examples, + preprocessor=preprocessor, + config=config, + ) + + return DataLoader( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + num_workers=num_workers or 0, + collate_fn=collate_fn, + ) + + +def build_val_loader( + val_examples: Sequence[data.PathLike], + config: TrainingConfig, + num_workers: Optional[int] = None, +): + val_dataset = build_val_dataset( + val_examples, + config=config, + ) + return DataLoader( + val_dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=num_workers or 0, + collate_fn=collate_fn, + ) + + def build_train_dataset( - examples: List[data.PathLike], + examples: Sequence[data.PathLike], preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, ) -> LabeledDataset: @@ -133,7 +216,7 @@ def build_train_dataset( clipper = build_clipper(config.cliping, random=True) random_example_source = RandomExampleSource( - examples, + list(examples), clipper=clipper, ) @@ -151,7 +234,7 @@ def build_train_dataset( def build_val_dataset( - examples: List[data.PathLike], + examples: Sequence[data.PathLike], config: Optional[TrainingConfig] = None, train: bool = True, ) -> LabeledDataset: