diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 1d92b26..c5dd4d0 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -731,7 +731,7 @@ def process_file( config["chunk_size"], ): # Run detection model on chunk - pred_nms, features, spec_np = _process_audio_array( + pred_nms, features, spec = _process_audio_array( audio, sampling_rate, model, @@ -739,6 +739,9 @@ def process_file( device, ) + # convert to numpy + spec_np = spec.detach().cpu().numpy() + # add chunk time to start and end times pred_nms["start_times"] += chunk_time pred_nms["end_times"] += chunk_time @@ -756,6 +759,7 @@ def process_file( cnn_feats.append(features[0]) if config["spec_slices"]: + # FIX: This is not currently working. Returns empty slices spec_slices.extend( feats.extract_spec_slices(spec_np, pred_nms, config) ) diff --git a/tests/test_api.py b/tests/test_api.py index d20ebf4..942a1f1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -251,3 +251,14 @@ def test_postprocess_model_outputs(): assert isinstance(features, np.ndarray) assert features.shape[0] == len(predictions) assert features.shape[1] == 32 + + +def test_process_file_with_spec_slices(): + """Test process file returns spec slices.""" + config = api.get_config(spec_slices=True) + results = api.process_file(TEST_DATA[0], config=config) + detections = results["pred_dict"]["annotation"] + + assert "spec_slices" in results + assert isinstance(results["spec_slices"], list) + assert len(results["spec_slices"]) == len(detections)