Device fixing #5

This commit is contained in:
mbsantiago 2025-08-25 23:20:30 +01:00
parent 67e37227f5
commit 3043230d4f

View File

@ -9,7 +9,6 @@ from pydantic import Field
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.preprocess.common import PeakNormalize from batdetect2.preprocess.common import PeakNormalize
from batdetect2.typing.preprocess import SpectrogramBuilder
__all__ = [ __all__ = [
"STFTConfig", "STFTConfig",
@ -83,7 +82,7 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig):
def build_spectrogram_builder( def build_spectrogram_builder(
samplerate: int, samplerate: int,
conf: STFTConfig, conf: STFTConfig,
) -> SpectrogramBuilder: ) -> torch.nn.Module:
n_fft, hop_length = _spec_params_from_config(samplerate, conf) n_fft, hop_length = _spec_params_from_config(samplerate, conf)
return torchaudio.transforms.Spectrogram( return torchaudio.transforms.Spectrogram(
n_fft=n_fft, n_fft=n_fft,
@ -339,7 +338,7 @@ def build_spectrogram_transform(
class SpectrogramPipeline(torch.nn.Module): class SpectrogramPipeline(torch.nn.Module):
def __init__( def __init__(
self, self,
spec_builder: SpectrogramBuilder, spec_builder: torch.nn.Module,
freq_cutter: torch.nn.Module, freq_cutter: torch.nn.Module,
transforms: torch.nn.Module, transforms: torch.nn.Module,
resizer: torch.nn.Module, resizer: torch.nn.Module,
@ -351,10 +350,10 @@ class SpectrogramPipeline(torch.nn.Module):
self.resizer = resizer self.resizer = resizer
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
spec = self.spec_builder(wav) spec = self.spec_builder.to(wav)(wav)
spec = self.freq_cutter(spec) spec = self.freq_cutter.to(wav)(spec)
spec = self.transforms(spec) spec = self.transforms.to(wav)(spec)
return self.resizer(spec) return self.resizer.to(wav)(spec)
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
return self.spec_builder(wav) return self.spec_builder(wav)