From c5b2446978271a6e7d6eda04c197d4d0bfadfefc Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 28 Mar 2026 14:42:15 +0000 Subject: [PATCH] Add tests for detection threshold --- tests/test_api_v2/test_api_v2.py | 57 +++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index 71a102a..a85dfd1 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -243,7 +243,7 @@ def test_user_can_load_checkpoint_with_new_targets( detector = cast(Detector, api.model.detector) classifier_head = cast(ClassifierHead, detector.classifier_head) - assert api.targets.config == sample_targets.config + assert api.targets.config == sample_targets.config # type: ignore assert detector.num_classes == len(sample_targets.class_names) assert ( classifier_head.classifier.out_channels @@ -399,6 +399,61 @@ def test_process_file_uses_resolved_batch_size_by_default( assert captured["batch_size"] == api_v2.inference_config.loader.batch_size +def test_detection_threshold_override_changes_process_file_results( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: users can override threshold in process_file.""" + + default_prediction = api_v2.process_file(example_audio_files[0]) + strict_prediction = api_v2.process_file( + example_audio_files[0], + detection_threshold=1.0, + ) + + assert len(strict_prediction.detections) <= len( + default_prediction.detections + ) + + +def test_detection_threshold_override_is_ephemeral_in_process_file( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: per-call threshold override does not change defaults.""" + + before = api_v2.process_file(example_audio_files[0]) + _ = api_v2.process_file( + example_audio_files[0], + detection_threshold=1.0, + ) + after = api_v2.process_file(example_audio_files[0]) + + assert len(before.detections) == len(after.detections) + np.testing.assert_allclose( + [det.detection_score for det in before.detections], + [det.detection_score for det in after.detections], + atol=1e-6, + ) + + +def test_detection_threshold_override_changes_spectrogram_results( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: threshold override works in spectrogram path.""" + + audio = api_v2.load_audio(example_audio_files[0]) + spec = api_v2.generate_spectrogram(audio) + default_detections = api_v2.process_spectrogram(spec) + strict_detections = api_v2.process_spectrogram( + spec, + detection_threshold=1.0, + ) + + assert len(strict_detections) <= len(default_detections) + + def test_per_call_overrides_are_ephemeral(monkeypatch) -> None: """User story: call-level overrides do not mutate resolved defaults."""