diff --git a/batdetect2/postprocess/remapping.py b/batdetect2/postprocess/remapping.py index 51560ea..7112046 100644 --- a/batdetect2/postprocess/remapping.py +++ b/batdetect2/postprocess/remapping.py @@ -84,7 +84,7 @@ def features_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=features.detach().numpy(), + data=features.detach().cpu().numpy(), dims=[ Dimensions.feature.value, Dimensions.frequency.value, @@ -157,7 +157,7 @@ def detection_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=detection.squeeze(dim=0).detach().numpy(), + data=detection.squeeze(dim=0).detach().cpu().numpy(), dims=[ Dimensions.frequency.value, Dimensions.time.value, @@ -233,7 +233,7 @@ def classification_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=classes.detach().numpy(), + data=classes.detach().cpu().numpy(), dims=[ "category", Dimensions.frequency.value, @@ -302,7 +302,7 @@ def sizes_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=sizes.detach().numpy(), + data=sizes.detach().cpu().numpy(), dims=[ "dimension", Dimensions.frequency.value,