updated extract spec slices bit

This commit is contained in:
Santiago Martinez 2023-04-07 15:20:15 -06:00
parent 5e08607eec
commit b8bbfe8ad4
2 changed files with 16 additions and 1 deletions

View File

@ -731,7 +731,7 @@ def process_file(
config["chunk_size"],
):
# Run detection model on chunk
pred_nms, features, spec_np = _process_audio_array(
pred_nms, features, spec = _process_audio_array(
audio,
sampling_rate,
model,
@ -739,6 +739,9 @@ def process_file(
device,
)
# convert to numpy
spec_np = spec.detach().cpu().numpy()
# add chunk time to start and end times
pred_nms["start_times"] += chunk_time
pred_nms["end_times"] += chunk_time
@ -756,6 +759,7 @@ def process_file(
cnn_feats.append(features[0])
if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms, config)
)

View File

@ -251,3 +251,14 @@ def test_postprocess_model_outputs():
assert isinstance(features, np.ndarray)
assert features.shape[0] == len(predictions)
assert features.shape[1] == 32
def test_process_file_with_spec_slices():
"""Test process file returns spec slices."""
config = api.get_config(spec_slices=True)
results = api.process_file(TEST_DATA[0], config=config)
detections = results["pred_dict"]["annotation"]
assert "spec_slices" in results
assert isinstance(results["spec_slices"], list)
assert len(results["spec_slices"]) == len(detections)