updated extract spec slices bit

This commit is contained in:
Santiago Martinez 2023-04-07 15:20:15 -06:00
parent 5e08607eec
commit b8bbfe8ad4
2 changed files with 16 additions and 1 deletions

View File

@ -731,7 +731,7 @@ def process_file(
config["chunk_size"], config["chunk_size"],
): ):
# Run detection model on chunk # Run detection model on chunk
pred_nms, features, spec_np = _process_audio_array( pred_nms, features, spec = _process_audio_array(
audio, audio,
sampling_rate, sampling_rate,
model, model,
@ -739,6 +739,9 @@ def process_file(
device, device,
) )
# convert to numpy
spec_np = spec.detach().cpu().numpy()
# add chunk time to start and end times # add chunk time to start and end times
pred_nms["start_times"] += chunk_time pred_nms["start_times"] += chunk_time
pred_nms["end_times"] += chunk_time pred_nms["end_times"] += chunk_time
@ -756,6 +759,7 @@ def process_file(
cnn_feats.append(features[0]) cnn_feats.append(features[0])
if config["spec_slices"]: if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend( spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms, config) feats.extract_spec_slices(spec_np, pred_nms, config)
) )

View File

@ -251,3 +251,14 @@ def test_postprocess_model_outputs():
assert isinstance(features, np.ndarray) assert isinstance(features, np.ndarray)
assert features.shape[0] == len(predictions) assert features.shape[0] == len(predictions)
assert features.shape[1] == 32 assert features.shape[1] == 32
def test_process_file_with_spec_slices():
"""Test process file returns spec slices."""
config = api.get_config(spec_slices=True)
results = api.process_file(TEST_DATA[0], config=config)
detections = results["pred_dict"]["annotation"]
assert "spec_slices" in results
assert isinstance(results["spec_slices"], list)
assert len(results["spec_slices"]) == len(detections)