mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Tweaks of augmentation config
This commit is contained in:
parent
ed76ec24b6
commit
ff754a1269
@ -400,6 +400,8 @@ AugmentationConfig = Annotated[
|
||||
class AugmentationsConfig(BaseConfig):
|
||||
"""Configuration for a sequence of data augmentations."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
steps: List[AugmentationConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@ -67,8 +67,8 @@ class TrainingConfig(BaseConfig):
|
||||
t_max: int = 100
|
||||
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
augmentations: Optional[AugmentationsConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||
)
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
|
||||
@ -239,7 +239,7 @@ def build_train_dataset(
|
||||
clipper=clipper,
|
||||
)
|
||||
|
||||
if config.augmentations and config.augmentations.steps:
|
||||
if config.augmentations.enabled and config.augmentations.steps:
|
||||
augmentations = build_augmentations(
|
||||
preprocessor,
|
||||
config=config.augmentations,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user