From 3043230d4f4c9141f6df14e70ade1188b886d8ee Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 23:20:30 +0100 Subject: [PATCH] Device fixing #5 --- src/batdetect2/preprocess/spectrogram.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 5e0344c..dde3e13 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -9,7 +9,6 @@ from pydantic import Field from batdetect2.configs import BaseConfig from batdetect2.preprocess.common import PeakNormalize -from batdetect2.typing.preprocess import SpectrogramBuilder __all__ = [ "STFTConfig", @@ -83,7 +82,7 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig): def build_spectrogram_builder( samplerate: int, conf: STFTConfig, -) -> SpectrogramBuilder: +) -> torch.nn.Module: n_fft, hop_length = _spec_params_from_config(samplerate, conf) return torchaudio.transforms.Spectrogram( n_fft=n_fft, @@ -339,7 +338,7 @@ def build_spectrogram_transform( class SpectrogramPipeline(torch.nn.Module): def __init__( self, - spec_builder: SpectrogramBuilder, + spec_builder: torch.nn.Module, freq_cutter: torch.nn.Module, transforms: torch.nn.Module, resizer: torch.nn.Module, @@ -351,10 +350,10 @@ class SpectrogramPipeline(torch.nn.Module): self.resizer = resizer def forward(self, wav: torch.Tensor) -> torch.Tensor: - spec = self.spec_builder(wav) - spec = self.freq_cutter(spec) - spec = self.transforms(spec) - return self.resizer(spec) + spec = self.spec_builder.to(wav)(wav) + spec = self.freq_cutter.to(wav)(spec) + spec = self.transforms.to(wav)(spec) + return self.resizer.to(wav)(spec) def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: return self.spec_builder(wav)