mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Use custom AugmentationSequence instead of nn.Sequential
This commit is contained in:
parent
2f4edeffff
commit
8d093c3ca2
@ -551,6 +551,22 @@ DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||
)
|
||||
|
||||
|
||||
class AugmentationSequence(torch.nn.Module):
|
||||
def __init__(self, augmentations: List[torch.nn.Module]):
|
||||
super().__init__()
|
||||
self.augmentations = torch.nn.ModuleList(augmentations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||
for aug in self.augmentations:
|
||||
tensor, clip_annotation = aug(tensor, clip_annotation)
|
||||
|
||||
return tensor, clip_annotation
|
||||
|
||||
|
||||
def build_augmentation_sequence(
|
||||
samplerate: int,
|
||||
steps: Optional[Sequence[AugmentationConfig]] = None,
|
||||
@ -578,7 +594,7 @@ def build_augmentation_sequence(
|
||||
)
|
||||
)
|
||||
|
||||
return torch.nn.Sequential(*augmentations)
|
||||
return AugmentationSequence(augmentations)
|
||||
|
||||
|
||||
def build_augmentations(
|
||||
@ -599,6 +615,7 @@ def build_augmentations(
|
||||
steps=config.audio,
|
||||
audio_source=audio_source,
|
||||
)
|
||||
|
||||
spectrogram_augmentation = build_augmentation_sequence(
|
||||
samplerate,
|
||||
steps=config.audio,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user