diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 37d96b9..3bcdf93 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -551,6 +551,22 @@ DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( ) +class AugmentationSequence(torch.nn.Module): + def __init__(self, augmentations: List[torch.nn.Module]): + super().__init__() + self.augmentations = torch.nn.ModuleList(augmentations) + + def forward( + self, + tensor: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + for aug in self.augmentations: + tensor, clip_annotation = aug(tensor, clip_annotation) + + return tensor, clip_annotation + + def build_augmentation_sequence( samplerate: int, steps: Optional[Sequence[AugmentationConfig]] = None, @@ -578,7 +594,7 @@ def build_augmentation_sequence( ) ) - return torch.nn.Sequential(*augmentations) + return AugmentationSequence(augmentations) def build_augmentations( @@ -599,6 +615,7 @@ def build_augmentations( steps=config.audio, audio_source=audio_source, ) + spectrogram_augmentation = build_augmentation_sequence( samplerate, steps=config.audio,