Ensure preprocessor is in CPU

This commit is contained in:
mbsantiago 2026-03-19 00:09:30 +00:00
parent 0d90cb5cc3
commit fb3dc3eaf0
3 changed files with 8 additions and 4 deletions

View File

@ -29,3 +29,6 @@ class PreprocessorProtocol(Protocol):
def process_numpy(self, wav: np.ndarray) -> np.ndarray: def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy() return self(torch.tensor(wav)).numpy()
def cpu(self) -> "PreprocessorProtocol":
return self

View File

@ -272,7 +272,7 @@ def build_train_dataset(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
clipper=clipper, clipper=clipper,
preprocessor=preprocessor, preprocessor=preprocessor.cpu(),
audio_augmentation=audio_augmentation, audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation, spectrogram_augmentation=spectrogram_augmentation,
) )
@ -305,7 +305,7 @@ def build_val_dataset(
clip_annotations, clip_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor.cpu(),
clipper=clipper, clipper=clipper,
) )

View File

@ -73,8 +73,9 @@ def run_train(
audio_loader = audio_loader or build_audio_loader(config=audio_config) audio_loader = audio_loader or build_audio_loader(config=audio_config)
# NOTE: Create a new preprocessor instead of using the one from the model if model is not None:
# to avoid issues with device placement. preprocessor = preprocessor or model.preprocessor
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=model_config.preprocess, config=model_config.preprocess,