This commit is contained in:
Santiago Martinez Balvanera 2025-06-19 00:46:43 +01:00
parent 434fc652a2
commit a62f07ebdd

View File

@ -84,7 +84,7 @@ def features_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=features.detach().numpy(),
data=features.detach().cpu().numpy(),
dims=[
Dimensions.feature.value,
Dimensions.frequency.value,
@ -157,7 +157,7 @@ def detection_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=detection.squeeze(dim=0).detach().numpy(),
data=detection.squeeze(dim=0).detach().cpu().numpy(),
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
@ -233,7 +233,7 @@ def classification_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=classes.detach().numpy(),
data=classes.detach().cpu().numpy(),
dims=[
"category",
Dimensions.frequency.value,
@ -302,7 +302,7 @@ def sizes_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=sizes.detach().numpy(),
data=sizes.detach().cpu().numpy(),
dims=[
"dimension",
Dimensions.frequency.value,