mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Device fixing #5
This commit is contained in:
parent
67e37227f5
commit
3043230d4f
@ -9,7 +9,6 @@ from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.common import PeakNormalize
|
||||
from batdetect2.typing.preprocess import SpectrogramBuilder
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
@ -83,7 +82,7 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
||||
def build_spectrogram_builder(
|
||||
samplerate: int,
|
||||
conf: STFTConfig,
|
||||
) -> SpectrogramBuilder:
|
||||
) -> torch.nn.Module:
|
||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
@ -339,7 +338,7 @@ def build_spectrogram_transform(
|
||||
class SpectrogramPipeline(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_builder: SpectrogramBuilder,
|
||||
spec_builder: torch.nn.Module,
|
||||
freq_cutter: torch.nn.Module,
|
||||
transforms: torch.nn.Module,
|
||||
resizer: torch.nn.Module,
|
||||
@ -351,10 +350,10 @@ class SpectrogramPipeline(torch.nn.Module):
|
||||
self.resizer = resizer
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
spec = self.spec_builder(wav)
|
||||
spec = self.freq_cutter(spec)
|
||||
spec = self.transforms(spec)
|
||||
return self.resizer(spec)
|
||||
spec = self.spec_builder.to(wav)(wav)
|
||||
spec = self.freq_cutter.to(wav)(spec)
|
||||
spec = self.transforms.to(wav)(spec)
|
||||
return self.resizer.to(wav)(spec)
|
||||
|
||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.spec_builder(wav)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user