From b8bbfe8ad489778d4fbd24a8f5ecbb7c10f90973 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Fri, 7 Apr 2023 15:20:15 -0600 Subject: [PATCH] updated extract spec slices bit --- batdetect2/utils/detector_utils.py | 6 +++++- tests/test_api.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 1d92b26..c5dd4d0 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -731,7 +731,7 @@ def process_file( config["chunk_size"], ): # Run detection model on chunk - pred_nms, features, spec_np = _process_audio_array( + pred_nms, features, spec = _process_audio_array( audio, sampling_rate, model, @@ -739,6 +739,9 @@ def process_file( device, ) + # convert to numpy + spec_np = spec.detach().cpu().numpy() + # add chunk time to start and end times pred_nms["start_times"] += chunk_time pred_nms["end_times"] += chunk_time @@ -756,6 +759,7 @@ def process_file( cnn_feats.append(features[0]) if config["spec_slices"]: + # FIX: This is not currently working. Returns empty slices spec_slices.extend( feats.extract_spec_slices(spec_np, pred_nms, config) ) diff --git a/tests/test_api.py b/tests/test_api.py index d20ebf4..942a1f1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -251,3 +251,14 @@ def test_postprocess_model_outputs(): assert isinstance(features, np.ndarray) assert features.shape[0] == len(predictions) assert features.shape[1] == 32 + + +def test_process_file_with_spec_slices(): + """Test process file returns spec slices.""" + config = api.get_config(spec_slices=True) + results = api.process_file(TEST_DATA[0], config=config) + detections = results["pred_dict"]["annotation"] + + assert "spec_slices" in results + assert isinstance(results["spec_slices"], list) + assert len(results["spec_slices"]) == len(detections)