mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Use custom AugmentationSequence instead of nn.Sequential
This commit is contained in:
parent
2f4edeffff
commit
8d093c3ca2
@ -551,6 +551,22 @@ DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationSequence(torch.nn.Module):
|
||||||
|
def __init__(self, augmentations: List[torch.nn.Module]):
|
||||||
|
super().__init__()
|
||||||
|
self.augmentations = torch.nn.ModuleList(augmentations)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
|
for aug in self.augmentations:
|
||||||
|
tensor, clip_annotation = aug(tensor, clip_annotation)
|
||||||
|
|
||||||
|
return tensor, clip_annotation
|
||||||
|
|
||||||
|
|
||||||
def build_augmentation_sequence(
|
def build_augmentation_sequence(
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
steps: Optional[Sequence[AugmentationConfig]] = None,
|
steps: Optional[Sequence[AugmentationConfig]] = None,
|
||||||
@ -578,7 +594,7 @@ def build_augmentation_sequence(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch.nn.Sequential(*augmentations)
|
return AugmentationSequence(augmentations)
|
||||||
|
|
||||||
|
|
||||||
def build_augmentations(
|
def build_augmentations(
|
||||||
@ -599,6 +615,7 @@ def build_augmentations(
|
|||||||
steps=config.audio,
|
steps=config.audio,
|
||||||
audio_source=audio_source,
|
audio_source=audio_source,
|
||||||
)
|
)
|
||||||
|
|
||||||
spectrogram_augmentation = build_augmentation_sequence(
|
spectrogram_augmentation = build_augmentation_sequence(
|
||||||
samplerate,
|
samplerate,
|
||||||
steps=config.audio,
|
steps=config.audio,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user