Compare commits

...

6 Commits

Author SHA1 Message Date
mbsantiago
c7b110feeb Limit number of parallel processes to match predictions and annotations 2025-08-12 19:06:44 +01:00
mbsantiago
7d92ec772b Fix number formatting in gallery plot 2025-08-12 18:49:21 +01:00
mbsantiago
0bb703f3c1 Add call type alias 2025-08-12 18:45:54 +01:00
mbsantiago
81c7b68b0b Update dev deps 2025-08-12 18:45:43 +01:00
mbsantiago
59aaf07af5 Update tests after incorporating term registry from soundevent 2025-08-12 18:44:18 +01:00
mbsantiago
51d0a49da9 Improve performance of postprocessing code 2025-08-12 17:47:17 +01:00
12 changed files with 265 additions and 302 deletions

View File

@ -85,6 +85,7 @@ dev = [
"ty>=0.0.1a12",
"rust-just>=1.40.0",
"pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0",
]
dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"]

View File

@ -179,7 +179,7 @@ def plot_false_positive_match(
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,

View File

@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol):
]
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol):
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
),
)
@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)

View File

@ -33,8 +33,7 @@ import xarray as xr
from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
from batdetect2.targets.types import TargetProtocol
__all__ = [
"convert_xr_dataset_to_raw_prediction",
@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction(
"""
detections = []
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
categories = detection_dataset.category.values
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()
for score, class_scores, time, freq, dims, feats in zip(
detection_dataset["scores"].values,
detection_dataset["classes"].values,
detection_dataset["time"].values,
detection_dataset["frequency"].values,
detection_dataset["dimensions"].values,
detection_dataset["features"].values,
):
highest_scoring_class = categories[class_scores.argmax()]
geom = geometry_decoder(
(det_info.time, det_info.frequency),
det_info.dimensions,
(time, freq),
dims,
class_name=highest_scoring_class,
)
detections.append(
RawPrediction(
detection_score=det_info.scores,
detection_score=score,
geometry=geom,
class_scores=det_info.classes,
features=det_info.features,
class_scores=class_scores,
features=feats,
)
)
@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction(
def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction],
clip: data.Clip,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
targets: TargetProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction(
convert_raw_prediction_to_sound_event_prediction(
prediction,
recording=clip.recording,
sound_event_decoder=sound_event_decoder,
generic_class_tags=generic_class_tags,
targets=targets,
classification_threshold=classification_threshold,
top_class_only=top_class_only,
)
@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction(
def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
recording: data.Recording,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
targets: TargetProtocol,
classification_threshold: Optional[
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [
*get_generic_tags(
raw_prediction.detection_score,
generic_class_tags=generic_class_tags,
generic_class_tags=targets.generic_class_tags,
),
*get_class_tags(
raw_prediction.class_scores,
sound_event_decoder,
targets=targets,
top_class_only=top_class_only,
threshold=classification_threshold,
),
@ -297,7 +298,7 @@ def get_generic_tags(
]
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features.
Parameters
@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
return [
data.Feature(
term=data.Term(
name=f"batdetect2:{feat_name}",
label=feat_name,
name=f"batdetect2:f{index}",
label=f"BatDetect Feature {index}",
definition="Automatically extracted features by BatDetect2",
),
value=value,
)
for feat_name, value in iterate_over_array(features)
for index, value in enumerate(features)
]
def get_class_tags(
class_scores: xr.DataArray,
sound_event_decoder: SoundEventDecoder,
class_scores: np.ndarray,
targets: TargetProtocol,
top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]:
@ -367,11 +368,13 @@ def get_class_tags(
"""
tags = []
if threshold is not None:
class_scores = class_scores.where(class_scores > threshold, drop=True)
for class_name, score in _iterate_sorted(
class_scores, targets.class_names
):
if threshold is not None and score < threshold:
continue
for class_name, score in _iterate_sorted(class_scores):
class_tags = sound_event_decoder(class_name)
class_tags = targets.decode_class(class_name)
for tag in class_tags:
tags.append(
@ -387,9 +390,7 @@ def get_class_tags(
return tags
def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name].values
indices = np.argsort(-array.values)
def _iterate_sorted(array: np.ndarray, class_names: List[str]):
indices = np.argsort(-array)
for index in indices:
yield str(coords[index]), float(array.values[index])
yield str(class_names[index]), float(array[index])

View File

@ -14,6 +14,7 @@ system that deal with model predictions.
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr
from soundevent import data
@ -72,8 +73,8 @@ class RawPrediction(NamedTuple):
geometry: data.Geometry
detection_score: float
class_scores: xr.DataArray
features: xr.DataArray
class_scores: np.ndarray
features: np.ndarray
@dataclass

View File

@ -91,6 +91,7 @@ terms.register_term_set(
"individual": individual.name,
"event": call_type.name,
"source": data_source.name,
"call_type": call_type.name,
},
),
override_existing=True,

View File

@ -1,3 +1,4 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
@ -15,7 +16,10 @@ from batdetect2.evaluate.match import (
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule
@ -114,33 +118,51 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
dataset = self.get_dataset(trainer)
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
self._matches.extend(
_get_batch_clips_and_predictions(
batch,
outputs,
dataset=self.get_dataset(trainer),
postprocessor=pl_module.postprocessor,
targets=pl_module.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs,
clips,
)
def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions(
outputs,
clips,
)
return [
(clip_annotation, clip_predictions)
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
):
self._matches.append((clip_annotation, clip_predictions))
)
]
def _match_all_collected_examples(
@ -150,7 +172,8 @@ def _match_all_collected_examples(
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions")
with Pool() as p:
cpu_count = os.cpu_count() or 1
with Pool(processes=min(cpu_count, 4)) as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,

View File

@ -254,11 +254,13 @@ class TestLoadBatDetect2Files:
assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1
assert clip_ann.notes[0].message == "Standard notes."
clip_tag = data.find_tag(clip_ann.tags, "Class")
clip_tag = data.find_tag(clip_ann.tags, term_label="Class")
assert clip_tag is not None
assert clip_tag.value == "Myotis"
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
recording_tag = data.find_tag(
clip_ann.clip.recording.tags, term_label="Class"
)
assert recording_tag is not None
assert recording_tag.value == "Myotis"
@ -271,15 +273,15 @@ class TestLoadBatDetect2Files:
40000,
]
se_class_tag = data.find_tag(se_ann.tags, "Class")
se_class_tag = data.find_tag(se_ann.tags, term_label="Class")
assert se_class_tag is not None
assert se_class_tag.value == "Myotis"
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type")
assert se_event_tag is not None
assert se_event_tag.value == "Echolocation"
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual")
assert se_individual_tag is not None
assert se_individual_tag.value == "0"
@ -439,7 +441,7 @@ class TestLoadBatDetect2Merged:
assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class")
assert clip_class_tag is not None
assert clip_class_tag.value == "Myotis"

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional
import numpy as np
import pytest
@ -16,35 +16,11 @@ from batdetect2.postprocess.decoding import (
get_prediction_features,
)
from batdetect2.postprocess.types import RawPrediction
from batdetect2.targets.types import TargetProtocol
@pytest.fixture
def dummy_geometry_builder():
"""A simple GeometryBuilder that creates a BBox around the point."""
def _builder(
position: Tuple[float, float],
dimensions: xr.DataArray,
class_name: Optional[str] = None,
) -> data.BoundingBox:
time, freq = position
width = dimensions.sel(dimension="width").item()
height = dimensions.sel(dimension="height").item()
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return _builder
@pytest.fixture
def dummy_sound_event_decoder():
"""A simple SoundEventDecoder mapping names to tags."""
def dummy_targets() -> TargetProtocol:
tag_map = {
"bat": [
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
@ -57,18 +33,56 @@ def dummy_sound_event_decoder():
],
}
def _decoder(class_name: str) -> List[data.Tag]:
return tag_map.get(class_name.lower(), [])
class DummyTargets(TargetProtocol):
class_names = [
"bat",
"noise",
"unknown",
]
return _decoder
dimension_names = ["width", "height"]
generic_class_tags = [
data.Tag(
term=data.term_from_key(key="detector"), value="batdetect2"
)
]
@pytest.fixture
def generic_tags() -> List[data.Tag]:
"""Sample generic tags."""
return [
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
]
def filter(self, sound_event: data.SoundEventAnnotation):
return True
def transform(self, sound_event: data.SoundEventAnnotation):
return sound_event
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
return "bat"
def decode_class(self, class_label: str) -> List[data.Tag]:
return tag_map.get(class_label.lower(), [])
def encode_roi(self, sound_event: data.SoundEventAnnotation):
return np.array([0.0, 0.0]), np.array([0.0, 0.0])
def decode_roi(
self,
position,
size: np.ndarray,
class_name: Optional[str] = None,
):
time, freq = position
width, height = size
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return DummyTargets()
@pytest.fixture
@ -156,7 +170,7 @@ def empty_detection_dataset() -> xr.Dataset:
"""Creates an empty detection dataset with correct structure."""
detection_coords = {
"time": ("detection", np.array([], dtype=np.float64)),
"freq": ("detection", np.array([], dtype=np.float64)),
"frequency": ("detection", np.array([], dtype=np.float64)),
}
scores = xr.DataArray(
np.array([], dtype=np.float64),
@ -184,7 +198,7 @@ def empty_detection_dataset() -> xr.Dataset:
)
return xr.Dataset(
{
"score": scores,
"scores": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
@ -215,8 +229,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
300 + 16 / 2,
]
),
class_scores=pred1_classes,
features=pred1_features,
class_scores=pred1_classes.values,
features=pred1_features.values,
)
pred2_classes = xr.DataArray(
@ -237,8 +251,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
200 + 12 / 2,
]
),
class_scores=pred2_classes,
features=pred2_features,
class_scores=pred2_classes.values,
features=pred2_features.values,
)
pred3_classes = xr.DataArray(
@ -259,18 +273,17 @@ def sample_raw_predictions() -> List[RawPrediction]:
60.0,
]
),
class_scores=pred3_classes,
features=pred3_features,
class_scores=pred3_classes.values,
features=pred3_features.values,
)
return [pred1, pred2, pred3]
def test_convert_xr_dataset_basic(
sample_detection_dataset, dummy_geometry_builder
):
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
"""Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset, dummy_geometry_builder
sample_detection_dataset,
dummy_targets.decode_roi,
)
assert isinstance(raw_predictions, list)
@ -286,11 +299,11 @@ def test_convert_xr_dataset_basic(
20 + 7 / 2,
300 + 16 / 2,
]
xr.testing.assert_allclose(
np.testing.assert_allclose(
pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0),
)
xr.testing.assert_allclose(
np.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0)
)
@ -304,21 +317,20 @@ def test_convert_xr_dataset_basic(
10 + 3 / 2,
200 + 12 / 2,
]
xr.testing.assert_allclose(
np.testing.assert_allclose(
pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1),
)
xr.testing.assert_allclose(
np.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1)
)
def test_convert_xr_dataset_empty(
empty_detection_dataset, dummy_geometry_builder
):
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
"""Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset, dummy_geometry_builder
empty_detection_dataset,
dummy_targets.decode_roi,
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0
@ -327,8 +339,7 @@ def test_convert_xr_dataset_empty(
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
dummy_targets,
):
"""Test basic conversion, default threshold, multi-label."""
@ -337,8 +348,7 @@ def test_convert_raw_to_sound_event_basic(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
)
assert isinstance(se_pred, data.SoundEventPrediction)
@ -357,10 +367,11 @@ def test_convert_raw_to_sound_event_basic(
)
assert feat_dict["batdetect2:f0"] == 7.0
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
("soundevent:species", "Myotis", 0.43),
("category", "noise", 0.85),
("dwc:scientificName", "Myotis", 0.43),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -369,10 +380,7 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
sample_raw_predictions, sample_recording, dummy_targets
):
"""Test effect of classification threshold."""
raw_pred = sample_raw_predictions[0]
@ -381,15 +389,15 @@ def test_convert_raw_to_sound_event_thresholding(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
classification_threshold=high_threshold,
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
("category", "noise", 0.85),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -400,8 +408,7 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
dummy_targets,
):
"""Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[2]
@ -409,16 +416,16 @@ def test_convert_raw_to_sound_event_no_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
classification_threshold=None,
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
("soundevent:species", "Myotis", 0.05),
("soundevent:category", "noise", 0.02),
("dwc:scientificName", "Myotis", 0.05),
("category", "noise", 0.02),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -429,8 +436,7 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
dummy_targets,
):
"""Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[0]
@ -438,15 +444,15 @@ def test_convert_raw_to_sound_event_top_class(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
("category", "noise", 0.85),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -457,8 +463,7 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
dummy_targets,
):
"""Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[2]
@ -466,12 +471,12 @@ def test_convert_raw_to_sound_event_all_below_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
@ -484,15 +489,13 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
dummy_targets,
):
"""Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
@ -515,23 +518,19 @@ def test_convert_raw_list_to_clip_basic(
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags
}
generic_tags = dummy_targets.generic_class_tags
expected_tags3 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
assert se_pred3_tags == expected_tags3
def test_convert_raw_list_to_clip_empty(
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
"""Test converting an empty list of RawPredictions."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=[],
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
)
assert isinstance(clip_pred, data.ClipPrediction)
@ -542,16 +541,14 @@ def test_convert_raw_list_to_clip_empty(
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
dummy_targets,
):
"""Test that arguments like top_class_only are passed through."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
targets=dummy_targets,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True,
)
@ -562,16 +559,18 @@ def test_convert_raw_list_to_clip_passes_args(
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags
}
generic_tags = dummy_targets.generic_class_tags
expected_tags1 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
("category", "noise", 0.85),
}
assert se_pred1_tags == expected_tags1
def test_get_generic_tags_basic(generic_tags):
def test_get_generic_tags_basic(dummy_targets):
"""Test creation of generic tags with score."""
detection_score = 0.75
generic_tags = dummy_targets.generic_class_tags
predicted_tags = get_generic_tags(
detection_score=detection_score, generic_class_tags=generic_tags
)
@ -589,17 +588,19 @@ def test_get_prediction_features_basic():
coords={"feature": ["feat1", "feat2", "feat3"]},
dims=["feature"],
)
features = get_prediction_features(feature_data)
features = get_prediction_features(feature_data.values)
assert len(features) == 3
for feature, feat_name, feat_value in zip(
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
features,
["f0", "f1", "f2"],
[1.1, 2.2, 3.3],
):
assert isinstance(feature, data.Feature)
assert feature.term.name == f"batdetect2:{feat_name}"
assert feature.value == feat_value
def test_get_class_tags_basic(dummy_sound_event_decoder):
def test_get_class_tags_basic(dummy_targets):
"""Test creation of class tags based on scores and decoder."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
@ -607,8 +608,8 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
dims=["category"],
)
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
class_scores=class_scores.values,
targets=dummy_targets,
)
assert len(predicted_tags) == 3
tag_values = [pt.tag.value for pt in predicted_tags]
@ -622,7 +623,7 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
assert 0.9 in tag_scores
def test_get_class_tags_thresholding(dummy_sound_event_decoder):
def test_get_class_tags_thresholding(dummy_targets):
"""Test class tag creation with a threshold."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
@ -631,8 +632,8 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
)
threshold = 0.5
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
class_scores=class_scores.values,
targets=dummy_targets,
threshold=threshold,
)
@ -643,7 +644,7 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
assert "uncertain" in tag_values
def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
def test_get_class_tags_top_class_only(dummy_targets):
"""Test class tag creation with top_class_only."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
@ -651,8 +652,8 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
dims=["category"],
)
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
class_scores=class_scores.values,
targets=dummy_targets,
top_class_only=True,
)
@ -661,11 +662,11 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
assert predicted_tags[0].score == 0.9
def test_get_class_tags_empty(dummy_sound_event_decoder):
def test_get_class_tags_empty(dummy_targets):
"""Test with empty class scores."""
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
class_scores=class_scores.values,
targets=dummy_targets,
)
assert len(predicted_tags) == 0

View File

@ -5,6 +5,7 @@ from uuid import uuid4
import pytest
from pydantic import ValidationError
from soundevent import data
from soundevent.terms import get_term
from batdetect2.targets.classes import (
DEFAULT_SPECIES_LIST,
@ -21,26 +22,19 @@ from batdetect2.targets.classes import (
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.terms import TagInfo, TermRegistry
from batdetect2.targets.terms import TagInfo
@pytest.fixture
def sample_annotation(
sound_event: data.SoundEvent,
sample_term_registry: TermRegistry,
) -> data.SoundEventAnnotation:
"""Fixture for a sample SoundEventAnnotation."""
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(
term=sample_term_registry.get_term("species"),
value="Pipistrellus pipistrellus",
),
data.Tag(
term=sample_term_registry.get_term("quality"),
value="Good",
),
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
],
)
@ -136,59 +130,33 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
def test_is_target_class_match_all(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
}
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
}
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=False) is False
@ -208,7 +176,6 @@ def test_get_class_names_from_config():
def test_build_encoder_from_config(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
config = ClassesConfig(
classes=[
@ -220,25 +187,18 @@ def test_build_encoder_from_config(
)
]
)
encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
encoder = build_sound_event_encoder(config)
result = encoder(sample_annotation)
assert result == "pippip"
config = ClassesConfig(classes=[])
encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
encoder = build_sound_event_encoder(config)
result = encoder(sample_annotation)
assert result is None
def test_load_encoder_from_config_valid(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
@ -249,10 +209,7 @@ def test_load_encoder_from_config_valid(
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
encoder = load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)
encoder = load_encoder_from_config(temp_yaml_path)
# We cannot directly compare the function, so we test it.
result = encoder(sample_annotation) # type: ignore
assert result == "pippip"
@ -260,7 +217,6 @@ def test_load_encoder_from_config_valid(
def test_load_encoder_from_config_invalid(
create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
):
yaml_content = """
classes:
@ -271,10 +227,7 @@ def test_load_encoder_from_config_invalid(
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(KeyError):
load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)
load_encoder_from_config(temp_yaml_path)
def test_get_default_class_name():
@ -291,7 +244,7 @@ def test_get_default_classes():
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
def test_build_decoder_from_config():
config = ClassesConfig(
classes=[
TargetClass(
@ -304,12 +257,10 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
decoder = build_sound_event_decoder(config)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].term == get_term("event")
assert tags[0].value == "Echolocation"
# Test when output_tags is None, should fall back to tags
@ -324,32 +275,25 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
decoder = build_sound_event_decoder(config)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == sample_term_registry["species"]
assert tags[0].term == get_term("species")
assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=True
)
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
with pytest.raises(ValueError):
decoder("unknown_class")
# Test raise_on_unmapped=False
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=False
)
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
tags = decoder("unknown_class")
assert len(tags) == 0
def test_load_decoder_from_config_valid(
create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
):
yaml_content = """
classes:
@ -366,17 +310,15 @@ def test_load_decoder_from_config_valid(
"""
temp_yaml_path = create_temp_yaml(yaml_content)
decoder = load_decoder_from_config(
temp_yaml_path, term_registry=sample_term_registry
temp_yaml_path,
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].term == get_term("call_type")
assert tags[0].value == "Echolocation"
def test_build_generic_class_tags_from_config(
sample_term_registry: TermRegistry,
):
def test_build_generic_class_tags_from_config():
config = ClassesConfig(
classes=[
TargetClass(
@ -391,11 +333,9 @@ def test_build_generic_class_tags_from_config(
TagInfo(key="call_type", value="Echolocation"),
],
)
generic_tags = build_generic_class_tags(
config, term_registry=sample_term_registry
)
generic_tags = build_generic_class_tags(config)
assert len(generic_tags) == 2
assert generic_tags[0].term == sample_term_registry["order"]
assert generic_tags[0].term == get_term("order")
assert generic_tags[0].value == "Chiroptera"
assert generic_tags[1].term == sample_term_registry["call_type"]
assert generic_tags[1].term == get_term("call_type")
assert generic_tags[1].value == "Echolocation"

View File

@ -80,7 +80,6 @@ def test_generated_heatmaps_have_correct_dimensions(
def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
pippip_tag: TagInfo,
):
config = sample_target_config.model_copy(
@ -92,7 +91,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
)
)
targets = build_targets(config, term_registry=sample_term_registry)
targets = build_targets(config)
spec = xr.DataArray(
data=np.random.rand(100, 100),
@ -113,12 +112,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
coordinates=[10, 10, 20, 20],
),
),
tags=[
data.Tag(
term=sample_term_registry[pippip_tag.key],
value=pippip_tag.value,
)
],
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
)
],
)

