diff --git a/src/batdetect2/preprocess/types.py b/src/batdetect2/preprocess/types.py index 39485e9..dbf3758 100644 --- a/src/batdetect2/preprocess/types.py +++ b/src/batdetect2/preprocess/types.py @@ -29,3 +29,6 @@ class PreprocessorProtocol(Protocol): def process_numpy(self, wav: np.ndarray) -> np.ndarray: return self(torch.tensor(wav)).numpy() + + def cpu(self) -> "PreprocessorProtocol": + return self diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index e9e1883..570106d 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -272,7 +272,7 @@ def build_train_dataset( audio_loader=audio_loader, labeller=labeller, clipper=clipper, - preprocessor=preprocessor, + preprocessor=preprocessor.cpu(), audio_augmentation=audio_augmentation, spectrogram_augmentation=spectrogram_augmentation, ) @@ -305,7 +305,7 @@ def build_val_dataset( clip_annotations, audio_loader=audio_loader, labeller=labeller, - preprocessor=preprocessor, + preprocessor=preprocessor.cpu(), clipper=clipper, ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 8845fa8..1f1745a 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -73,8 +73,9 @@ def run_train( audio_loader = audio_loader or build_audio_loader(config=audio_config) - # NOTE: Create a new preprocessor instead of using the one from the model - # to avoid issues with device placement. + if model is not None: + preprocessor = preprocessor or model.preprocessor + preprocessor = preprocessor or build_preprocessor( input_samplerate=audio_loader.samplerate, config=model_config.preprocess,