Compare commits

..

No commits in common. "c7b110feebff0bb7fcde361484435b9bd809687a" and "b997a122f1e41e70bf686a8ed72e6fce776c3dd0" have entirely different histories.

12 changed files with 302 additions and 265 deletions

View File

@ -85,7 +85,6 @@ dev = [
"ty>=0.0.1a12", "ty>=0.0.1a12",
"rust-just>=1.40.0", "rust-just>=1.40.0",
"pandas-stubs>=2.2.2.240807", "pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0",
] ]
dvclive = ["dvclive>=3.48.2"] dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"] mlflow = ["mlflow>=3.1.1"]

View File

@ -179,7 +179,7 @@ def plot_false_positive_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ", f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,

View File

@ -543,9 +543,7 @@ class Postprocessor(PostprocessorProtocol):
] ]
def get_sound_event_predictions( def get_sound_event_predictions(
self, self, output: ModelOutput, clips: List[data.Clip]
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]: ) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips) raw_predictions = self.get_raw_predictions(output, clips)
return [ return [
@ -555,7 +553,8 @@ class Postprocessor(PostprocessorProtocol):
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw, raw,
recording=clip.recording, recording=clip.recording,
targets=self.targets, sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold, classification_threshold=self.config.classification_threshold,
), ),
) )
@ -591,7 +590,8 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction( convert_raw_predictions_to_clip_prediction(
prediction, prediction,
clip, clip,
targets=self.targets, sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold, classification_threshold=self.config.classification_threshold,
) )
for prediction, clip in zip(raw_predictions, clips) for prediction, clip in zip(raw_predictions, clips)

View File

