Tweaks of augmentation config

This commit is contained in:
mbsantiago 2025-08-27 18:23:38 +01:00
parent ed76ec24b6
commit ff754a1269
3 changed files with 5 additions and 3 deletions

View File

@ -400,6 +400,8 @@ AugmentationConfig = Annotated[
class AugmentationsConfig(BaseConfig):
"""Configuration for a sequence of data augmentations."""
enabled: bool = True
steps: List[AugmentationConfig] = Field(default_factory=list)

View File

@ -67,8 +67,8 @@ class TrainingConfig(BaseConfig):
t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
loss: LossConfig = Field(default_factory=LossConfig)
augmentations: Optional[AugmentationsConfig] = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)

View File

@ -239,7 +239,7 @@ def build_train_dataset(
clipper=clipper,
)
if config.augmentations and config.augmentations.steps:
if config.augmentations.enabled and config.augmentations.steps:
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,