Use custom AugmentationSequence instead of nn.Sequential

This commit is contained in:
mbsantiago 2025-08-31 19:27:15 +01:00
parent 2f4edeffff
commit 8d093c3ca2

View File

@ -551,6 +551,22 @@ DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
) )
class AugmentationSequence(torch.nn.Module):
def __init__(self, augmentations: List[torch.nn.Module]):
super().__init__()
self.augmentations = torch.nn.ModuleList(augmentations)
def forward(
self,
tensor: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
for aug in self.augmentations:
tensor, clip_annotation = aug(tensor, clip_annotation)
return tensor, clip_annotation
def build_augmentation_sequence( def build_augmentation_sequence(
samplerate: int, samplerate: int,
steps: Optional[Sequence[AugmentationConfig]] = None, steps: Optional[Sequence[AugmentationConfig]] = None,
@ -578,7 +594,7 @@ def build_augmentation_sequence(
) )
) )
return torch.nn.Sequential(*augmentations) return AugmentationSequence(augmentations)
def build_augmentations( def build_augmentations(
@ -599,6 +615,7 @@ def build_augmentations(
steps=config.audio, steps=config.audio,
audio_source=audio_source, audio_source=audio_source,
) )
spectrogram_augmentation = build_augmentation_sequence( spectrogram_augmentation = build_augmentation_sequence(
samplerate, samplerate,
steps=config.audio, steps=config.audio,