View File

@ -2,12 +2,12 @@ import pytest
import torch
import xarray as xr
from soundevent import data
from soundevent.terms import get_term
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
@ -15,7 +15,6 @@ from batdetect2.train.preprocess import generate_train_example
@pytest.fixture
def build_from_config(
create_temp_yaml,
sample_term_registry,
):
def build(yaml_content):
config_path = create_temp_yaml(yaml_content)
@ -31,9 +30,7 @@ def build_from_config(
field="postprocessing",
)
targets = build_targets(
targets_config, term_registry=sample_term_registry
)
targets = build_targets(targets_config)
preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler(
targets=targets,
@ -54,7 +51,6 @@ def build_from_config(
# TODO: better name
def test_generated_train_example_has_expected_outputs(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
@ -78,10 +74,11 @@ def test_generated_train_example_has_expected_outputs(
_, preprocessor, labeller, _ = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
)
clip_annotation = data.ClipAnnotation(
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
@ -108,7 +105,6 @@ def test_generated_train_example_has_expected_outputs(
def test_encoding_decoding_roundtrip_recovers_object(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
@ -131,10 +127,11 @@ def test_encoding_decoding_roundtrip_recovers_object(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -171,14 +168,16 @@ def test_encoding_decoding_roundtrip_recovers_object(
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == species), None
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
@ -187,7 +186,6 @@ def test_encoding_decoding_roundtrip_recovers_object(
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
@ -217,10 +215,9 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Myotis myotis")],
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -257,14 +254,16 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == species), None
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1