Compare commits

..

No commits in common. "c7b110feebff0bb7fcde361484435b9bd809687a" and "b997a122f1e41e70bf686a8ed72e6fce776c3dd0" have entirely different histories.

12 changed files with 302 additions and 265 deletions

View File

@ -85,7 +85,6 @@ dev = [
"ty>=0.0.1a12",
"rust-just>=1.40.0",
"pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0",
]
dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"]

View File

@ -179,7 +179,7 @@ def plot_false_positive_match(
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,

View File

@ -543,9 +543,7 @@ class Postprocessor(PostprocessorProtocol):
]
def get_sound_event_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
@ -555,7 +553,8 @@ class Postprocessor(PostprocessorProtocol):
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=self.targets,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
),
)
@ -591,7 +590,8 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=self.targets,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)

View File

@ -33,7 +33,8 @@ import xarray as xr
from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
__all__ = [
"convert_xr_dataset_to_raw_prediction",
@ -91,30 +92,25 @@ def convert_xr_dataset_to_raw_prediction(
"""
detections = []
categories = detection_dataset.category.values
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
for score, class_scores, time, freq, dims, feats in zip(
detection_dataset["scores"].values,
detection_dataset["classes"].values,
detection_dataset["time"].values,
detection_dataset["frequency"].values,
detection_dataset["dimensions"].values,
detection_dataset["features"].values,
):
highest_scoring_class = categories[class_scores.argmax()]
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()
geom = geometry_decoder(
(time, freq),
dims,
(det_info.time, det_info.frequency),
det_info.dimensions,
class_name=highest_scoring_class,
)
detections.append(
RawPrediction(
detection_score=score,
detection_score=det_info.scores,
geometry=geom,
class_scores=class_scores,
features=feats,
class_scores=det_info.classes,
features=det_info.features,
)
)
@ -124,7 +120,8 @@ def convert_xr_dataset_to_raw_prediction(
def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction],
clip: data.Clip,
targets: TargetProtocol,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
@ -163,7 +160,8 @@ def convert_raw_predictions_to_clip_prediction(
convert_raw_prediction_to_sound_event_prediction(
prediction,
recording=clip.recording,
targets=targets,
sound_event_decoder=sound_event_decoder,
generic_class_tags=generic_class_tags,
classification_threshold=classification_threshold,
top_class_only=top_class_only,
)
@ -175,7 +173,8 @@ def convert_raw_predictions_to_clip_prediction(
def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
recording: data.Recording,
targets: TargetProtocol,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: Optional[
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
@ -252,11 +251,11 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [
*get_generic_tags(
raw_prediction.detection_score,
generic_class_tags=targets.generic_class_tags,
generic_class_tags=generic_class_tags,
),
*get_class_tags(
raw_prediction.class_scores,
targets=targets,
sound_event_decoder,
top_class_only=top_class_only,
threshold=classification_threshold,
),
@ -298,7 +297,7 @@ def get_generic_tags(
]
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features.
Parameters
@ -321,19 +320,19 @@ def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
return [
data.Feature(
term=data.Term(
name=f"batdetect2:f{index}",
label=f"BatDetect Feature {index}",
name=f"batdetect2:{feat_name}",
label=feat_name,
definition="Automatically extracted features by BatDetect2",
),
value=value,
)
for index, value in enumerate(features)
for feat_name, value in iterate_over_array(features)
]
def get_class_tags(
class_scores: np.ndarray,
targets: TargetProtocol,
class_scores: xr.DataArray,
sound_event_decoder: SoundEventDecoder,
top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]:
@ -368,13 +367,11 @@ def get_class_tags(
"""
tags = []
for class_name, score in _iterate_sorted(
class_scores, targets.class_names
):
if threshold is not None and score < threshold:
continue
if threshold is not None:
class_scores = class_scores.where(class_scores > threshold, drop=True)
class_tags = targets.decode_class(class_name)
for class_name, score in _iterate_sorted(class_scores):
class_tags = sound_event_decoder(class_name)
for tag in class_tags:
tags.append(
@ -390,7 +387,9 @@ def get_class_tags(
return tags
def _iterate_sorted(array: np.ndarray, class_names: List[str]):
indices = np.argsort(-array)
def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name].values
indices = np.argsort(-array.values)
for index in indices:
yield str(class_names[index]), float(array[index])
yield str(coords[index]), float(array.values[index])

View File

@ -14,7 +14,6 @@ system that deal with model predictions.
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr
from soundevent import data
@ -73,8 +72,8 @@ class RawPrediction(NamedTuple):
geometry: data.Geometry
detection_score: float
class_scores: np.ndarray
features: np.ndarray
class_scores: xr.DataArray
features: xr.DataArray
@dataclass

View File

@ -91,7 +91,6 @@ terms.register_term_set(
"individual": individual.name,
"event": call_type.name,
"source": data_source.name,
"call_type": call_type.name,
},
),
override_existing=True,

View File

@ -1,4 +1,3 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
@ -16,10 +15,7 @@ from batdetect2.evaluate.match import (
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
)
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule
@ -118,30 +114,14 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self._matches.extend(
_get_batch_clips_and_predictions(
batch,
outputs,
dataset=self.get_dataset(trainer),
postprocessor=pl_module.postprocessor,
targets=pl_module.targets,
)
)
dataset = self.get_dataset(trainer)
def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=targets,
targets=pl_module.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
@ -152,17 +132,15 @@ def _get_batch_clips_and_predictions(
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions(
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs,
clips,
)
return [
(clip_annotation, clip_predictions)
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
)
]
):
self._matches.append((clip_annotation, clip_predictions))
def _match_all_collected_examples(
@ -172,8 +150,7 @@ def _match_all_collected_examples(
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions")
cpu_count = os.cpu_count() or 1
with Pool(processes=min(cpu_count, 4)) as p:
with Pool() as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,

View File

@ -254,13 +254,11 @@ class TestLoadBatDetect2Files:
assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1
assert clip_ann.notes[0].message == "Standard notes."
clip_tag = data.find_tag(clip_ann.tags, term_label="Class")
clip_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_tag is not None
assert clip_tag.value == "Myotis"
recording_tag = data.find_tag(
clip_ann.clip.recording.tags, term_label="Class"
)
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
assert recording_tag is not None
assert recording_tag.value == "Myotis"
@ -273,15 +271,15 @@ class TestLoadBatDetect2Files:
40000,
]
se_class_tag = data.find_tag(se_ann.tags, term_label="Class")
se_class_tag = data.find_tag(se_ann.tags, "Class")
assert se_class_tag is not None
assert se_class_tag.value == "Myotis"
se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type")
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
assert se_event_tag is not None
assert se_event_tag.value == "Echolocation"
se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual")
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
assert se_individual_tag is not None
assert se_individual_tag.value == "0"
@ -441,7 +439,7 @@ class TestLoadBatDetect2Merged:
assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1
clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class")
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_class_tag is not None
assert clip_class_tag.value == "Myotis"

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
import numpy as np
import pytest
@ -16,11 +16,35 @@ from batdetect2.postprocess.decoding import (
get_prediction_features,
)
from batdetect2.postprocess.types import RawPrediction
from batdetect2.targets.types import TargetProtocol
@pytest.fixture
def dummy_targets() -> TargetProtocol:
def dummy_geometry_builder():
"""A simple GeometryBuilder that creates a BBox around the point."""
def _builder(
position: Tuple[float, float],
dimensions: xr.DataArray,
class_name: Optional[str] = None,
) -> data.BoundingBox:
time, freq = position
width = dimensions.sel(dimension="width").item()
height = dimensions.sel(dimension="height").item()
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return _builder
@pytest.fixture
def dummy_sound_event_decoder():
"""A simple SoundEventDecoder mapping names to tags."""
tag_map = {
"bat": [
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
@ -33,57 +57,19 @@ def dummy_targets() -> TargetProtocol:
],
}
class DummyTargets(TargetProtocol):
class_names = [
"bat",
"noise",
"unknown",
def _decoder(class_name: str) -> List[data.Tag]:
return tag_map.get(class_name.lower(), [])
return _decoder
@pytest.fixture
def generic_tags() -> List[data.Tag]:
"""Sample generic tags."""
return [
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
]
dimension_names = ["width", "height"]
generic_class_tags = [
data.Tag(
term=data.term_from_key(key="detector"), value="batdetect2"
)
]
def filter(self, sound_event: data.SoundEventAnnotation):
return True
def transform(self, sound_event: data.SoundEventAnnotation):
return sound_event
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
return "bat"
def decode_class(self, class_label: str) -> List[data.Tag]:
return tag_map.get(class_label.lower(), [])
def encode_roi(self, sound_event: data.SoundEventAnnotation):
return np.array([0.0, 0.0]), np.array([0.0, 0.0])
def decode_roi(
self,
position,
size: np.ndarray,
class_name: Optional[str] = None,
):
time, freq = position
width, height = size
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return DummyTargets()
@pytest.fixture
def sample_recording() -> data.Recording:
@ -170,7 +156,7 @@ def empty_detection_dataset() -> xr.Dataset:
"""Creates an empty detection dataset with correct structure."""
detection_coords = {
"time": ("detection", np.array([], dtype=np.float64)),
"frequency": ("detection", np.array([], dtype=np.float64)),
"freq": ("detection", np.array([], dtype=np.float64)),
}
scores = xr.DataArray(
np.array([], dtype=np.float64),
@ -198,7 +184,7 @@ def empty_detection_dataset() -> xr.Dataset:
)
return xr.Dataset(
{
"scores": scores,
"score": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
@ -229,8 +215,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
300 + 16 / 2,
]
),
class_scores=pred1_classes.values,
features=pred1_features.values,
class_scores=pred1_classes,
features=pred1_features,
)
pred2_classes = xr.DataArray(
@ -251,8 +237,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
200 + 12 / 2,
]
),
class_scores=pred2_classes.values,
features=pred2_features.values,
class_scores=pred2_classes,
features=pred2_features,
)
pred3_classes = xr.DataArray(
@ -273,17 +259,18 @@ def sample_raw_predictions() -> List[RawPrediction]:
60.0,
]
),
class_scores=pred3_classes.values,
features=pred3_features.values,
class_scores=pred3_classes,
features=pred3_features,
)
return [pred1, pred2, pred3]
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
def test_convert_xr_dataset_basic(
sample_detection_dataset, dummy_geometry_builder
):
"""Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset,
dummy_targets.decode_roi,
sample_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
@ -299,11 +286,11 @@ def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
20 + 7 / 2,
300 + 16 / 2,
]
np.testing.assert_allclose(
xr.testing.assert_allclose(
pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0),
)
np.testing.assert_allclose(
xr.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0)
)
@ -317,20 +304,21 @@ def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
10 + 3 / 2,
200 + 12 / 2,
]
np.testing.assert_allclose(
xr.testing.assert_allclose(
pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1),
)
np.testing.assert_allclose(
xr.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1)
)
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
def test_convert_xr_dataset_empty(
empty_detection_dataset, dummy_geometry_builder
):
"""Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset,
dummy_targets.decode_roi,
empty_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0
@ -339,7 +327,8 @@ def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions,
sample_recording,
dummy_targets,
dummy_sound_event_decoder,
generic_tags,
):
"""Test basic conversion, default threshold, multi-label."""
@ -348,7 +337,8 @@ def test_convert_raw_to_sound_event_basic(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
)
assert isinstance(se_pred, data.SoundEventPrediction)
@ -367,11 +357,10 @@ def test_convert_raw_to_sound_event_basic(
)
assert feat_dict["batdetect2:f0"] == 7.0
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
("dwc:scientificName", "Myotis", 0.43),
("soundevent:category", "noise", 0.85),
("soundevent:species", "Myotis", 0.43),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -380,7 +369,10 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions, sample_recording, dummy_targets
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test effect of classification threshold."""
raw_pred = sample_raw_predictions[0]
@ -389,15 +381,15 @@ def test_convert_raw_to_sound_event_thresholding(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=high_threshold,
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
("soundevent:category", "noise", 0.85),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -408,7 +400,8 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions,
sample_recording,
dummy_targets,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[2]
@ -416,16 +409,16 @@ def test_convert_raw_to_sound_event_no_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=None,
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
("dwc:scientificName", "Myotis", 0.05),
("category", "noise", 0.02),
("soundevent:species", "Myotis", 0.05),
("soundevent:category", "noise", 0.02),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -436,7 +429,8 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions,
sample_recording,
dummy_targets,
dummy_sound_event_decoder,
generic_tags,
):
"""Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[0]
@ -444,15 +438,15 @@ def test_convert_raw_to_sound_event_top_class(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
("soundevent:category", "noise", 0.85),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -463,7 +457,8 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions,
sample_recording,
dummy_targets,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[2]
@ -471,12 +466,12 @@ def test_convert_raw_to_sound_event_all_below_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
@ -489,13 +484,15 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions,
sample_clip,
dummy_targets,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
@ -518,19 +515,23 @@ def test_convert_raw_list_to_clip_basic(
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags
}
generic_tags = dummy_targets.generic_class_tags
expected_tags3 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
assert se_pred3_tags == expected_tags3
def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
def test_convert_raw_list_to_clip_empty(
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting an empty list of RawPredictions."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=[],
clip=sample_clip,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
)
assert isinstance(clip_pred, data.ClipPrediction)
@ -541,14 +542,16 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions,
sample_clip,
dummy_targets,
dummy_sound_event_decoder,
generic_tags,
):
"""Test that arguments like top_class_only are passed through."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
targets=dummy_targets,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True,
)
@ -559,18 +562,16 @@ def test_convert_raw_list_to_clip_passes_args(
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags
}
generic_tags = dummy_targets.generic_class_tags
expected_tags1 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
("soundevent:category", "noise", 0.85),
}
assert se_pred1_tags == expected_tags1
def test_get_generic_tags_basic(dummy_targets):
def test_get_generic_tags_basic(generic_tags):
"""Test creation of generic tags with score."""
detection_score = 0.75
generic_tags = dummy_targets.generic_class_tags
predicted_tags = get_generic_tags(
detection_score=detection_score, generic_class_tags=generic_tags
)
@ -588,19 +589,17 @@ def test_get_prediction_features_basic():
coords={"feature": ["feat1", "feat2", "feat3"]},
dims=["feature"],
)
features = get_prediction_features(feature_data.values)
features = get_prediction_features(feature_data)
assert len(features) == 3
for feature, feat_name, feat_value in zip(
features,
["f0", "f1", "f2"],
[1.1, 2.2, 3.3],
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
):
assert isinstance(feature, data.Feature)
assert feature.term.name == f"batdetect2:{feat_name}"
assert feature.value == feat_value
def test_get_class_tags_basic(dummy_targets):
def test_get_class_tags_basic(dummy_sound_event_decoder):
"""Test creation of class tags based on scores and decoder."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
@ -608,8 +607,8 @@ def test_get_class_tags_basic(dummy_targets):
dims=["category"],
)
predicted_tags = get_class_tags(
class_scores=class_scores.values,
targets=dummy_targets,
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
)
assert len(predicted_tags) == 3
tag_values = [pt.tag.value for pt in predicted_tags]
@ -623,7 +622,7 @@ def test_get_class_tags_basic(dummy_targets):
assert 0.9 in tag_scores
def test_get_class_tags_thresholding(dummy_targets):
def test_get_class_tags_thresholding(dummy_sound_event_decoder):
"""Test class tag creation with a threshold."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
@ -632,8 +631,8 @@ def test_get_class_tags_thresholding(dummy_targets):
)
threshold = 0.5
predicted_tags = get_class_tags(
class_scores=class_scores.values,
targets=dummy_targets,
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
threshold=threshold,
)
@ -644,7 +643,7 @@ def test_get_class_tags_thresholding(dummy_targets):
assert "uncertain" in tag_values
def test_get_class_tags_top_class_only(dummy_targets):
def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
"""Test class tag creation with top_class_only."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
@ -652,8 +651,8 @@ def test_get_class_tags_top_class_only(dummy_targets):
dims=["category"],
)
predicted_tags = get_class_tags(
class_scores=class_scores.values,
targets=dummy_targets,
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
top_class_only=True,
)
@ -662,11 +661,11 @@ def test_get_class_tags_top_class_only(dummy_targets):
assert predicted_tags[0].score == 0.9
def test_get_class_tags_empty(dummy_targets):
def test_get_class_tags_empty(dummy_sound_event_decoder):
"""Test with empty class scores."""
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
predicted_tags = get_class_tags(
class_scores=class_scores.values,
targets=dummy_targets,
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
)
assert len(predicted_tags) == 0

