mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
updated extract spec slices bit
This commit is contained in:
parent
5e08607eec
commit
b8bbfe8ad4
@ -731,7 +731,7 @@ def process_file(
|
|||||||
config["chunk_size"],
|
config["chunk_size"],
|
||||||
):
|
):
|
||||||
# Run detection model on chunk
|
# Run detection model on chunk
|
||||||
pred_nms, features, spec_np = _process_audio_array(
|
pred_nms, features, spec = _process_audio_array(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
model,
|
model,
|
||||||
@ -739,6 +739,9 @@ def process_file(
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# convert to numpy
|
||||||
|
spec_np = spec.detach().cpu().numpy()
|
||||||
|
|
||||||
# add chunk time to start and end times
|
# add chunk time to start and end times
|
||||||
pred_nms["start_times"] += chunk_time
|
pred_nms["start_times"] += chunk_time
|
||||||
pred_nms["end_times"] += chunk_time
|
pred_nms["end_times"] += chunk_time
|
||||||
@ -756,6 +759,7 @@ def process_file(
|
|||||||
cnn_feats.append(features[0])
|
cnn_feats.append(features[0])
|
||||||
|
|
||||||
if config["spec_slices"]:
|
if config["spec_slices"]:
|
||||||
|
# FIX: This is not currently working. Returns empty slices
|
||||||
spec_slices.extend(
|
spec_slices.extend(
|
||||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
feats.extract_spec_slices(spec_np, pred_nms, config)
|
||||||
)
|
)
|
||||||
|
@ -251,3 +251,14 @@ def test_postprocess_model_outputs():
|
|||||||
assert isinstance(features, np.ndarray)
|
assert isinstance(features, np.ndarray)
|
||||||
assert features.shape[0] == len(predictions)
|
assert features.shape[0] == len(predictions)
|
||||||
assert features.shape[1] == 32
|
assert features.shape[1] == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_with_spec_slices():
|
||||||
|
"""Test process file returns spec slices."""
|
||||||
|
config = api.get_config(spec_slices=True)
|
||||||
|
results = api.process_file(TEST_DATA[0], config=config)
|
||||||
|
detections = results["pred_dict"]["annotation"]
|
||||||
|
|
||||||
|
assert "spec_slices" in results
|
||||||
|
assert isinstance(results["spec_slices"], list)
|
||||||
|
assert len(results["spec_slices"]) == len(detections)
|
||||||
|
Loading…
Reference in New Issue
Block a user