diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 823e7bb..330c0c1 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -1,3 +1,4 @@ +import sys from pathlib import Path from typing import Optional @@ -25,6 +26,12 @@ __all__ = [ @click.option("--config-field", type=str) @click.option("--train-workers", type=int, default=0) @click.option("--val-workers", type=int, default=0) +@click.option( + "-v", + "--verbose", + count=True, + help="Increase verbosity. -v for INFO, -vv for DEBUG.", +) def train_command( train_dir: Path, val_dir: Optional[Path] = None, @@ -33,7 +40,17 @@ def train_command( config_field: Optional[str] = None, train_workers: int = 0, val_workers: int = 0, + verbose: int = 0, ): + logger.remove() + if verbose == 0: + log_level = "WARNING" + elif verbose == 1: + log_level = "INFO" + else: + log_level = "DEBUG" + logger.add(sys.stderr, level=log_level) + logger.info("Starting training!") conf = ( diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 2f1ff42..c6c17af 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -50,7 +50,7 @@ class TrainingConfig(PLTrainerConfig): learning_rate: float = 1e-3 t_max: int = 100 loss: LossConfig = Field(default_factory=LossConfig) - augmentations: AugmentationsConfig = Field( + augmentations: Optional[AugmentationsConfig] = Field( default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG ) cliping: ClipingConfig = Field(default_factory=ClipingConfig) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 1feb5a1..dce0221 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -164,10 +164,17 @@ def build_train_dataset( clipper=clipper, ) - augmentations = build_augmentations( - preprocessor, - config=config.augmentations, - example_source=random_example_source, + logger.debug( + "Augmentations config: {}.", config.augmentations + ) + augmentations = ( + build_augmentations( + preprocessor, + config=config.augmentations, + example_source=random_example_source, + ) + if config.augmentations + else None ) return LabeledDataset(