diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 8e8cae5..a714ec6 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -484,11 +484,19 @@ class BatDetect2API: transform=output_transform, ) + # NOTE: Build separate instances of preprocessor and postprocessor + # to avoid device mismatch errors model = build_model( config=config.model, - targets=targets, - preprocessor=preprocessor, - postprocessor=postprocessor, + targets=build_targets(config=config.model.targets), + preprocessor=build_preprocessor( + input_samplerate=audio_loader.samplerate, + config=config.model.preprocess, + ), + postprocessor=build_postprocessor( + preprocessor, + config=config.model.postprocess, + ), ) return cls( diff --git a/src/batdetect2/preprocess/types.py b/src/batdetect2/preprocess/types.py index dbf3758..39485e9 100644 --- a/src/batdetect2/preprocess/types.py +++ b/src/batdetect2/preprocess/types.py @@ -29,6 +29,3 @@ class PreprocessorProtocol(Protocol): def process_numpy(self, wav: np.ndarray) -> np.ndarray: return self(torch.tensor(wav)).numpy() - - def cpu(self) -> "PreprocessorProtocol": - return self diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 570106d..e9e1883 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -272,7 +272,7 @@ def build_train_dataset( audio_loader=audio_loader, labeller=labeller, clipper=clipper, - preprocessor=preprocessor.cpu(), + preprocessor=preprocessor, audio_augmentation=audio_augmentation, spectrogram_augmentation=spectrogram_augmentation, ) @@ -305,7 +305,7 @@ def build_val_dataset( clip_annotations, audio_loader=audio_loader, labeller=labeller, - preprocessor=preprocessor.cpu(), + preprocessor=preprocessor, clipper=clipper, ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 1f1745a..2842738 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -73,9 +73,6 @@ def run_train( audio_loader = audio_loader or build_audio_loader(config=audio_config) - if model is not None: - preprocessor = preprocessor or model.preprocessor - preprocessor = preprocessor or build_preprocessor( input_samplerate=audio_loader.samplerate, config=model_config.preprocess,