Improve augmentations config and logging

This commit is contained in:
mbsantiago 2025-06-28 11:36:18 -06:00
parent ed67d8ceec
commit 19e873dd0b
3 changed files with 29 additions and 5 deletions

View File

@ -1,3 +1,4 @@
import sys
from pathlib import Path
from typing import Optional
@ -25,6 +26,12 @@ __all__ = [
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int, default=0)
@click.option("--val-workers", type=int, default=0)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command(
train_dir: Path,
val_dir: Optional[Path] = None,
@ -33,7 +40,17 @@ def train_command(
config_field: Optional[str] = None,
train_workers: int = 0,
val_workers: int = 0,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Starting training!")
conf = (

View File

@ -50,7 +50,7 @@ class TrainingConfig(PLTrainerConfig):
learning_rate: float = 1e-3
t_max: int = 100
loss: LossConfig = Field(default_factory=LossConfig)
augmentations: AugmentationsConfig = Field(
augmentations: Optional[AugmentationsConfig] = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)

View File

@ -164,10 +164,17 @@ def build_train_dataset(
clipper=clipper,
)
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
logger.debug(
"Augmentations config: {}.", config.augmentations
)
augmentations = (
build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
if config.augmentations
else None
)
return LabeledDataset(