mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add tests for detection threshold
This commit is contained in:
parent
8e33473b4e
commit
c5b2446978
@ -243,7 +243,7 @@ def test_user_can_load_checkpoint_with_new_targets(
|
||||
detector = cast(Detector, api.model.detector)
|
||||
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||
|
||||
assert api.targets.config == sample_targets.config
|
||||
assert api.targets.config == sample_targets.config # type: ignore
|
||||
assert detector.num_classes == len(sample_targets.class_names)
|
||||
assert (
|
||||
classifier_head.classifier.out_channels
|
||||
@ -399,6 +399,61 @@ def test_process_file_uses_resolved_batch_size_by_default(
|
||||
assert captured["batch_size"] == api_v2.inference_config.loader.batch_size
|
||||
|
||||
|
||||
def test_detection_threshold_override_changes_process_file_results(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
) -> None:
|
||||
"""User story: users can override threshold in process_file."""
|
||||
|
||||
default_prediction = api_v2.process_file(example_audio_files[0])
|
||||
strict_prediction = api_v2.process_file(
|
||||
example_audio_files[0],
|
||||
detection_threshold=1.0,
|
||||
)
|
||||
|
||||
assert len(strict_prediction.detections) <= len(
|
||||
default_prediction.detections
|
||||
)
|
||||
|
||||
|
||||
def test_detection_threshold_override_is_ephemeral_in_process_file(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
) -> None:
|
||||
"""User story: per-call threshold override does not change defaults."""
|
||||
|
||||
before = api_v2.process_file(example_audio_files[0])
|
||||
_ = api_v2.process_file(
|
||||
example_audio_files[0],
|
||||
detection_threshold=1.0,
|
||||
)
|
||||
after = api_v2.process_file(example_audio_files[0])
|
||||
|
||||
assert len(before.detections) == len(after.detections)
|
||||
np.testing.assert_allclose(
|
||||
[det.detection_score for det in before.detections],
|
||||
[det.detection_score for det in after.detections],
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
|
||||
def test_detection_threshold_override_changes_spectrogram_results(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
) -> None:
|
||||
"""User story: threshold override works in spectrogram path."""
|
||||
|
||||
audio = api_v2.load_audio(example_audio_files[0])
|
||||
spec = api_v2.generate_spectrogram(audio)
|
||||
default_detections = api_v2.process_spectrogram(spec)
|
||||
strict_detections = api_v2.process_spectrogram(
|
||||
spec,
|
||||
detection_threshold=1.0,
|
||||
)
|
||||
|
||||
assert len(strict_detections) <= len(default_detections)
|
||||
|
||||
|
||||
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
||||
"""User story: call-level overrides do not mutate resolved defaults."""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user