@ -33,7 +33,8 @@ import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
__all__ = [ __all__ = [
"convert_xr_dataset_to_raw_prediction", "convert_xr_dataset_to_raw_prediction",
@ -91,30 +92,25 @@ def convert_xr_dataset_to_raw_prediction(
""" """
detections = [] detections = []
categories = detection_dataset.category.values for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
for score, class_scores, time, freq, dims, feats in zip( highest_scoring_class = det_info.coords["category"][
detection_dataset["scores"].values, det_info["classes"].argmax()
detection_dataset["classes"].values, ].item()
detection_dataset["time"].values,
detection_dataset["frequency"].values,
detection_dataset["dimensions"].values,
detection_dataset["features"].values,
):
highest_scoring_class = categories[class_scores.argmax()]
geom = geometry_decoder( geom = geometry_decoder(
(time, freq), (det_info.time, det_info.frequency),
dims, det_info.dimensions,
class_name=highest_scoring_class, class_name=highest_scoring_class,
) )
detections.append( detections.append(
RawPrediction( RawPrediction(
detection_score=score, detection_score=det_info.scores,
geometry=geom, geometry=geom,
class_scores=class_scores, class_scores=det_info.classes,
features=feats, features=det_info.features,
) )
) )
@ -124,7 +120,8 @@ def convert_xr_dataset_to_raw_prediction(
def convert_raw_predictions_to_clip_prediction( def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction], raw_predictions: List[RawPrediction],
clip: data.Clip, clip: data.Clip,
targets: TargetProtocol, sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False, top_class_only: bool = False,
) -> data.ClipPrediction: ) -> data.ClipPrediction:
@ -163,7 +160,8 @@ def convert_raw_predictions_to_clip_prediction(
convert_raw_prediction_to_sound_event_prediction( convert_raw_prediction_to_sound_event_prediction(
prediction, prediction,
recording=clip.recording, recording=clip.recording,
targets=targets, sound_event_decoder=sound_event_decoder,
generic_class_tags=generic_class_tags,
classification_threshold=classification_threshold, classification_threshold=classification_threshold,
top_class_only=top_class_only, top_class_only=top_class_only,
) )
@ -175,7 +173,8 @@ def convert_raw_predictions_to_clip_prediction(
def convert_raw_prediction_to_sound_event_prediction( def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction, raw_prediction: RawPrediction,
recording: data.Recording, recording: data.Recording,
targets: TargetProtocol, sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: Optional[ classification_threshold: Optional[
float float
] = DEFAULT_CLASSIFICATION_THRESHOLD, ] = DEFAULT_CLASSIFICATION_THRESHOLD,
@ -252,11 +251,11 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [ tags = [
*get_generic_tags( *get_generic_tags(
raw_prediction.detection_score, raw_prediction.detection_score,
generic_class_tags=targets.generic_class_tags, generic_class_tags=generic_class_tags,
), ),
*get_class_tags( *get_class_tags(
raw_prediction.class_scores, raw_prediction.class_scores,
targets=targets, sound_event_decoder,
top_class_only=top_class_only, top_class_only=top_class_only,
threshold=classification_threshold, threshold=classification_threshold,
), ),
@ -298,7 +297,7 @@ def get_generic_tags(
] ]
def get_prediction_features(features: np.ndarray) -> List[data.Feature]: def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features. """Convert an extracted feature vector DataArray into soundevent Features.
Parameters Parameters
@ -321,19 +320,19 @@ def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
return [ return [
data.Feature( data.Feature(
term=data.Term( term=data.Term(
name=f"batdetect2:f{index}", name=f"batdetect2:{feat_name}",
label=f"BatDetect Feature {index}", label=feat_name,
definition="Automatically extracted features by BatDetect2", definition="Automatically extracted features by BatDetect2",
), ),
value=value, value=value,
) )
for index, value in enumerate(features) for feat_name, value in iterate_over_array(features)
] ]
def get_class_tags( def get_class_tags(
class_scores: np.ndarray, class_scores: xr.DataArray,
targets: TargetProtocol, sound_event_decoder: SoundEventDecoder,
top_class_only: bool = False, top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD, threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]: ) -> List[data.PredictedTag]:
@ -368,13 +367,11 @@ def get_class_tags(
""" """
tags = [] tags = []
for class_name, score in _iterate_sorted( if threshold is not None:
class_scores, targets.class_names class_scores = class_scores.where(class_scores > threshold, drop=True)
):
if threshold is not None and score < threshold:
continue
class_tags = targets.decode_class(class_name) for class_name, score in _iterate_sorted(class_scores):
class_tags = sound_event_decoder(class_name)
for tag in class_tags: for tag in class_tags:
tags.append( tags.append(
@ -390,7 +387,9 @@ def get_class_tags(
return tags return tags
def _iterate_sorted(array: np.ndarray, class_names: List[str]): def _iterate_sorted(array: xr.DataArray):
indices = np.argsort(-array) dim_name = array.dims[0]
coords = array.coords[dim_name].values
indices = np.argsort(-array.values)
for index in indices: for index in indices:
yield str(class_names[index]), float(array[index]) yield str(coords[index]), float(array.values[index])

View File

@ -14,7 +14,6 @@ system that deal with model predictions.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
@ -73,8 +72,8 @@ class RawPrediction(NamedTuple):
geometry: data.Geometry geometry: data.Geometry
detection_score: float detection_score: float
class_scores: np.ndarray class_scores: xr.DataArray
features: np.ndarray features: xr.DataArray
@dataclass @dataclass

View File

@ -91,7 +91,6 @@ terms.register_term_set(
"individual": individual.name, "individual": individual.name,
"event": call_type.name, "event": call_type.name,
"source": data_source.name, "source": data_source.name,
"call_type": call_type.name,
}, },
), ),
override_existing=True, override_existing=True,

View File

@ -1,4 +1,3 @@
import os
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -16,10 +15,7 @@ from batdetect2.evaluate.match import (
) )
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import ( from batdetect2.postprocess.types import BatDetect2Prediction
BatDetect2Prediction,
PostprocessorProtocol,
)
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -118,51 +114,33 @@ class ValidationMetrics(Callback):
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> None: ) -> None:
self._matches.extend( dataset = self.get_dataset(trainer)
_get_batch_clips_and_predictions(
batch, clip_annotations = [
outputs, _get_subclip(
dataset=self.get_dataset(trainer), dataset.get_clip_annotation(example_id),
postprocessor=pl_module.postprocessor, start_time=start_time.item(),
end_time=end_time.item(),
targets=pl_module.targets, targets=pl_module.targets,
) )
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs,
clips,
) )
def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions(
outputs,
clips,
)
return [
(clip_annotation, clip_predictions)
for clip_annotation, clip_predictions in zip( for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions clip_annotations, raw_predictions
) ):
] self._matches.append((clip_annotation, clip_predictions))
def _match_all_collected_examples( def _match_all_collected_examples(
@ -172,8 +150,7 @@ def _match_all_collected_examples(
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions") logger.info("Matching all annotations and predictions")
cpu_count = os.cpu_count() or 1 with Pool() as p:
with Pool(processes=min(cpu_count, 4)) as p:
matches = p.starmap( matches = p.starmap(
partial( partial(
match_sound_events_and_raw_predictions, match_sound_events_and_raw_predictions,

View File

@ -254,13 +254,11 @@ class TestLoadBatDetect2Files:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
assert clip_ann.notes[0].message == "Standard notes." assert clip_ann.notes[0].message == "Standard notes."
clip_tag = data.find_tag(clip_ann.tags, term_label="Class") clip_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_tag is not None assert clip_tag is not None
assert clip_tag.value == "Myotis" assert clip_tag.value == "Myotis"
recording_tag = data.find_tag( recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
clip_ann.clip.recording.tags, term_label="Class"
)
assert recording_tag is not None assert recording_tag is not None
assert recording_tag.value == "Myotis" assert recording_tag.value == "Myotis"
@ -273,15 +271,15 @@ class TestLoadBatDetect2Files:
40000, 40000,
] ]
se_class_tag = data.find_tag(se_ann.tags, term_label="Class") se_class_tag = data.find_tag(se_ann.tags, "Class")
assert se_class_tag is not None assert se_class_tag is not None
assert se_class_tag.value == "Myotis" assert se_class_tag.value == "Myotis"
se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type") se_event_tag = data.find_tag(se_ann.tags, "Call Type")
assert se_event_tag is not None assert se_event_tag is not None
assert se_event_tag.value == "Echolocation" assert se_event_tag.value == "Echolocation"
se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual") se_individual_tag = data.find_tag(se_ann.tags, "Individual")
assert se_individual_tag is not None assert se_individual_tag is not None
assert se_individual_tag.value == "0" assert se_individual_tag.value == "0"
@ -441,7 +439,7 @@ class TestLoadBatDetect2Merged:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class") clip_class_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_class_tag is not None assert clip_class_tag is not None
assert clip_class_tag.value == "Myotis" assert clip_class_tag.value == "Myotis"

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional, Tuple
import numpy as np import numpy as np
import pytest import pytest
@ -16,11 +16,35 @@ from batdetect2.postprocess.decoding import (
get_prediction_features, get_prediction_features,
) )
from batdetect2.postprocess.types import RawPrediction from batdetect2.postprocess.types import RawPrediction
from batdetect2.targets.types import TargetProtocol
@pytest.fixture @pytest.fixture
def dummy_targets() -> TargetProtocol: def dummy_geometry_builder():
"""A simple GeometryBuilder that creates a BBox around the point."""
def _builder(
position: Tuple[float, float],
dimensions: xr.DataArray,
class_name: Optional[str] = None,
) -> data.BoundingBox:
time, freq = position
width = dimensions.sel(dimension="width").item()
height = dimensions.sel(dimension="height").item()
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return _builder
@pytest.fixture
def dummy_sound_event_decoder():
"""A simple SoundEventDecoder mapping names to tags."""
tag_map = { tag_map = {
"bat": [ "bat": [
data.Tag(term=data.term_from_key(key="species"), value="Myotis") data.Tag(term=data.term_from_key(key="species"), value="Myotis")
@ -33,56 +57,18 @@ def dummy_targets() -> TargetProtocol:
], ],
} }
class DummyTargets(TargetProtocol): def _decoder(class_name: str) -> List[data.Tag]:
class_names = [ return tag_map.get(class_name.lower(), [])
"bat",
"noise",
"unknown",
]
dimension_names = ["width", "height"] return _decoder
generic_class_tags = [
data.Tag(
term=data.term_from_key(key="detector"), value="batdetect2"
)
]
def filter(self, sound_event: data.SoundEventAnnotation): @pytest.fixture
return True def generic_tags() -> List[data.Tag]:
"""Sample generic tags."""
def transform(self, sound_event: data.SoundEventAnnotation): return [
return sound_event data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
]
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
return "bat"
def decode_class(self, class_label: str) -> List[data.Tag]:
return tag_map.get(class_label.lower(), [])
def encode_roi(self, sound_event: data.SoundEventAnnotation):
return np.array([0.0, 0.0]), np.array([0.0, 0.0])
def decode_roi(
self,
position,
size: np.ndarray,
class_name: Optional[str] = None,
):
time, freq = position
width, height = size
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return DummyTargets()
@pytest.fixture @pytest.fixture
@ -170,7 +156,7 @@ def empty_detection_dataset() -> xr.Dataset:
"""Creates an empty detection dataset with correct structure.""" """Creates an empty detection dataset with correct structure."""
detection_coords = { detection_coords = {
"time": ("detection", np.array([], dtype=np.float64)), "time": ("detection", np.array([], dtype=np.float64)),
"frequency": ("detection", np.array([], dtype=np.float64)), "freq": ("detection", np.array([], dtype=np.float64)),
} }
scores = xr.DataArray( scores = xr.DataArray(
np.array([], dtype=np.float64), np.array([], dtype=np.float64),
@ -198,7 +184,7 @@ def empty_detection_dataset() -> xr.Dataset:
) )
return xr.Dataset( return xr.Dataset(
{ {
"scores": scores, "score": scores,
"dimensions": dimensions, "dimensions": dimensions,
"classes": classes, "classes": classes,
"features": features, "features": features,
@ -229,8 +215,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
300 + 16 / 2, 300 + 16 / 2,
] ]
), ),
class_scores=pred1_classes.values, class_scores=pred1_classes,
features=pred1_features.values, features=pred1_features,
) )
pred2_classes = xr.DataArray( pred2_classes = xr.DataArray(
@ -251,8 +237,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
200 + 12 / 2, 200 + 12 / 2,
] ]
), ),
class_scores=pred2_classes.values, class_scores=pred2_classes,
features=pred2_features.values, features=pred2_features,
) )
pred3_classes = xr.DataArray( pred3_classes = xr.DataArray(
@ -273,17 +259,18 @@ def sample_raw_predictions() -> List[RawPrediction]:
60.0, 60.0,
] ]
), ),
class_scores=pred3_classes.values, class_scores=pred3_classes,
features=pred3_features.values, features=pred3_features,
) )
return [pred1, pred2, pred3] return [pred1, pred2, pred3]
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets): def test_convert_xr_dataset_basic(
sample_detection_dataset, dummy_geometry_builder
):
"""Test basic conversion of a dataset to RawPrediction list.""" """Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction( raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset, sample_detection_dataset, dummy_geometry_builder
dummy_targets.decode_roi,
) )
assert isinstance(raw_predictions, list) assert isinstance(raw_predictions, list)
@ -299,11 +286,11 @@ def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
20 + 7 / 2, 20 + 7 / 2,
300 + 16 / 2, 300 + 16 / 2,
] ]
np.testing.assert_allclose( xr.testing.assert_allclose(
pred1.class_scores, pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0), sample_detection_dataset["classes"].sel(detection=0),
) )
np.testing.assert_allclose( xr.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0) pred1.features, sample_detection_dataset["features"].sel(detection=0)
) )
@ -317,20 +304,21 @@ def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
10 + 3 / 2, 10 + 3 / 2,
200 + 12 / 2, 200 + 12 / 2,
] ]
np.testing.assert_allclose( xr.testing.assert_allclose(
pred2.class_scores, pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1), sample_detection_dataset["classes"].sel(detection=1),
) )
np.testing.assert_allclose( xr.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1) pred2.features, sample_detection_dataset["features"].sel(detection=1)
) )
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets): def test_convert_xr_dataset_empty(
empty_detection_dataset, dummy_geometry_builder
):
"""Test conversion of an empty dataset.""" """Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction( raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset, empty_detection_dataset, dummy_geometry_builder
dummy_targets.decode_roi,
) )
assert isinstance(raw_predictions, list) assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0 assert len(raw_predictions) == 0
@ -339,7 +327,8 @@ def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
def test_convert_raw_to_sound_event_basic( def test_convert_raw_to_sound_event_basic(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_targets, dummy_sound_event_decoder,
generic_tags,
): ):
"""Test basic conversion, default threshold, multi-label.""" """Test basic conversion, default threshold, multi-label."""
@ -348,7 +337,8 @@ def test_convert_raw_to_sound_event_basic(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
) )
assert isinstance(se_pred, data.SoundEventPrediction) assert isinstance(se_pred, data.SoundEventPrediction)
@ -367,11 +357,10 @@ def test_convert_raw_to_sound_event_basic(
) )
assert feat_dict["batdetect2:f0"] == 7.0 assert feat_dict["batdetect2:f0"] == 7.0
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85), ("soundevent:category", "noise", 0.85),
("dwc:scientificName", "Myotis", 0.43), ("soundevent:species", "Myotis", 0.43),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -380,7 +369,10 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding( def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions, sample_recording, dummy_targets sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
): ):
"""Test effect of classification threshold.""" """Test effect of classification threshold."""
raw_pred = sample_raw_predictions[0] raw_pred = sample_raw_predictions[0]
@ -389,15 +381,15 @@ def test_convert_raw_to_sound_event_thresholding(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=high_threshold, classification_threshold=high_threshold,
top_class_only=False, top_class_only=False,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85), ("soundevent:category", "noise", 0.85),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -408,7 +400,8 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold( def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_targets, dummy_sound_event_decoder,
generic_tags,
): ):
"""Test when classification_threshold is None.""" """Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[2] raw_pred = sample_raw_predictions[2]
@ -416,16 +409,16 @@ def test_convert_raw_to_sound_event_no_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=None, classification_threshold=None,
top_class_only=False, top_class_only=False,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15), (generic_tags[0].term.name, generic_tags[0].value, 0.15),
("dwc:scientificName", "Myotis", 0.05), ("soundevent:species", "Myotis", 0.05),
("category", "noise", 0.02), ("soundevent:category", "noise", 0.02),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -436,7 +429,8 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class( def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_targets, dummy_sound_event_decoder,
generic_tags,
): ):
"""Test top_class_only=True behavior.""" """Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[0] raw_pred = sample_raw_predictions[0]
@ -444,15 +438,15 @@ def test_convert_raw_to_sound_event_top_class(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, top_class_only=True,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85), ("soundevent:category", "noise", 0.85),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -463,7 +457,8 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold( def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_targets, dummy_sound_event_decoder,
generic_tags,
): ):
"""Test when all class scores are below the default threshold.""" """Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[2] raw_pred = sample_raw_predictions[2]
@ -471,12 +466,12 @@ def test_convert_raw_to_sound_event_all_below_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False, top_class_only=False,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15), (generic_tags[0].term.name, generic_tags[0].value, 0.15),
} }
@ -489,13 +484,15 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic( def test_convert_raw_list_to_clip_basic(
sample_raw_predictions, sample_raw_predictions,
sample_clip, sample_clip,
dummy_targets, dummy_sound_event_decoder,
generic_tags,
): ):
"""Test converting a list of RawPredictions to a ClipPrediction.""" """Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction( clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions, raw_predictions=sample_raw_predictions,
clip=sample_clip, clip=sample_clip,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False, top_class_only=False,
) )
@ -518,19 +515,23 @@ def test_convert_raw_list_to_clip_basic(
(pt.tag.term.name, pt.tag.value, pt.score) (pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags for pt in clip_pred.sound_events[2].tags
} }
generic_tags = dummy_targets.generic_class_tags
expected_tags3 = { expected_tags3 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15), (generic_tags[0].term.name, generic_tags[0].value, 0.15),
} }
assert se_pred3_tags == expected_tags3 assert se_pred3_tags == expected_tags3
def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets): def test_convert_raw_list_to_clip_empty(
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting an empty list of RawPredictions.""" """Test converting an empty list of RawPredictions."""
clip_pred = convert_raw_predictions_to_clip_prediction( clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=[], raw_predictions=[],
clip=sample_clip, clip=sample_clip,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
) )
assert isinstance(clip_pred, data.ClipPrediction) assert isinstance(clip_pred, data.ClipPrediction)
@ -541,14 +542,16 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
def test_convert_raw_list_to_clip_passes_args( def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions, sample_raw_predictions,
sample_clip, sample_clip,
dummy_targets, dummy_sound_event_decoder,
generic_tags,
): ):
"""Test that arguments like top_class_only are passed through.""" """Test that arguments like top_class_only are passed through."""
clip_pred = convert_raw_predictions_to_clip_prediction( clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions, raw_predictions=sample_raw_predictions,
clip=sample_clip, clip=sample_clip,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, top_class_only=True,
) )
@ -559,18 +562,16 @@ def test_convert_raw_list_to_clip_passes_args(
(pt.tag.term.name, pt.tag.value, pt.score) (pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags for pt in clip_pred.sound_events[0].tags
} }
generic_tags = dummy_targets.generic_class_tags
expected_tags1 = { expected_tags1 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85), ("soundevent:category", "noise", 0.85),
} }
assert se_pred1_tags == expected_tags1 assert se_pred1_tags == expected_tags1
def test_get_generic_tags_basic(dummy_targets): def test_get_generic_tags_basic(generic_tags):
"""Test creation of generic tags with score.""" """Test creation of generic tags with score."""
detection_score = 0.75 detection_score = 0.75
generic_tags = dummy_targets.generic_class_tags
predicted_tags = get_generic_tags( predicted_tags = get_generic_tags(
detection_score=detection_score, generic_class_tags=generic_tags detection_score=detection_score, generic_class_tags=generic_tags
) )
@ -588,19 +589,17 @@ def test_get_prediction_features_basic():
coords={"feature": ["feat1", "feat2", "feat3"]}, coords={"feature": ["feat1", "feat2", "feat3"]},
dims=["feature"], dims=["feature"],
) )
features = get_prediction_features(feature_data.values) features = get_prediction_features(feature_data)
assert len(features) == 3 assert len(features) == 3
for feature, feat_name, feat_value in zip( for feature, feat_name, feat_value in zip(
features, features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
["f0", "f1", "f2"],
[1.1, 2.2, 3.3],
): ):
assert isinstance(feature, data.Feature) assert isinstance(feature, data.Feature)
assert feature.term.name == f"batdetect2:{feat_name}" assert feature.term.name == f"batdetect2:{feat_name}"
assert feature.value == feat_value assert feature.value == feat_value
def test_get_class_tags_basic(dummy_targets): def test_get_class_tags_basic(dummy_sound_event_decoder):
"""Test creation of class tags based on scores and decoder.""" """Test creation of class tags based on scores and decoder."""
class_scores = xr.DataArray( class_scores = xr.DataArray(
[0.6, 0.2, 0.9], [0.6, 0.2, 0.9],
@ -608,8 +607,8 @@ def test_get_class_tags_basic(dummy_targets):
dims=["category"], dims=["category"],
) )
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores.values, class_scores=class_scores,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
) )
assert len(predicted_tags) == 3 assert len(predicted_tags) == 3
tag_values = [pt.tag.value for pt in predicted_tags] tag_values = [pt.tag.value for pt in predicted_tags]
@ -623,7 +622,7 @@ def test_get_class_tags_basic(dummy_targets):
assert 0.9 in tag_scores assert 0.9 in tag_scores
def test_get_class_tags_thresholding(dummy_targets): def test_get_class_tags_thresholding(dummy_sound_event_decoder):
"""Test class tag creation with a threshold.""" """Test class tag creation with a threshold."""
class_scores = xr.DataArray( class_scores = xr.DataArray(
[0.6, 0.2, 0.9], [0.6, 0.2, 0.9],
@ -632,8 +631,8 @@ def test_get_class_tags_thresholding(dummy_targets):
) )
threshold = 0.5 threshold = 0.5
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores.values, class_scores=class_scores,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
threshold=threshold, threshold=threshold,
) )
@ -644,7 +643,7 @@ def test_get_class_tags_thresholding(dummy_targets):
assert "uncertain" in tag_values assert "uncertain" in tag_values
def test_get_class_tags_top_class_only(dummy_targets): def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
"""Test class tag creation with top_class_only.""" """Test class tag creation with top_class_only."""
class_scores = xr.DataArray( class_scores = xr.DataArray(
[0.6, 0.2, 0.9], [0.6, 0.2, 0.9],
@ -652,8 +651,8 @@ def test_get_class_tags_top_class_only(dummy_targets):
dims=["category"], dims=["category"],
) )
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores.values, class_scores=class_scores,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
top_class_only=True, top_class_only=True,
) )
@ -662,11 +661,11 @@ def test_get_class_tags_top_class_only(dummy_targets):
assert predicted_tags[0].score == 0.9 assert predicted_tags[0].score == 0.9
def test_get_class_tags_empty(dummy_targets): def test_get_class_tags_empty(dummy_sound_event_decoder):
"""Test with empty class scores.""" """Test with empty class scores."""
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"]) class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores.values, class_scores=class_scores,
targets=dummy_targets, sound_event_decoder=dummy_sound_event_decoder,
) )
assert len(predicted_tags) == 0 assert len(predicted_tags) == 0

