Use custom AugmentationSequence instead of nn.Sequential

This commit is contained in:
mbsantiago 2025-08-31 19:27:15 +01:00
parent 2f4edeffff
commit 8d093c3ca2

View File

@ -551,6 +551,22 @@ DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
)
class AugmentationSequence(torch.nn.Module):
def __init__(self, augmentations: List[torch.nn.Module]):
super().__init__()
self.augmentations = torch.nn.ModuleList(augmentations)
def forward(
self,
tensor: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
for aug in self.augmentations:
tensor, clip_annotation = aug(tensor, clip_annotation)
return tensor, clip_annotation
def build_augmentation_sequence(
samplerate: int,
steps: Optional[Sequence[AugmentationConfig]] = None,
@ -578,7 +594,7 @@ def build_augmentation_sequence(
)
)
return torch.nn.Sequential(*augmentations)
return AugmentationSequence(augmentations)
def build_augmentations(
@ -599,6 +615,7 @@ def build_augmentations(
steps=config.audio,
audio_source=audio_source,
)
spectrogram_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,