mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
updated extract spec slices bit
This commit is contained in:
parent
5e08607eec
commit
b8bbfe8ad4
@ -731,7 +731,7 @@ def process_file(
|
||||
config["chunk_size"],
|
||||
):
|
||||
# Run detection model on chunk
|
||||
pred_nms, features, spec_np = _process_audio_array(
|
||||
pred_nms, features, spec = _process_audio_array(
|
||||
audio,
|
||||
sampling_rate,
|
||||
model,
|
||||
@ -739,6 +739,9 @@ def process_file(
|
||||
device,
|
||||
)
|
||||
|
||||
# convert to numpy
|
||||
spec_np = spec.detach().cpu().numpy()
|
||||
|
||||
# add chunk time to start and end times
|
||||
pred_nms["start_times"] += chunk_time
|
||||
pred_nms["end_times"] += chunk_time
|
||||
@ -756,6 +759,7 @@ def process_file(
|
||||
cnn_feats.append(features[0])
|
||||
|
||||
if config["spec_slices"]:
|
||||
# FIX: This is not currently working. Returns empty slices
|
||||
spec_slices.extend(
|
||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
||||
)
|
||||
|
@ -251,3 +251,14 @@ def test_postprocess_model_outputs():
|
||||
assert isinstance(features, np.ndarray)
|
||||
assert features.shape[0] == len(predictions)
|
||||
assert features.shape[1] == 32
|
||||
|
||||
|
||||
def test_process_file_with_spec_slices():
|
||||
"""Test process file returns spec slices."""
|
||||
config = api.get_config(spec_slices=True)
|
||||
results = api.process_file(TEST_DATA[0], config=config)
|
||||
detections = results["pred_dict"]["annotation"]
|
||||
|
||||
assert "spec_slices" in results
|
||||
assert isinstance(results["spec_slices"], list)
|
||||
assert len(results["spec_slices"]) == len(detections)
|
||||
|
Loading…
Reference in New Issue
Block a user