View File

@ -5,7 +5,6 @@ from uuid import uuid4
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from soundevent import data from soundevent import data
from soundevent.terms import get_term
from batdetect2.targets.classes import ( from batdetect2.targets.classes import (
DEFAULT_SPECIES_LIST, DEFAULT_SPECIES_LIST,
@ -22,19 +21,26 @@ from batdetect2.targets.classes import (
load_decoder_from_config, load_decoder_from_config,
load_encoder_from_config, load_encoder_from_config,
) )
from batdetect2.targets.terms import TagInfo from batdetect2.targets.terms import TagInfo, TermRegistry
@pytest.fixture @pytest.fixture
def sample_annotation( def sample_annotation(
sound_event: data.SoundEvent, sound_event: data.SoundEvent,
sample_term_registry: TermRegistry,
) -> data.SoundEventAnnotation: ) -> data.SoundEventAnnotation:
"""Fixture for a sample SoundEventAnnotation.""" """Fixture for a sample SoundEventAnnotation."""
return data.SoundEventAnnotation( return data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore data.Tag(
data.Tag(key="quality", value="Good"), # type: ignore term=sample_term_registry.get_term("species"),
value="Pipistrellus pipistrellus",
),
data.Tag(
term=sample_term_registry.get_term("quality"),
value="Good",
),
], ],
) )
@ -130,33 +136,59 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
def test_is_target_class_match_all( def test_is_target_class_match_all(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
): ):
tags = { tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore data.Tag(
data.Tag(key="quality", value="Good"), # type: ignore term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
} }
assert is_target_class(sample_annotation, tags, match_all=True) is True assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is True assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is False assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any( def test_is_target_class_match_any(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
): ):
tags = { tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore data.Tag(
data.Tag(key="quality", value="Good"), # type: ignore term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
} }
assert is_target_class(sample_annotation, tags, match_all=False) is True assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is True assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is False assert is_target_class(sample_annotation, tags, match_all=False) is False
@ -176,6 +208,7 @@ def test_get_class_names_from_config():
def test_build_encoder_from_config( def test_build_encoder_from_config(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
): ):
config = ClassesConfig( config = ClassesConfig(
classes=[ classes=[
@ -187,18 +220,25 @@ def test_build_encoder_from_config(
) )
] ]
) )
encoder = build_sound_event_encoder(config) encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation) result = encoder(sample_annotation)
assert result == "pippip" assert result == "pippip"
config = ClassesConfig(classes=[]) config = ClassesConfig(classes=[])
encoder = build_sound_event_encoder(config) encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation) result = encoder(sample_annotation)
assert result is None assert result is None
def test_load_encoder_from_config_valid( def test_load_encoder_from_config_valid(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
): ):
yaml_content = """ yaml_content = """
@ -209,7 +249,10 @@ def test_load_encoder_from_config_valid(
value: Pipistrellus pipistrellus value: Pipistrellus pipistrellus
""" """
temp_yaml_path = create_temp_yaml(yaml_content) temp_yaml_path = create_temp_yaml(yaml_content)
encoder = load_encoder_from_config(temp_yaml_path) encoder = load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)
# We cannot directly compare the function, so we test it. # We cannot directly compare the function, so we test it.
result = encoder(sample_annotation) # type: ignore result = encoder(sample_annotation) # type: ignore
assert result == "pippip" assert result == "pippip"
@ -217,6 +260,7 @@ def test_load_encoder_from_config_valid(
def test_load_encoder_from_config_invalid( def test_load_encoder_from_config_invalid(
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
): ):
yaml_content = """ yaml_content = """
classes: classes:
@ -227,7 +271,10 @@ def test_load_encoder_from_config_invalid(
""" """
temp_yaml_path = create_temp_yaml(yaml_content) temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(KeyError): with pytest.raises(KeyError):
load_encoder_from_config(temp_yaml_path) load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)
def test_get_default_class_name(): def test_get_default_class_name():
@ -244,7 +291,7 @@ def test_get_default_classes():
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0] assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
def test_build_decoder_from_config(): def test_build_decoder_from_config(sample_term_registry: TermRegistry):
config = ClassesConfig( config = ClassesConfig(
classes=[ classes=[
TargetClass( TargetClass(
@ -257,10 +304,12 @@ def test_build_decoder_from_config():
], ],
generic_class=[TagInfo(key="order", value="Chiroptera")], generic_class=[TagInfo(key="order", value="Chiroptera")],
) )
decoder = build_sound_event_decoder(config) decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
tags = decoder("pippip") tags = decoder("pippip")
assert len(tags) == 1 assert len(tags) == 1
assert tags[0].term == get_term("event") assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].value == "Echolocation" assert tags[0].value == "Echolocation"
# Test when output_tags is None, should fall back to tags # Test when output_tags is None, should fall back to tags
@ -275,25 +324,32 @@ def test_build_decoder_from_config():
], ],
generic_class=[TagInfo(key="order", value="Chiroptera")], generic_class=[TagInfo(key="order", value="Chiroptera")],
) )
decoder = build_sound_event_decoder(config) decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
tags = decoder("pippip") tags = decoder("pippip")
assert len(tags) == 1 assert len(tags) == 1
assert tags[0].term == get_term("species") assert tags[0].term == sample_term_registry["species"]
assert tags[0].value == "Pipistrellus pipistrellus" assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True # Test raise_on_unmapped=True
decoder = build_sound_event_decoder(config, raise_on_unmapped=True) decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=True
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
decoder("unknown_class") decoder("unknown_class")
# Test raise_on_unmapped=False # Test raise_on_unmapped=False
decoder = build_sound_event_decoder(config, raise_on_unmapped=False) decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=False
)
tags = decoder("unknown_class") tags = decoder("unknown_class")
assert len(tags) == 0 assert len(tags) == 0
def test_load_decoder_from_config_valid( def test_load_decoder_from_config_valid(
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
): ):
yaml_content = """ yaml_content = """
classes: classes:
@ -310,15 +366,17 @@ def test_load_decoder_from_config_valid(
""" """
temp_yaml_path = create_temp_yaml(yaml_content) temp_yaml_path = create_temp_yaml(yaml_content)
decoder = load_decoder_from_config( decoder = load_decoder_from_config(
temp_yaml_path, temp_yaml_path, term_registry=sample_term_registry
) )
tags = decoder("pippip") tags = decoder("pippip")
assert len(tags) == 1 assert len(tags) == 1
assert tags[0].term == get_term("call_type") assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].value == "Echolocation" assert tags[0].value == "Echolocation"
def test_build_generic_class_tags_from_config(): def test_build_generic_class_tags_from_config(
sample_term_registry: TermRegistry,
):
config = ClassesConfig( config = ClassesConfig(
classes=[ classes=[
TargetClass( TargetClass(
@ -333,9 +391,11 @@ def test_build_generic_class_tags_from_config():
TagInfo(key="call_type", value="Echolocation"), TagInfo(key="call_type", value="Echolocation"),
], ],
) )
generic_tags = build_generic_class_tags(config) generic_tags = build_generic_class_tags(
config, term_registry=sample_term_registry
)
assert len(generic_tags) == 2 assert len(generic_tags) == 2
assert generic_tags[0].term == get_term("order") assert generic_tags[0].term == sample_term_registry["order"]
assert generic_tags[0].value == "Chiroptera" assert generic_tags[0].value == "Chiroptera"
assert generic_tags[1].term == get_term("call_type") assert generic_tags[1].term == sample_term_registry["call_type"]
assert generic_tags[1].value == "Echolocation" assert generic_tags[1].value == "Echolocation"

View File

@ -80,6 +80,7 @@ def test_generated_heatmaps_have_correct_dimensions(
def test_generated_heatmap_are_non_zero_at_correct_positions( def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig, sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
pippip_tag: TagInfo, pippip_tag: TagInfo,
): ):
config = sample_target_config.model_copy( config = sample_target_config.model_copy(
@ -91,7 +92,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
) )
) )
targets = build_targets(config) targets = build_targets(config, term_registry=sample_term_registry)
spec = xr.DataArray( spec = xr.DataArray(
data=np.random.rand(100, 100), data=np.random.rand(100, 100),
@ -112,7 +113,12 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
coordinates=[10, 10, 20, 20], coordinates=[10, 10, 20, 20],
), ),
), ),
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore tags=[
data.Tag(
term=sample_term_registry[pippip_tag.key],
value=pippip_tag.value,
)
],
) )
], ],
) )