View File

@ -5,7 +5,6 @@ from uuid import uuid4
import pytest
from pydantic import ValidationError
from soundevent import data
from soundevent.terms import get_term
from batdetect2.targets.classes import (
DEFAULT_SPECIES_LIST,
@ -22,19 +21,26 @@ from batdetect2.targets.classes import (
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.terms import TagInfo
from batdetect2.targets.terms import TagInfo, TermRegistry
@pytest.fixture
def sample_annotation(
sound_event: data.SoundEvent,
sample_term_registry: TermRegistry,
) -> data.SoundEventAnnotation:
"""Fixture for a sample SoundEventAnnotation."""
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
data.Tag(
term=sample_term_registry.get_term("species"),
value="Pipistrellus pipistrellus",
),
data.Tag(
term=sample_term_registry.get_term("quality"),
value="Good",
),
],
)
@ -130,33 +136,59 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
def test_is_target_class_match_all(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
}
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
}
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is False
@ -176,6 +208,7 @@ def test_get_class_names_from_config():
def test_build_encoder_from_config(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
config = ClassesConfig(
classes=[
@ -187,18 +220,25 @@ def test_build_encoder_from_config(
)
]
)
encoder = build_sound_event_encoder(config)
encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation)
assert result == "pippip"
config = ClassesConfig(classes=[])
encoder = build_sound_event_encoder(config)
encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation)
assert result is None
def test_load_encoder_from_config_valid(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
@ -209,7 +249,10 @@ def test_load_encoder_from_config_valid(
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
encoder = load_encoder_from_config(temp_yaml_path)
encoder = load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)
# We cannot directly compare the function, so we test it.
result = encoder(sample_annotation) # type: ignore
assert result == "pippip"
@ -217,6 +260,7 @@ def test_load_encoder_from_config_valid(
def test_load_encoder_from_config_invalid(
create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
):
yaml_content = """
classes:
@ -227,7 +271,10 @@ def test_load_encoder_from_config_invalid(
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(KeyError):
load_encoder_from_config(temp_yaml_path)
load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)
def test_get_default_class_name():
@ -244,7 +291,7 @@ def test_get_default_classes():
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
def test_build_decoder_from_config():
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
config = ClassesConfig(
classes=[
TargetClass(
@ -257,10 +304,12 @@ def test_build_decoder_from_config():
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(config)
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("event")
assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].value == "Echolocation"
# Test when output_tags is None, should fall back to tags
@ -275,25 +324,32 @@ def test_build_decoder_from_config():
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(config)
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("species")
assert tags[0].term == sample_term_registry["species"]
assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=True
)
with pytest.raises(ValueError):
decoder("unknown_class")
# Test raise_on_unmapped=False
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=False
)
tags = decoder("unknown_class")
assert len(tags) == 0
def test_load_decoder_from_config_valid(
create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
):
yaml_content = """
classes:
@ -310,15 +366,17 @@ def test_load_decoder_from_config_valid(
"""
temp_yaml_path = create_temp_yaml(yaml_content)
decoder = load_decoder_from_config(
temp_yaml_path,
temp_yaml_path, term_registry=sample_term_registry
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("call_type")
assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].value == "Echolocation"
def test_build_generic_class_tags_from_config():
def test_build_generic_class_tags_from_config(
sample_term_registry: TermRegistry,
):
config = ClassesConfig(
classes=[
TargetClass(
@ -333,9 +391,11 @@ def test_build_generic_class_tags_from_config():
TagInfo(key="call_type", value="Echolocation"),
],
)
generic_tags = build_generic_class_tags(config)
generic_tags = build_generic_class_tags(
config, term_registry=sample_term_registry
)
assert len(generic_tags) == 2
assert generic_tags[0].term == get_term("order")
assert generic_tags[0].term == sample_term_registry["order"]
assert generic_tags[0].value == "Chiroptera"
assert generic_tags[1].term == get_term("call_type")
assert generic_tags[1].term == sample_term_registry["call_type"]
assert generic_tags[1].value == "Echolocation"

View File

@ -80,6 +80,7 @@ def test_generated_heatmaps_have_correct_dimensions(
def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
pippip_tag: TagInfo,
):
config = sample_target_config.model_copy(
@ -91,7 +92,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
)
)
targets = build_targets(config)
targets = build_targets(config, term_registry=sample_term_registry)
spec = xr.DataArray(
data=np.random.rand(100, 100),
@ -112,7 +113,12 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
coordinates=[10, 10, 20, 20],
),
),
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
tags=[
data.Tag(
term=sample_term_registry[pippip_tag.key],
value=pippip_tag.value,
)
],
)
],
)

View File

@ -2,12 +2,12 @@ import pytest
import torch
import xarray as xr
from soundevent import data
from soundevent.terms import get_term
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
@ -15,6 +15,7 @@ from batdetect2.train.preprocess import generate_train_example
@pytest.fixture
def build_from_config(
create_temp_yaml,
sample_term_registry,
):
def build(yaml_content):
config_path = create_temp_yaml(yaml_content)
@ -30,7 +31,9 @@ def build_from_config(
field="postprocessing",
)
targets = build_targets(targets_config)
targets = build_targets(
targets_config, term_registry=sample_term_registry
)
preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler(
targets=targets,
@ -51,6 +54,7 @@ def build_from_config(
# TODO: better name
def test_generated_train_example_has_expected_outputs(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
@ -74,11 +78,10 @@ def test_generated_train_example_has_expected_outputs(
_, preprocessor, labeller, _ = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
)
clip_annotation = data.ClipAnnotation(
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
@ -105,6 +108,7 @@ def test_generated_train_example_has_expected_outputs(
def test_encoding_decoding_roundtrip_recovers_object(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
@ -127,11 +131,10 @@ def test_encoding_decoding_roundtrip_recovers_object(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -168,16 +171,14 @@ def test_encoding_decoding_roundtrip_recovers_object(
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
iter(t for t in recovered.tags if t.tag.term == species), None
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
@ -186,6 +187,7 @@ def test_encoding_decoding_roundtrip_recovers_object(
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
@ -215,9 +217,10 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(term=species, value="Myotis myotis")],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -254,16 +257,14 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
iter(t for t in recovered.tags if t.tag.term == species), None
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1