diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index b79ef81..067a965 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -370,10 +370,10 @@ class SpectrogramPipeline(torch.nn.Module): self.resizer = resizer def forward(self, wav: torch.Tensor) -> torch.Tensor: - spec = self.spec_builder.to(wav)(wav) - spec = self.freq_cutter.to(wav)(spec) - spec = self.transforms.to(wav)(spec) - return self.resizer.to(wav)(spec) + spec = self.spec_builder(wav) + spec = self.freq_cutter(spec) + spec = self.transforms(spec) + return self.resizer(spec) def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: return self.spec_builder(wav)