mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Device fixing #5
This commit is contained in:
parent
67e37227f5
commit
3043230d4f
@ -9,7 +9,6 @@ from pydantic import Field
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess.common import PeakNormalize
|
from batdetect2.preprocess.common import PeakNormalize
|
||||||
from batdetect2.typing.preprocess import SpectrogramBuilder
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"STFTConfig",
|
"STFTConfig",
|
||||||
@ -83,7 +82,7 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
|||||||
def build_spectrogram_builder(
|
def build_spectrogram_builder(
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
conf: STFTConfig,
|
conf: STFTConfig,
|
||||||
) -> SpectrogramBuilder:
|
) -> torch.nn.Module:
|
||||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||||
return torchaudio.transforms.Spectrogram(
|
return torchaudio.transforms.Spectrogram(
|
||||||
n_fft=n_fft,
|
n_fft=n_fft,
|
||||||
@ -339,7 +338,7 @@ def build_spectrogram_transform(
|
|||||||
class SpectrogramPipeline(torch.nn.Module):
|
class SpectrogramPipeline(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spec_builder: SpectrogramBuilder,
|
spec_builder: torch.nn.Module,
|
||||||
freq_cutter: torch.nn.Module,
|
freq_cutter: torch.nn.Module,
|
||||||
transforms: torch.nn.Module,
|
transforms: torch.nn.Module,
|
||||||
resizer: torch.nn.Module,
|
resizer: torch.nn.Module,
|
||||||
@ -351,10 +350,10 @@ class SpectrogramPipeline(torch.nn.Module):
|
|||||||
self.resizer = resizer
|
self.resizer = resizer
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
spec = self.spec_builder(wav)
|
spec = self.spec_builder.to(wav)(wav)
|
||||||
spec = self.freq_cutter(spec)
|
spec = self.freq_cutter.to(wav)(spec)
|
||||||
spec = self.transforms(spec)
|
spec = self.transforms.to(wav)(spec)
|
||||||
return self.resizer(spec)
|
return self.resizer.to(wav)(spec)
|
||||||
|
|
||||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
return self.spec_builder(wav)
|
return self.spec_builder(wav)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user