mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Ensure preprocessor is in CPU
This commit is contained in:
parent
0d90cb5cc3
commit
fb3dc3eaf0
@ -29,3 +29,6 @@ class PreprocessorProtocol(Protocol):
|
||||
|
||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||
return self(torch.tensor(wav)).numpy()
|
||||
|
||||
def cpu(self) -> "PreprocessorProtocol":
|
||||
return self
|
||||
|
||||
@ -272,7 +272,7 @@ def build_train_dataset(
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
clipper=clipper,
|
||||
preprocessor=preprocessor,
|
||||
preprocessor=preprocessor.cpu(),
|
||||
audio_augmentation=audio_augmentation,
|
||||
spectrogram_augmentation=spectrogram_augmentation,
|
||||
)
|
||||
@ -305,7 +305,7 @@ def build_val_dataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
preprocessor=preprocessor.cpu(),
|
||||
clipper=clipper,
|
||||
)
|
||||
|
||||
|
||||
@ -73,8 +73,9 @@ def run_train(
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
# NOTE: Create a new preprocessor instead of using the one from the model
|
||||
# to avoid issues with device placement.
|
||||
if model is not None:
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=model_config.preprocess,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user