mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add tests for detection threshold
This commit is contained in:
parent
8e33473b4e
commit
c5b2446978
@ -243,7 +243,7 @@ def test_user_can_load_checkpoint_with_new_targets(
|
|||||||
detector = cast(Detector, api.model.detector)
|
detector = cast(Detector, api.model.detector)
|
||||||
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||||
|
|
||||||
assert api.targets.config == sample_targets.config
|
assert api.targets.config == sample_targets.config # type: ignore
|
||||||
assert detector.num_classes == len(sample_targets.class_names)
|
assert detector.num_classes == len(sample_targets.class_names)
|
||||||
assert (
|
assert (
|
||||||
classifier_head.classifier.out_channels
|
classifier_head.classifier.out_channels
|
||||||
@ -399,6 +399,61 @@ def test_process_file_uses_resolved_batch_size_by_default(
|
|||||||
assert captured["batch_size"] == api_v2.inference_config.loader.batch_size
|
assert captured["batch_size"] == api_v2.inference_config.loader.batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_detection_threshold_override_changes_process_file_results(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: users can override threshold in process_file."""
|
||||||
|
|
||||||
|
default_prediction = api_v2.process_file(example_audio_files[0])
|
||||||
|
strict_prediction = api_v2.process_file(
|
||||||
|
example_audio_files[0],
|
||||||
|
detection_threshold=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(strict_prediction.detections) <= len(
|
||||||
|
default_prediction.detections
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detection_threshold_override_is_ephemeral_in_process_file(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: per-call threshold override does not change defaults."""
|
||||||
|
|
||||||
|
before = api_v2.process_file(example_audio_files[0])
|
||||||
|
_ = api_v2.process_file(
|
||||||
|
example_audio_files[0],
|
||||||
|
detection_threshold=1.0,
|
||||||
|
)
|
||||||
|
after = api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
|
assert len(before.detections) == len(after.detections)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
[det.detection_score for det in before.detections],
|
||||||
|
[det.detection_score for det in after.detections],
|
||||||
|
atol=1e-6,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detection_threshold_override_changes_spectrogram_results(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: threshold override works in spectrogram path."""
|
||||||
|
|
||||||
|
audio = api_v2.load_audio(example_audio_files[0])
|
||||||
|
spec = api_v2.generate_spectrogram(audio)
|
||||||
|
default_detections = api_v2.process_spectrogram(spec)
|
||||||
|
strict_detections = api_v2.process_spectrogram(
|
||||||
|
spec,
|
||||||
|
detection_threshold=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(strict_detections) <= len(default_detections)
|
||||||
|
|
||||||
|
|
||||||
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
||||||
"""User story: call-level overrides do not mutate resolved defaults."""
|
"""User story: call-level overrides do not mutate resolved defaults."""
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user