Add tests for detection threshold

This commit is contained in:
mbsantiago 2026-03-28 14:42:15 +00:00
parent 8e33473b4e
commit c5b2446978

View File

@ -243,7 +243,7 @@ def test_user_can_load_checkpoint_with_new_targets(
detector = cast(Detector, api.model.detector) detector = cast(Detector, api.model.detector)
classifier_head = cast(ClassifierHead, detector.classifier_head) classifier_head = cast(ClassifierHead, detector.classifier_head)
assert api.targets.config == sample_targets.config assert api.targets.config == sample_targets.config # type: ignore
assert detector.num_classes == len(sample_targets.class_names) assert detector.num_classes == len(sample_targets.class_names)
assert ( assert (
classifier_head.classifier.out_channels classifier_head.classifier.out_channels
@ -399,6 +399,61 @@ def test_process_file_uses_resolved_batch_size_by_default(
assert captured["batch_size"] == api_v2.inference_config.loader.batch_size assert captured["batch_size"] == api_v2.inference_config.loader.batch_size
def test_detection_threshold_override_changes_process_file_results(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: users can override threshold in process_file."""
default_prediction = api_v2.process_file(example_audio_files[0])
strict_prediction = api_v2.process_file(
example_audio_files[0],
detection_threshold=1.0,
)
assert len(strict_prediction.detections) <= len(
default_prediction.detections
)
def test_detection_threshold_override_is_ephemeral_in_process_file(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: per-call threshold override does not change defaults."""
before = api_v2.process_file(example_audio_files[0])
_ = api_v2.process_file(
example_audio_files[0],
detection_threshold=1.0,
)
after = api_v2.process_file(example_audio_files[0])
assert len(before.detections) == len(after.detections)
np.testing.assert_allclose(
[det.detection_score for det in before.detections],
[det.detection_score for det in after.detections],
atol=1e-6,
)
def test_detection_threshold_override_changes_spectrogram_results(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: threshold override works in spectrogram path."""
audio = api_v2.load_audio(example_audio_files[0])
spec = api_v2.generate_spectrogram(audio)
default_detections = api_v2.process_spectrogram(spec)
strict_detections = api_v2.process_spectrogram(
spec,
detection_threshold=1.0,
)
assert len(strict_detections) <= len(default_detections)
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None: def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
"""User story: call-level overrides do not mutate resolved defaults.""" """User story: call-level overrides do not mutate resolved defaults."""