View File

@ -2,12 +2,12 @@ import pytest
import torch import torch
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from soundevent.terms import get_term
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.postprocess import build_postprocessor, load_postprocess_config from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
@ -15,6 +15,7 @@ from batdetect2.train.preprocess import generate_train_example
@pytest.fixture @pytest.fixture
def build_from_config( def build_from_config(
create_temp_yaml, create_temp_yaml,
sample_term_registry,
): ):
def build(yaml_content): def build(yaml_content):
config_path = create_temp_yaml(yaml_content) config_path = create_temp_yaml(yaml_content)
@ -30,7 +31,9 @@ def build_from_config(
field="postprocessing", field="postprocessing",
) )
targets = build_targets(targets_config) targets = build_targets(
targets_config, term_registry=sample_term_registry
)
preprocessor = build_preprocessor(preprocessing_config) preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler( labeller = build_clip_labeler(
targets=targets, targets=targets,
@ -51,6 +54,7 @@ def build_from_config(
# TODO: better name # TODO: better name
def test_generated_train_example_has_expected_outputs( def test_generated_train_example_has_expected_outputs(
build_from_config, build_from_config,
sample_term_registry,
recording, recording,
): ):
yaml_content = """ yaml_content = """
@ -74,11 +78,10 @@ def test_generated_train_example_has_expected_outputs(
_, preprocessor, labeller, _ = build_from_config(yaml_content) _, preprocessor, labeller, _ = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation( se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry), sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[ tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
) )
clip_annotation = data.ClipAnnotation( clip_annotation = data.ClipAnnotation(
clip=data.Clip(start_time=0, end_time=0.5, recording=recording), clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
@ -105,6 +108,7 @@ def test_generated_train_example_has_expected_outputs(
def test_encoding_decoding_roundtrip_recovers_object( def test_encoding_decoding_roundtrip_recovers_object(
build_from_config, build_from_config,
sample_term_registry,
recording, recording,
): ):
yaml_content = """ yaml_content = """
@ -127,11 +131,10 @@ def test_encoding_decoding_roundtrip_recovers_object(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content) _, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation( se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry), sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[ tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
) )
clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -168,16 +171,14 @@ def test_encoding_decoding_roundtrip_recovers_object(
assert len(recovered.tags) == 2 assert len(recovered.tags) == 2
predicted_species_tag = next( predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")), iter(t for t in recovered.tags if t.tag.term == species), None
None,
) )
assert predicted_species_tag is not None assert predicted_species_tag is not None
assert predicted_species_tag.score == 1 assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus" assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next( predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")), iter(t for t in recovered.tags if t.tag.term.label == "order"), None
None,
) )
assert predicted_order_tag is not None assert predicted_order_tag is not None
assert predicted_order_tag.score == 1 assert predicted_order_tag.score == 1
@ -186,6 +187,7 @@ def test_encoding_decoding_roundtrip_recovers_object(
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
build_from_config, build_from_config,
sample_term_registry,
recording, recording,
): ):
yaml_content = """ yaml_content = """
@ -215,9 +217,10 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content) _, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation( se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry), sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(term=species, value="Myotis myotis")],
) )
clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -254,16 +257,14 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
assert len(recovered.tags) == 2 assert len(recovered.tags) == 2
predicted_species_tag = next( predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")), iter(t for t in recovered.tags if t.tag.term == species), None
None,
) )
assert predicted_species_tag is not None assert predicted_species_tag is not None
assert predicted_species_tag.score == 1 assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis" assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next( predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")), iter(t for t in recovered.tags if t.tag.term.label == "order"), None
None,
) )
assert predicted_order_tag is not None assert predicted_order_tag is not None
assert predicted_order_tag.score == 1 assert predicted_order_tag.score == 1