This commit is contained in:
Santiago Martinez Balvanera 2025-06-19 00:46:43 +01:00
parent 434fc652a2
commit a62f07ebdd

View File

@ -84,7 +84,7 @@ def features_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=features.detach().numpy(), data=features.detach().cpu().numpy(),
dims=[ dims=[
Dimensions.feature.value, Dimensions.feature.value,
Dimensions.frequency.value, Dimensions.frequency.value,
@ -157,7 +157,7 @@ def detection_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=detection.squeeze(dim=0).detach().numpy(), data=detection.squeeze(dim=0).detach().cpu().numpy(),
dims=[ dims=[
Dimensions.frequency.value, Dimensions.frequency.value,
Dimensions.time.value, Dimensions.time.value,
@ -233,7 +233,7 @@ def classification_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=classes.detach().numpy(), data=classes.detach().cpu().numpy(),
dims=[ dims=[
"category", "category",
Dimensions.frequency.value, Dimensions.frequency.value,
@ -302,7 +302,7 @@ def sizes_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=sizes.detach().numpy(), data=sizes.detach().cpu().numpy(),
dims=[ dims=[
"dimension", "dimension",
Dimensions.frequency.value, Dimensions.frequency.value,