diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index b89ae97..84499c0 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -400,6 +400,8 @@ AugmentationConfig = Annotated[ class AugmentationsConfig(BaseConfig): """Configuration for a sequence of data augmentations.""" + enabled: bool = True + steps: List[AugmentationConfig] = Field(default_factory=list) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index e2cb1cd..5e77dc5 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -67,8 +67,8 @@ class TrainingConfig(BaseConfig): t_max: int = 100 dataloaders: LoadersConfig = Field(default_factory=LoadersConfig) loss: LossConfig = Field(default_factory=LossConfig) - augmentations: Optional[AugmentationsConfig] = Field( - default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG + augmentations: AugmentationsConfig = Field( + default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy() ) cliping: ClipingConfig = Field(default_factory=ClipingConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 9da168c..1cb899a 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -239,7 +239,7 @@ def build_train_dataset( clipper=clipper, ) - if config.augmentations and config.augmentations.steps: + if config.augmentations.enabled and config.augmentations.steps: augmentations = build_augmentations( preprocessor, config=config.augmentations,