mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
6 Commits
b997a122f1
...
c7b110feeb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7b110feeb | ||
|
|
7d92ec772b | ||
|
|
0bb703f3c1 | ||
|
|
81c7b68b0b | ||
|
|
59aaf07af5 | ||
|
|
51d0a49da9 |
@ -85,6 +85,7 @@ dev = [
|
||||
"ty>=0.0.1a12",
|
||||
"rust-just>=1.40.0",
|
||||
"pandas-stubs>=2.2.2.240807",
|
||||
"python-lsp-server>=1.13.0",
|
||||
]
|
||||
dvclive = ["dvclive>=3.48.2"]
|
||||
mlflow = ["mlflow>=3.1.1"]
|
||||
|
||||
@ -179,7 +179,7 @@ def plot_false_positive_match(
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
|
||||
@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol):
|
||||
]
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
targets=self.targets,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
),
|
||||
)
|
||||
@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
targets=self.targets,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
|
||||
@ -33,8 +33,7 @@ import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||
from batdetect2.targets.classes import SoundEventDecoder
|
||||
from batdetect2.utils.arrays import iterate_over_array
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
"""
|
||||
detections = []
|
||||
|
||||
for det_num in range(detection_dataset.sizes["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
categories = detection_dataset.category.values
|
||||
|
||||
highest_scoring_class = det_info.coords["category"][
|
||||
det_info["classes"].argmax()
|
||||
].item()
|
||||
for score, class_scores, time, freq, dims, feats in zip(
|
||||
detection_dataset["scores"].values,
|
||||
detection_dataset["classes"].values,
|
||||
detection_dataset["time"].values,
|
||||
detection_dataset["frequency"].values,
|
||||
detection_dataset["dimensions"].values,
|
||||
detection_dataset["features"].values,
|
||||
):
|
||||
highest_scoring_class = categories[class_scores.argmax()]
|
||||
|
||||
geom = geometry_decoder(
|
||||
(det_info.time, det_info.frequency),
|
||||
det_info.dimensions,
|
||||
(time, freq),
|
||||
dims,
|
||||
class_name=highest_scoring_class,
|
||||
)
|
||||
|
||||
detections.append(
|
||||
RawPrediction(
|
||||
detection_score=det_info.scores,
|
||||
detection_score=score,
|
||||
geometry=geom,
|
||||
class_scores=det_info.classes,
|
||||
features=det_info.features,
|
||||
class_scores=class_scores,
|
||||
features=feats,
|
||||
)
|
||||
)
|
||||
|
||||
@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
def convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions: List[RawPrediction],
|
||||
clip: data.Clip,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
targets: TargetProtocol,
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
) -> data.ClipPrediction:
|
||||
@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
convert_raw_prediction_to_sound_event_prediction(
|
||||
prediction,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=sound_event_decoder,
|
||||
generic_class_tags=generic_class_tags,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
top_class_only=top_class_only,
|
||||
)
|
||||
@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
def convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction: RawPrediction,
|
||||
recording: data.Recording,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
targets: TargetProtocol,
|
||||
classification_threshold: Optional[
|
||||
float
|
||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
tags = [
|
||||
*get_generic_tags(
|
||||
raw_prediction.detection_score,
|
||||
generic_class_tags=generic_class_tags,
|
||||
generic_class_tags=targets.generic_class_tags,
|
||||
),
|
||||
*get_class_tags(
|
||||
raw_prediction.class_scores,
|
||||
sound_event_decoder,
|
||||
targets=targets,
|
||||
top_class_only=top_class_only,
|
||||
threshold=classification_threshold,
|
||||
),
|
||||
@ -297,7 +298,7 @@ def get_generic_tags(
|
||||
]
|
||||
|
||||
|
||||
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||
|
||||
Parameters
|
||||
@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
||||
return [
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
name=f"batdetect2:{feat_name}",
|
||||
label=feat_name,
|
||||
name=f"batdetect2:f{index}",
|
||||
label=f"BatDetect Feature {index}",
|
||||
definition="Automatically extracted features by BatDetect2",
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
for feat_name, value in iterate_over_array(features)
|
||||
for index, value in enumerate(features)
|
||||
]
|
||||
|
||||
|
||||
def get_class_tags(
|
||||
class_scores: xr.DataArray,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
class_scores: np.ndarray,
|
||||
targets: TargetProtocol,
|
||||
top_class_only: bool = False,
|
||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.PredictedTag]:
|
||||
@ -367,11 +368,13 @@ def get_class_tags(
|
||||
"""
|
||||
tags = []
|
||||
|
||||
if threshold is not None:
|
||||
class_scores = class_scores.where(class_scores > threshold, drop=True)
|
||||
for class_name, score in _iterate_sorted(
|
||||
class_scores, targets.class_names
|
||||
):
|
||||
if threshold is not None and score < threshold:
|
||||
continue
|
||||
|
||||
for class_name, score in _iterate_sorted(class_scores):
|
||||
class_tags = sound_event_decoder(class_name)
|
||||
class_tags = targets.decode_class(class_name)
|
||||
|
||||
for tag in class_tags:
|
||||
tags.append(
|
||||
@ -387,9 +390,7 @@ def get_class_tags(
|
||||
return tags
|
||||
|
||||
|
||||
def _iterate_sorted(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name].values
|
||||
indices = np.argsort(-array.values)
|
||||
def _iterate_sorted(array: np.ndarray, class_names: List[str]):
|
||||
indices = np.argsort(-array)
|
||||
for index in indices:
|
||||
yield str(coords[index]), float(array.values[index])
|
||||
yield str(class_names[index]), float(array[index])
|
||||
|
||||
@ -14,6 +14,7 @@ system that deal with model predictions.
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
@ -72,8 +73,8 @@ class RawPrediction(NamedTuple):
|
||||
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: xr.DataArray
|
||||
features: xr.DataArray
|
||||
class_scores: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -91,6 +91,7 @@ terms.register_term_set(
|
||||
"individual": individual.name,
|
||||
"event": call_type.name,
|
||||
"source": data_source.name,
|
||||
"call_type": call_type.name,
|
||||
},
|
||||
),
|
||||
override_existing=True,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from typing import List, Optional, Tuple
|
||||
@ -15,7 +16,10 @@ from batdetect2.evaluate.match import (
|
||||
)
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.postprocess.types import (
|
||||
BatDetect2Prediction,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -114,33 +118,51 @@ class ValidationMetrics(Callback):
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
_get_subclip(
|
||||
dataset.get_clip_annotation(example_id),
|
||||
start_time=start_time.item(),
|
||||
end_time=end_time.item(),
|
||||
self._matches.extend(
|
||||
_get_batch_clips_and_predictions(
|
||||
batch,
|
||||
outputs,
|
||||
dataset=self.get_dataset(trainer),
|
||||
postprocessor=pl_module.postprocessor,
|
||||
targets=pl_module.targets,
|
||||
)
|
||||
for example_id, start_time, end_time in zip(
|
||||
batch.idx,
|
||||
batch.start_time,
|
||||
batch.end_time,
|
||||
)
|
||||
]
|
||||
|
||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||
|
||||
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
|
||||
outputs,
|
||||
clips,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_clips_and_predictions(
|
||||
batch: TrainExample,
|
||||
outputs: ModelOutput,
|
||||
dataset: LabeledDataset,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
||||
clip_annotations = [
|
||||
_get_subclip(
|
||||
dataset.get_clip_annotation(example_id),
|
||||
start_time=start_time.item(),
|
||||
end_time=end_time.item(),
|
||||
targets=targets,
|
||||
)
|
||||
for example_id, start_time, end_time in zip(
|
||||
batch.idx,
|
||||
batch.start_time,
|
||||
batch.end_time,
|
||||
)
|
||||
]
|
||||
|
||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||
|
||||
raw_predictions = postprocessor.get_sound_event_predictions(
|
||||
outputs,
|
||||
clips,
|
||||
)
|
||||
|
||||
return [
|
||||
(clip_annotation, clip_predictions)
|
||||
for clip_annotation, clip_predictions in zip(
|
||||
clip_annotations, raw_predictions
|
||||
):
|
||||
self._matches.append((clip_annotation, clip_predictions))
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _match_all_collected_examples(
|
||||
@ -150,7 +172,8 @@ def _match_all_collected_examples(
|
||||
) -> List[MatchEvaluation]:
|
||||
logger.info("Matching all annotations and predictions")
|
||||
|
||||
with Pool() as p:
|
||||
cpu_count = os.cpu_count() or 1
|
||||
with Pool(processes=min(cpu_count, 4)) as p:
|
||||
matches = p.starmap(
|
||||
partial(
|
||||
match_sound_events_and_raw_predictions,
|
||||
|
||||
@ -254,11 +254,13 @@ class TestLoadBatDetect2Files:
|
||||
assert clip_ann.clip.recording.duration == 5.0
|
||||
assert len(clip_ann.sound_events) == 1
|
||||
assert clip_ann.notes[0].message == "Standard notes."
|
||||
clip_tag = data.find_tag(clip_ann.tags, "Class")
|
||||
clip_tag = data.find_tag(clip_ann.tags, term_label="Class")
|
||||
assert clip_tag is not None
|
||||
assert clip_tag.value == "Myotis"
|
||||
|
||||
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
|
||||
recording_tag = data.find_tag(
|
||||
clip_ann.clip.recording.tags, term_label="Class"
|
||||
)
|
||||
assert recording_tag is not None
|
||||
assert recording_tag.value == "Myotis"
|
||||
|
||||
@ -271,15 +273,15 @@ class TestLoadBatDetect2Files:
|
||||
40000,
|
||||
]
|
||||
|
||||
se_class_tag = data.find_tag(se_ann.tags, "Class")
|
||||
se_class_tag = data.find_tag(se_ann.tags, term_label="Class")
|
||||
assert se_class_tag is not None
|
||||
assert se_class_tag.value == "Myotis"
|
||||
|
||||
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
|
||||
se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type")
|
||||
assert se_event_tag is not None
|
||||
assert se_event_tag.value == "Echolocation"
|
||||
|
||||
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
|
||||
se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual")
|
||||
assert se_individual_tag is not None
|
||||
assert se_individual_tag.value == "0"
|
||||
|
||||
@ -439,7 +441,7 @@ class TestLoadBatDetect2Merged:
|
||||
assert clip_ann.clip.recording.duration == 5.0
|
||||
assert len(clip_ann.sound_events) == 1
|
||||
|
||||
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
|
||||
clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class")
|
||||
assert clip_class_tag is not None
|
||||
assert clip_class_tag.value == "Myotis"
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -16,35 +16,11 @@ from batdetect2.postprocess.decoding import (
|
||||
get_prediction_features,
|
||||
)
|
||||
from batdetect2.postprocess.types import RawPrediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_geometry_builder():
|
||||
"""A simple GeometryBuilder that creates a BBox around the point."""
|
||||
|
||||
def _builder(
|
||||
position: Tuple[float, float],
|
||||
dimensions: xr.DataArray,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.BoundingBox:
|
||||
time, freq = position
|
||||
width = dimensions.sel(dimension="width").item()
|
||||
height = dimensions.sel(dimension="height").item()
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - width / 2,
|
||||
freq - height / 2,
|
||||
time + width / 2,
|
||||
freq + height / 2,
|
||||
]
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_sound_event_decoder():
|
||||
"""A simple SoundEventDecoder mapping names to tags."""
|
||||
def dummy_targets() -> TargetProtocol:
|
||||
tag_map = {
|
||||
"bat": [
|
||||
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
|
||||
@ -57,18 +33,56 @@ def dummy_sound_event_decoder():
|
||||
],
|
||||
}
|
||||
|
||||
def _decoder(class_name: str) -> List[data.Tag]:
|
||||
return tag_map.get(class_name.lower(), [])
|
||||
class DummyTargets(TargetProtocol):
|
||||
class_names = [
|
||||
"bat",
|
||||
"noise",
|
||||
"unknown",
|
||||
]
|
||||
|
||||
return _decoder
|
||||
dimension_names = ["width", "height"]
|
||||
|
||||
generic_class_tags = [
|
||||
data.Tag(
|
||||
term=data.term_from_key(key="detector"), value="batdetect2"
|
||||
)
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def generic_tags() -> List[data.Tag]:
|
||||
"""Sample generic tags."""
|
||||
return [
|
||||
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
|
||||
]
|
||||
def filter(self, sound_event: data.SoundEventAnnotation):
|
||||
return True
|
||||
|
||||
def transform(self, sound_event: data.SoundEventAnnotation):
|
||||
return sound_event
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
return "bat"
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
return tag_map.get(class_label.lower(), [])
|
||||
|
||||
def encode_roi(self, sound_event: data.SoundEventAnnotation):
|
||||
return np.array([0.0, 0.0]), np.array([0.0, 0.0])
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position,
|
||||
size: np.ndarray,
|
||||
class_name: Optional[str] = None,
|
||||
):
|
||||
time, freq = position
|
||||
width, height = size
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - width / 2,
|
||||
freq - height / 2,
|
||||
time + width / 2,
|
||||
freq + height / 2,
|
||||
]
|
||||
)
|
||||
|
||||
return DummyTargets()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -156,7 +170,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
"""Creates an empty detection dataset with correct structure."""
|
||||
detection_coords = {
|
||||
"time": ("detection", np.array([], dtype=np.float64)),
|
||||
"freq": ("detection", np.array([], dtype=np.float64)),
|
||||
"frequency": ("detection", np.array([], dtype=np.float64)),
|
||||
}
|
||||
scores = xr.DataArray(
|
||||
np.array([], dtype=np.float64),
|
||||
@ -184,7 +198,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"score": scores,
|
||||
"scores": scores,
|
||||
"dimensions": dimensions,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
@ -215,8 +229,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
300 + 16 / 2,
|
||||
]
|
||||
),
|
||||
class_scores=pred1_classes,
|
||||
features=pred1_features,
|
||||
class_scores=pred1_classes.values,
|
||||
features=pred1_features.values,
|
||||
)
|
||||
|
||||
pred2_classes = xr.DataArray(
|
||||
@ -237,8 +251,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
200 + 12 / 2,
|
||||
]
|
||||
),
|
||||
class_scores=pred2_classes,
|
||||
features=pred2_features,
|
||||
class_scores=pred2_classes.values,
|
||||
features=pred2_features.values,
|
||||
)
|
||||
|
||||
pred3_classes = xr.DataArray(
|
||||
@ -259,18 +273,17 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
60.0,
|
||||
]
|
||||
),
|
||||
class_scores=pred3_classes,
|
||||
features=pred3_features,
|
||||
class_scores=pred3_classes.values,
|
||||
features=pred3_features.values,
|
||||
)
|
||||
return [pred1, pred2, pred3]
|
||||
|
||||
|
||||
def test_convert_xr_dataset_basic(
|
||||
sample_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
|
||||
"""Test basic conversion of a dataset to RawPrediction list."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
sample_detection_dataset, dummy_geometry_builder
|
||||
sample_detection_dataset,
|
||||
dummy_targets.decode_roi,
|
||||
)
|
||||
|
||||
assert isinstance(raw_predictions, list)
|
||||
@ -286,11 +299,11 @@ def test_convert_xr_dataset_basic(
|
||||
20 + 7 / 2,
|
||||
300 + 16 / 2,
|
||||
]
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred1.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=0),
|
||||
)
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
||||
)
|
||||
|
||||
@ -304,21 +317,20 @@ def test_convert_xr_dataset_basic(
|
||||
10 + 3 / 2,
|
||||
200 + 12 / 2,
|
||||
]
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred2.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=1),
|
||||
)
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred2.features, sample_detection_dataset["features"].sel(detection=1)
|
||||
)
|
||||
|
||||
|
||||
def test_convert_xr_dataset_empty(
|
||||
empty_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
|
||||
"""Test conversion of an empty dataset."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
empty_detection_dataset, dummy_geometry_builder
|
||||
empty_detection_dataset,
|
||||
dummy_targets.decode_roi,
|
||||
)
|
||||
assert isinstance(raw_predictions, list)
|
||||
assert len(raw_predictions) == 0
|
||||
@ -327,8 +339,7 @@ def test_convert_xr_dataset_empty(
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test basic conversion, default threshold, multi-label."""
|
||||
|
||||
@ -337,8 +348,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
|
||||
assert isinstance(se_pred, data.SoundEventPrediction)
|
||||
@ -357,10 +367,11 @@ def test_convert_raw_to_sound_event_basic(
|
||||
)
|
||||
assert feat_dict["batdetect2:f0"] == 7.0
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("soundevent:species", "Myotis", 0.43),
|
||||
("category", "noise", 0.85),
|
||||
("dwc:scientificName", "Myotis", 0.43),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -369,10 +380,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_thresholding(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
sample_raw_predictions, sample_recording, dummy_targets
|
||||
):
|
||||
"""Test effect of classification threshold."""
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
@ -381,15 +389,15 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=high_threshold,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("category", "noise", 0.85),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -400,8 +408,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
def test_convert_raw_to_sound_event_no_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test when classification_threshold is None."""
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
@ -409,16 +416,16 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=None,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
("soundevent:species", "Myotis", 0.05),
|
||||
("soundevent:category", "noise", 0.02),
|
||||
("dwc:scientificName", "Myotis", 0.05),
|
||||
("category", "noise", 0.02),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -429,8 +436,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
def test_convert_raw_to_sound_event_top_class(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test top_class_only=True behavior."""
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
@ -438,15 +444,15 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("category", "noise", 0.85),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -457,8 +463,7 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test when all class scores are below the default threshold."""
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
@ -466,12 +471,12 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
@ -484,15 +489,13 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
def test_convert_raw_list_to_clip_basic(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=sample_raw_predictions,
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=False,
|
||||
)
|
||||
@ -515,23 +518,19 @@ def test_convert_raw_list_to_clip_basic(
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[2].tags
|
||||
}
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags3 = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
assert se_pred3_tags == expected_tags3
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_empty(
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
||||
"""Test converting an empty list of RawPredictions."""
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=[],
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
|
||||
assert isinstance(clip_pred, data.ClipPrediction)
|
||||
@ -542,16 +541,14 @@ def test_convert_raw_list_to_clip_empty(
|
||||
def test_convert_raw_list_to_clip_passes_args(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test that arguments like top_class_only are passed through."""
|
||||
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=sample_raw_predictions,
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True,
|
||||
)
|
||||
@ -562,16 +559,18 @@ def test_convert_raw_list_to_clip_passes_args(
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[0].tags
|
||||
}
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags1 = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("category", "noise", 0.85),
|
||||
}
|
||||
assert se_pred1_tags == expected_tags1
|
||||
|
||||
|
||||
def test_get_generic_tags_basic(generic_tags):
|
||||
def test_get_generic_tags_basic(dummy_targets):
|
||||
"""Test creation of generic tags with score."""
|
||||
detection_score = 0.75
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
predicted_tags = get_generic_tags(
|
||||
detection_score=detection_score, generic_class_tags=generic_tags
|
||||
)
|
||||
@ -589,17 +588,19 @@ def test_get_prediction_features_basic():
|
||||
coords={"feature": ["feat1", "feat2", "feat3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
features = get_prediction_features(feature_data)
|
||||
features = get_prediction_features(feature_data.values)
|
||||
assert len(features) == 3
|
||||
for feature, feat_name, feat_value in zip(
|
||||
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
|
||||
features,
|
||||
["f0", "f1", "f2"],
|
||||
[1.1, 2.2, 3.3],
|
||||
):
|
||||
assert isinstance(feature, data.Feature)
|
||||
assert feature.term.name == f"batdetect2:{feat_name}"
|
||||
assert feature.value == feat_value
|
||||
|
||||
|
||||
def test_get_class_tags_basic(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_basic(dummy_targets):
|
||||
"""Test creation of class tags based on scores and decoder."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
@ -607,8 +608,8 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
|
||||
dims=["category"],
|
||||
)
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
assert len(predicted_tags) == 3
|
||||
tag_values = [pt.tag.value for pt in predicted_tags]
|
||||
@ -622,7 +623,7 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
|
||||
assert 0.9 in tag_scores
|
||||
|
||||
|
||||
def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_thresholding(dummy_targets):
|
||||
"""Test class tag creation with a threshold."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
@ -631,8 +632,8 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
||||
)
|
||||
threshold = 0.5
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
@ -643,7 +644,7 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
||||
assert "uncertain" in tag_values
|
||||
|
||||
|
||||
def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_top_class_only(dummy_targets):
|
||||
"""Test class tag creation with top_class_only."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
@ -651,8 +652,8 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
||||
dims=["category"],
|
||||
)
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
@ -661,11 +662,11 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
||||
assert predicted_tags[0].score == 0.9
|
||||
|
||||
|
||||
def test_get_class_tags_empty(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_empty(dummy_targets):
|
||||
"""Test with empty class scores."""
|
||||
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
assert len(predicted_tags) == 0
|
||||
|
||||
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_SPECIES_LIST,
|
||||
@ -21,26 +22,19 @@ from batdetect2.targets.classes import (
|
||||
load_decoder_from_config,
|
||||
load_encoder_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
sample_term_registry: TermRegistry,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Fixture for a sample SoundEventAnnotation."""
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=sample_term_registry.get_term("species"),
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(
|
||||
term=sample_term_registry.get_term("quality"),
|
||||
value="Good",
|
||||
),
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
],
|
||||
)
|
||||
|
||||
@ -136,59 +130,33 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
||||
|
||||
def test_is_target_class_match_all(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||
|
||||
|
||||
def test_is_target_class_match_any(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||
|
||||
|
||||
@ -208,7 +176,6 @@ def test_get_class_names_from_config():
|
||||
|
||||
def test_build_encoder_from_config(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
@ -220,25 +187,18 @@ def test_build_encoder_from_config(
|
||||
)
|
||||
]
|
||||
)
|
||||
encoder = build_sound_event_encoder(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
encoder = build_sound_event_encoder(config)
|
||||
result = encoder(sample_annotation)
|
||||
assert result == "pippip"
|
||||
|
||||
config = ClassesConfig(classes=[])
|
||||
encoder = build_sound_event_encoder(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
encoder = build_sound_event_encoder(config)
|
||||
result = encoder(sample_annotation)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_encoder_from_config_valid(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
yaml_content = """
|
||||
@ -249,10 +209,7 @@ def test_load_encoder_from_config_valid(
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
encoder = load_encoder_from_config(
|
||||
temp_yaml_path,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
encoder = load_encoder_from_config(temp_yaml_path)
|
||||
# We cannot directly compare the function, so we test it.
|
||||
result = encoder(sample_annotation) # type: ignore
|
||||
assert result == "pippip"
|
||||
@ -260,7 +217,6 @@ def test_load_encoder_from_config_valid(
|
||||
|
||||
def test_load_encoder_from_config_invalid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
@ -271,10 +227,7 @@ def test_load_encoder_from_config_invalid(
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
with pytest.raises(KeyError):
|
||||
load_encoder_from_config(
|
||||
temp_yaml_path,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
load_encoder_from_config(temp_yaml_path)
|
||||
|
||||
|
||||
def test_get_default_class_name():
|
||||
@ -291,7 +244,7 @@ def test_get_default_classes():
|
||||
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
||||
|
||||
|
||||
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
def test_build_decoder_from_config():
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
@ -304,12 +257,10 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
decoder = build_sound_event_decoder(config)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["call_type"]
|
||||
assert tags[0].term == get_term("event")
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
# Test when output_tags is None, should fall back to tags
|
||||
@ -324,32 +275,25 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
decoder = build_sound_event_decoder(config)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["species"]
|
||||
assert tags[0].term == get_term("species")
|
||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||
|
||||
# Test raise_on_unmapped=True
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
||||
)
|
||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
|
||||
with pytest.raises(ValueError):
|
||||
decoder("unknown_class")
|
||||
|
||||
# Test raise_on_unmapped=False
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
||||
)
|
||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
|
||||
tags = decoder("unknown_class")
|
||||
assert len(tags) == 0
|
||||
|
||||
|
||||
def test_load_decoder_from_config_valid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
@ -366,17 +310,15 @@ def test_load_decoder_from_config_valid(
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
decoder = load_decoder_from_config(
|
||||
temp_yaml_path, term_registry=sample_term_registry
|
||||
temp_yaml_path,
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["call_type"]
|
||||
assert tags[0].term == get_term("call_type")
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
|
||||
def test_build_generic_class_tags_from_config(
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
def test_build_generic_class_tags_from_config():
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
@ -391,11 +333,9 @@ def test_build_generic_class_tags_from_config(
|
||||
TagInfo(key="call_type", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
generic_tags = build_generic_class_tags(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
generic_tags = build_generic_class_tags(config)
|
||||
assert len(generic_tags) == 2
|
||||
assert generic_tags[0].term == sample_term_registry["order"]
|
||||
assert generic_tags[0].term == get_term("order")
|
||||
assert generic_tags[0].value == "Chiroptera"
|
||||
assert generic_tags[1].term == sample_term_registry["call_type"]
|
||||
assert generic_tags[1].term == get_term("call_type")
|
||||
assert generic_tags[1].value == "Echolocation"
|
||||
|
||||
@ -80,7 +80,6 @@ def test_generated_heatmaps_have_correct_dimensions(
|
||||
|
||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
sample_target_config: TargetConfig,
|
||||
sample_term_registry: TermRegistry,
|
||||
pippip_tag: TagInfo,
|
||||
):
|
||||
config = sample_target_config.model_copy(
|
||||
@ -92,7 +91,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
)
|
||||
)
|
||||
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
targets = build_targets(config)
|
||||
|
||||
spec = xr.DataArray(
|
||||
data=np.random.rand(100, 100),
|
||||
@ -113,12 +112,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
coordinates=[10, 10, 20, 20],
|
||||
),
|
||||
),
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=sample_term_registry[pippip_tag.key],
|
||||
value=pippip_tag.value,
|
||||
)
|
||||
],
|
||||
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@ -2,12 +2,12 @@ import pytest
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
|
||||
@ -15,7 +15,6 @@ from batdetect2.train.preprocess import generate_train_example
|
||||
@pytest.fixture
|
||||
def build_from_config(
|
||||
create_temp_yaml,
|
||||
sample_term_registry,
|
||||
):
|
||||
def build(yaml_content):
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
@ -31,9 +30,7 @@ def build_from_config(
|
||||
field="postprocessing",
|
||||
)
|
||||
|
||||
targets = build_targets(
|
||||
targets_config, term_registry=sample_term_registry
|
||||
)
|
||||
targets = build_targets(targets_config)
|
||||
preprocessor = build_preprocessor(preprocessing_config)
|
||||
labeller = build_clip_labeler(
|
||||
targets=targets,
|
||||
@ -54,7 +51,6 @@ def build_from_config(
|
||||
# TODO: better name
|
||||
def test_generated_train_example_has_expected_outputs(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
@ -78,10 +74,11 @@ def test_generated_train_example_has_expected_outputs(
|
||||
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
],
|
||||
)
|
||||
clip_annotation = data.ClipAnnotation(
|
||||
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
||||
@ -108,7 +105,6 @@ def test_generated_train_example_has_expected_outputs(
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
@ -131,10 +127,11 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
],
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
@ -171,14 +168,16 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
@ -187,7 +186,6 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
@ -217,10 +215,9 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
@ -257,14 +254,16 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user