diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index 27e45df..b1c0a3a 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -24,7 +24,6 @@ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.typing import ModelOutput from batdetect2.typing.postprocess import ( BatDetect2Prediction, - DetectionsArray, DetectionsTensor, PostprocessorProtocol, RawPrediction, diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index 42afb4a..f75ca81 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -96,12 +96,12 @@ class DetectionsTensor(NamedTuple): def numpy(self) -> DetectionsArray: return DetectionsArray( - scores=self.scores.detach().numpy(), - sizes=self.sizes.detach().numpy(), - class_scores=self.class_scores.detach().numpy(), - times=self.times.detach().numpy(), - frequencies=self.frequencies.detach().numpy(), - features=self.features.detach().numpy(), + scores=self.scores.detach().cpu().numpy(), + sizes=self.sizes.detach().cpu().numpy(), + class_scores=self.class_scores.detach().cpu().numpy(), + times=self.times.detach().cpu().numpy(), + frequencies=self.frequencies.detach().cpu().numpy(), + features=self.features.detach().cpu().numpy(), )