mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 09:29:33 +01:00
Compare commits
6 Commits
b997a122f1
...
c7b110feeb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7b110feeb | ||
|
|
7d92ec772b | ||
|
|
0bb703f3c1 | ||
|
|
81c7b68b0b | ||
|
|
59aaf07af5 | ||
|
|
51d0a49da9 |
@ -85,6 +85,7 @@ dev = [
|
|||||||
"ty>=0.0.1a12",
|
"ty>=0.0.1a12",
|
||||||
"rust-just>=1.40.0",
|
"rust-just>=1.40.0",
|
||||||
"pandas-stubs>=2.2.2.240807",
|
"pandas-stubs>=2.2.2.240807",
|
||||||
|
"python-lsp-server>=1.13.0",
|
||||||
]
|
]
|
||||||
dvclive = ["dvclive>=3.48.2"]
|
dvclive = ["dvclive>=3.48.2"]
|
||||||
mlflow = ["mlflow>=3.1.1"]
|
mlflow = ["mlflow>=3.1.1"]
|
||||||
|
|||||||
@ -179,7 +179,7 @@ def plot_false_positive_match(
|
|||||||
plt.text(
|
plt.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
|
|||||||
@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_sound_event_predictions(
|
def get_sound_event_predictions(
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
) -> List[List[BatDetect2Prediction]]:
|
) -> List[List[BatDetect2Prediction]]:
|
||||||
raw_predictions = self.get_raw_predictions(output, clips)
|
raw_predictions = self.get_raw_predictions(output, clips)
|
||||||
return [
|
return [
|
||||||
@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw,
|
raw,
|
||||||
recording=clip.recording,
|
recording=clip.recording,
|
||||||
sound_event_decoder=self.targets.decode_class,
|
targets=self.targets,
|
||||||
generic_class_tags=self.targets.generic_class_tags,
|
|
||||||
classification_threshold=self.config.classification_threshold,
|
classification_threshold=self.config.classification_threshold,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
convert_raw_predictions_to_clip_prediction(
|
convert_raw_predictions_to_clip_prediction(
|
||||||
prediction,
|
prediction,
|
||||||
clip,
|
clip,
|
||||||
sound_event_decoder=self.targets.decode_class,
|
targets=self.targets,
|
||||||
generic_class_tags=self.targets.generic_class_tags,
|
|
||||||
classification_threshold=self.config.classification_threshold,
|
classification_threshold=self.config.classification_threshold,
|
||||||
)
|
)
|
||||||
for prediction, clip in zip(raw_predictions, clips)
|
for prediction, clip in zip(raw_predictions, clips)
|
||||||
|
|||||||
@ -33,8 +33,7 @@ import xarray as xr
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||||
from batdetect2.targets.classes import SoundEventDecoder
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_xr_dataset_to_raw_prediction",
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
"""
|
"""
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
for det_num in range(detection_dataset.sizes["detection"]):
|
categories = detection_dataset.category.values
|
||||||
det_info = detection_dataset.sel(detection=det_num)
|
|
||||||
|
|
||||||
highest_scoring_class = det_info.coords["category"][
|
for score, class_scores, time, freq, dims, feats in zip(
|
||||||
det_info["classes"].argmax()
|
detection_dataset["scores"].values,
|
||||||
].item()
|
detection_dataset["classes"].values,
|
||||||
|
detection_dataset["time"].values,
|
||||||
|
detection_dataset["frequency"].values,
|
||||||
|
detection_dataset["dimensions"].values,
|
||||||
|
detection_dataset["features"].values,
|
||||||
|
):
|
||||||
|
highest_scoring_class = categories[class_scores.argmax()]
|
||||||
|
|
||||||
geom = geometry_decoder(
|
geom = geometry_decoder(
|
||||||
(det_info.time, det_info.frequency),
|
(time, freq),
|
||||||
det_info.dimensions,
|
dims,
|
||||||
class_name=highest_scoring_class,
|
class_name=highest_scoring_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
detections.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=det_info.scores,
|
detection_score=score,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
class_scores=det_info.classes,
|
class_scores=class_scores,
|
||||||
features=det_info.features,
|
features=feats,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
def convert_raw_predictions_to_clip_prediction(
|
def convert_raw_predictions_to_clip_prediction(
|
||||||
raw_predictions: List[RawPrediction],
|
raw_predictions: List[RawPrediction],
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
sound_event_decoder: SoundEventDecoder,
|
targets: TargetProtocol,
|
||||||
generic_class_tags: List[data.Tag],
|
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
) -> data.ClipPrediction:
|
) -> data.ClipPrediction:
|
||||||
@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
convert_raw_prediction_to_sound_event_prediction(
|
convert_raw_prediction_to_sound_event_prediction(
|
||||||
prediction,
|
prediction,
|
||||||
recording=clip.recording,
|
recording=clip.recording,
|
||||||
sound_event_decoder=sound_event_decoder,
|
targets=targets,
|
||||||
generic_class_tags=generic_class_tags,
|
|
||||||
classification_threshold=classification_threshold,
|
classification_threshold=classification_threshold,
|
||||||
top_class_only=top_class_only,
|
top_class_only=top_class_only,
|
||||||
)
|
)
|
||||||
@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
def convert_raw_prediction_to_sound_event_prediction(
|
def convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction: RawPrediction,
|
raw_prediction: RawPrediction,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
sound_event_decoder: SoundEventDecoder,
|
targets: TargetProtocol,
|
||||||
generic_class_tags: List[data.Tag],
|
|
||||||
classification_threshold: Optional[
|
classification_threshold: Optional[
|
||||||
float
|
float
|
||||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
tags = [
|
tags = [
|
||||||
*get_generic_tags(
|
*get_generic_tags(
|
||||||
raw_prediction.detection_score,
|
raw_prediction.detection_score,
|
||||||
generic_class_tags=generic_class_tags,
|
generic_class_tags=targets.generic_class_tags,
|
||||||
),
|
),
|
||||||
*get_class_tags(
|
*get_class_tags(
|
||||||
raw_prediction.class_scores,
|
raw_prediction.class_scores,
|
||||||
sound_event_decoder,
|
targets=targets,
|
||||||
top_class_only=top_class_only,
|
top_class_only=top_class_only,
|
||||||
threshold=classification_threshold,
|
threshold=classification_threshold,
|
||||||
),
|
),
|
||||||
@ -297,7 +298,7 @@ def get_generic_tags(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
|||||||
return [
|
return [
|
||||||
data.Feature(
|
data.Feature(
|
||||||
term=data.Term(
|
term=data.Term(
|
||||||
name=f"batdetect2:{feat_name}",
|
name=f"batdetect2:f{index}",
|
||||||
label=feat_name,
|
label=f"BatDetect Feature {index}",
|
||||||
definition="Automatically extracted features by BatDetect2",
|
definition="Automatically extracted features by BatDetect2",
|
||||||
),
|
),
|
||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
for feat_name, value in iterate_over_array(features)
|
for index, value in enumerate(features)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_class_tags(
|
def get_class_tags(
|
||||||
class_scores: xr.DataArray,
|
class_scores: np.ndarray,
|
||||||
sound_event_decoder: SoundEventDecoder,
|
targets: TargetProtocol,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
@ -367,11 +368,13 @@ def get_class_tags(
|
|||||||
"""
|
"""
|
||||||
tags = []
|
tags = []
|
||||||
|
|
||||||
if threshold is not None:
|
for class_name, score in _iterate_sorted(
|
||||||
class_scores = class_scores.where(class_scores > threshold, drop=True)
|
class_scores, targets.class_names
|
||||||
|
):
|
||||||
|
if threshold is not None and score < threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
for class_name, score in _iterate_sorted(class_scores):
|
class_tags = targets.decode_class(class_name)
|
||||||
class_tags = sound_event_decoder(class_name)
|
|
||||||
|
|
||||||
for tag in class_tags:
|
for tag in class_tags:
|
||||||
tags.append(
|
tags.append(
|
||||||
@ -387,9 +390,7 @@ def get_class_tags(
|
|||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
def _iterate_sorted(array: xr.DataArray):
|
def _iterate_sorted(array: np.ndarray, class_names: List[str]):
|
||||||
dim_name = array.dims[0]
|
indices = np.argsort(-array)
|
||||||
coords = array.coords[dim_name].values
|
|
||||||
indices = np.argsort(-array.values)
|
|
||||||
for index in indices:
|
for index in indices:
|
||||||
yield str(coords[index]), float(array.values[index])
|
yield str(class_names[index]), float(array[index])
|
||||||
|
|||||||
@ -14,6 +14,7 @@ system that deal with model predictions.
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional, Protocol
|
from typing import List, NamedTuple, Optional, Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -72,8 +73,8 @@ class RawPrediction(NamedTuple):
|
|||||||
|
|
||||||
geometry: data.Geometry
|
geometry: data.Geometry
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: xr.DataArray
|
class_scores: np.ndarray
|
||||||
features: xr.DataArray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -91,6 +91,7 @@ terms.register_term_set(
|
|||||||
"individual": individual.name,
|
"individual": individual.name,
|
||||||
"event": call_type.name,
|
"event": call_type.name,
|
||||||
"source": data_source.name,
|
"source": data_source.name,
|
||||||
|
"call_type": call_type.name,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
override_existing=True,
|
override_existing=True,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
@ -15,7 +16,10 @@ from batdetect2.evaluate.match import (
|
|||||||
)
|
)
|
||||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
from batdetect2.postprocess.types import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
@ -114,33 +118,51 @@ class ValidationMetrics(Callback):
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
dataset = self.get_dataset(trainer)
|
self._matches.extend(
|
||||||
|
_get_batch_clips_and_predictions(
|
||||||
clip_annotations = [
|
batch,
|
||||||
_get_subclip(
|
outputs,
|
||||||
dataset.get_clip_annotation(example_id),
|
dataset=self.get_dataset(trainer),
|
||||||
start_time=start_time.item(),
|
postprocessor=pl_module.postprocessor,
|
||||||
end_time=end_time.item(),
|
|
||||||
targets=pl_module.targets,
|
targets=pl_module.targets,
|
||||||
)
|
)
|
||||||
for example_id, start_time, end_time in zip(
|
|
||||||
batch.idx,
|
|
||||||
batch.start_time,
|
|
||||||
batch.end_time,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
|
||||||
|
|
||||||
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
|
|
||||||
outputs,
|
|
||||||
clips,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_clips_and_predictions(
|
||||||
|
batch: TrainExample,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
dataset: LabeledDataset,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
||||||
|
clip_annotations = [
|
||||||
|
_get_subclip(
|
||||||
|
dataset.get_clip_annotation(example_id),
|
||||||
|
start_time=start_time.item(),
|
||||||
|
end_time=end_time.item(),
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
for example_id, start_time, end_time in zip(
|
||||||
|
batch.idx,
|
||||||
|
batch.start_time,
|
||||||
|
batch.end_time,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||||
|
|
||||||
|
raw_predictions = postprocessor.get_sound_event_predictions(
|
||||||
|
outputs,
|
||||||
|
clips,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
(clip_annotation, clip_predictions)
|
||||||
for clip_annotation, clip_predictions in zip(
|
for clip_annotation, clip_predictions in zip(
|
||||||
clip_annotations, raw_predictions
|
clip_annotations, raw_predictions
|
||||||
):
|
)
|
||||||
self._matches.append((clip_annotation, clip_predictions))
|
]
|
||||||
|
|
||||||
|
|
||||||
def _match_all_collected_examples(
|
def _match_all_collected_examples(
|
||||||
@ -150,7 +172,8 @@ def _match_all_collected_examples(
|
|||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
logger.info("Matching all annotations and predictions")
|
logger.info("Matching all annotations and predictions")
|
||||||
|
|
||||||
with Pool() as p:
|
cpu_count = os.cpu_count() or 1
|
||||||
|
with Pool(processes=min(cpu_count, 4)) as p:
|
||||||
matches = p.starmap(
|
matches = p.starmap(
|
||||||
partial(
|
partial(
|
||||||
match_sound_events_and_raw_predictions,
|
match_sound_events_and_raw_predictions,
|
||||||
|
|||||||
@ -254,11 +254,13 @@ class TestLoadBatDetect2Files:
|
|||||||
assert clip_ann.clip.recording.duration == 5.0
|
assert clip_ann.clip.recording.duration == 5.0
|
||||||
assert len(clip_ann.sound_events) == 1
|
assert len(clip_ann.sound_events) == 1
|
||||||
assert clip_ann.notes[0].message == "Standard notes."
|
assert clip_ann.notes[0].message == "Standard notes."
|
||||||
clip_tag = data.find_tag(clip_ann.tags, "Class")
|
clip_tag = data.find_tag(clip_ann.tags, term_label="Class")
|
||||||
assert clip_tag is not None
|
assert clip_tag is not None
|
||||||
assert clip_tag.value == "Myotis"
|
assert clip_tag.value == "Myotis"
|
||||||
|
|
||||||
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
|
recording_tag = data.find_tag(
|
||||||
|
clip_ann.clip.recording.tags, term_label="Class"
|
||||||
|
)
|
||||||
assert recording_tag is not None
|
assert recording_tag is not None
|
||||||
assert recording_tag.value == "Myotis"
|
assert recording_tag.value == "Myotis"
|
||||||
|
|
||||||
@ -271,15 +273,15 @@ class TestLoadBatDetect2Files:
|
|||||||
40000,
|
40000,
|
||||||
]
|
]
|
||||||
|
|
||||||
se_class_tag = data.find_tag(se_ann.tags, "Class")
|
se_class_tag = data.find_tag(se_ann.tags, term_label="Class")
|
||||||
assert se_class_tag is not None
|
assert se_class_tag is not None
|
||||||
assert se_class_tag.value == "Myotis"
|
assert se_class_tag.value == "Myotis"
|
||||||
|
|
||||||
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
|
se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type")
|
||||||
assert se_event_tag is not None
|
assert se_event_tag is not None
|
||||||
assert se_event_tag.value == "Echolocation"
|
assert se_event_tag.value == "Echolocation"
|
||||||
|
|
||||||
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
|
se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual")
|
||||||
assert se_individual_tag is not None
|
assert se_individual_tag is not None
|
||||||
assert se_individual_tag.value == "0"
|
assert se_individual_tag.value == "0"
|
||||||
|
|
||||||
@ -439,7 +441,7 @@ class TestLoadBatDetect2Merged:
|
|||||||
assert clip_ann.clip.recording.duration == 5.0
|
assert clip_ann.clip.recording.duration == 5.0
|
||||||
assert len(clip_ann.sound_events) == 1
|
assert len(clip_ann.sound_events) == 1
|
||||||
|
|
||||||
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
|
clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class")
|
||||||
assert clip_class_tag is not None
|
assert clip_class_tag is not None
|
||||||
assert clip_class_tag.value == "Myotis"
|
assert clip_class_tag.value == "Myotis"
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -16,35 +16,11 @@ from batdetect2.postprocess.decoding import (
|
|||||||
get_prediction_features,
|
get_prediction_features,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import RawPrediction
|
from batdetect2.postprocess.types import RawPrediction
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_geometry_builder():
|
def dummy_targets() -> TargetProtocol:
|
||||||
"""A simple GeometryBuilder that creates a BBox around the point."""
|
|
||||||
|
|
||||||
def _builder(
|
|
||||||
position: Tuple[float, float],
|
|
||||||
dimensions: xr.DataArray,
|
|
||||||
class_name: Optional[str] = None,
|
|
||||||
) -> data.BoundingBox:
|
|
||||||
time, freq = position
|
|
||||||
width = dimensions.sel(dimension="width").item()
|
|
||||||
height = dimensions.sel(dimension="height").item()
|
|
||||||
return data.BoundingBox(
|
|
||||||
coordinates=[
|
|
||||||
time - width / 2,
|
|
||||||
freq - height / 2,
|
|
||||||
time + width / 2,
|
|
||||||
freq + height / 2,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return _builder
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_sound_event_decoder():
|
|
||||||
"""A simple SoundEventDecoder mapping names to tags."""
|
|
||||||
tag_map = {
|
tag_map = {
|
||||||
"bat": [
|
"bat": [
|
||||||
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
|
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
|
||||||
@ -57,18 +33,56 @@ def dummy_sound_event_decoder():
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _decoder(class_name: str) -> List[data.Tag]:
|
class DummyTargets(TargetProtocol):
|
||||||
return tag_map.get(class_name.lower(), [])
|
class_names = [
|
||||||
|
"bat",
|
||||||
|
"noise",
|
||||||
|
"unknown",
|
||||||
|
]
|
||||||
|
|
||||||
return _decoder
|
dimension_names = ["width", "height"]
|
||||||
|
|
||||||
|
generic_class_tags = [
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(key="detector"), value="batdetect2"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
@pytest.fixture
|
def filter(self, sound_event: data.SoundEventAnnotation):
|
||||||
def generic_tags() -> List[data.Tag]:
|
return True
|
||||||
"""Sample generic tags."""
|
|
||||||
return [
|
def transform(self, sound_event: data.SoundEventAnnotation):
|
||||||
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
|
return sound_event
|
||||||
]
|
|
||||||
|
def encode_class(
|
||||||
|
self, sound_event: data.SoundEventAnnotation
|
||||||
|
) -> Optional[str]:
|
||||||
|
return "bat"
|
||||||
|
|
||||||
|
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||||
|
return tag_map.get(class_label.lower(), [])
|
||||||
|
|
||||||
|
def encode_roi(self, sound_event: data.SoundEventAnnotation):
|
||||||
|
return np.array([0.0, 0.0]), np.array([0.0, 0.0])
|
||||||
|
|
||||||
|
def decode_roi(
|
||||||
|
self,
|
||||||
|
position,
|
||||||
|
size: np.ndarray,
|
||||||
|
class_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
time, freq = position
|
||||||
|
width, height = size
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
time - width / 2,
|
||||||
|
freq - height / 2,
|
||||||
|
time + width / 2,
|
||||||
|
freq + height / 2,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return DummyTargets()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -156,7 +170,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
|||||||
"""Creates an empty detection dataset with correct structure."""
|
"""Creates an empty detection dataset with correct structure."""
|
||||||
detection_coords = {
|
detection_coords = {
|
||||||
"time": ("detection", np.array([], dtype=np.float64)),
|
"time": ("detection", np.array([], dtype=np.float64)),
|
||||||
"freq": ("detection", np.array([], dtype=np.float64)),
|
"frequency": ("detection", np.array([], dtype=np.float64)),
|
||||||
}
|
}
|
||||||
scores = xr.DataArray(
|
scores = xr.DataArray(
|
||||||
np.array([], dtype=np.float64),
|
np.array([], dtype=np.float64),
|
||||||
@ -184,7 +198,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
|||||||
)
|
)
|
||||||
return xr.Dataset(
|
return xr.Dataset(
|
||||||
{
|
{
|
||||||
"score": scores,
|
"scores": scores,
|
||||||
"dimensions": dimensions,
|
"dimensions": dimensions,
|
||||||
"classes": classes,
|
"classes": classes,
|
||||||
"features": features,
|
"features": features,
|
||||||
@ -215,8 +229,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
300 + 16 / 2,
|
300 + 16 / 2,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
class_scores=pred1_classes,
|
class_scores=pred1_classes.values,
|
||||||
features=pred1_features,
|
features=pred1_features.values,
|
||||||
)
|
)
|
||||||
|
|
||||||
pred2_classes = xr.DataArray(
|
pred2_classes = xr.DataArray(
|
||||||
@ -237,8 +251,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
200 + 12 / 2,
|
200 + 12 / 2,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
class_scores=pred2_classes,
|
class_scores=pred2_classes.values,
|
||||||
features=pred2_features,
|
features=pred2_features.values,
|
||||||
)
|
)
|
||||||
|
|
||||||
pred3_classes = xr.DataArray(
|
pred3_classes = xr.DataArray(
|
||||||
@ -259,18 +273,17 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
60.0,
|
60.0,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
class_scores=pred3_classes,
|
class_scores=pred3_classes.values,
|
||||||
features=pred3_features,
|
features=pred3_features.values,
|
||||||
)
|
)
|
||||||
return [pred1, pred2, pred3]
|
return [pred1, pred2, pred3]
|
||||||
|
|
||||||
|
|
||||||
def test_convert_xr_dataset_basic(
|
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
|
||||||
sample_detection_dataset, dummy_geometry_builder
|
|
||||||
):
|
|
||||||
"""Test basic conversion of a dataset to RawPrediction list."""
|
"""Test basic conversion of a dataset to RawPrediction list."""
|
||||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||||
sample_detection_dataset, dummy_geometry_builder
|
sample_detection_dataset,
|
||||||
|
dummy_targets.decode_roi,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(raw_predictions, list)
|
assert isinstance(raw_predictions, list)
|
||||||
@ -286,11 +299,11 @@ def test_convert_xr_dataset_basic(
|
|||||||
20 + 7 / 2,
|
20 + 7 / 2,
|
||||||
300 + 16 / 2,
|
300 + 16 / 2,
|
||||||
]
|
]
|
||||||
xr.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
pred1.class_scores,
|
pred1.class_scores,
|
||||||
sample_detection_dataset["classes"].sel(detection=0),
|
sample_detection_dataset["classes"].sel(detection=0),
|
||||||
)
|
)
|
||||||
xr.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -304,21 +317,20 @@ def test_convert_xr_dataset_basic(
|
|||||||
10 + 3 / 2,
|
10 + 3 / 2,
|
||||||
200 + 12 / 2,
|
200 + 12 / 2,
|
||||||
]
|
]
|
||||||
xr.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
pred2.class_scores,
|
pred2.class_scores,
|
||||||
sample_detection_dataset["classes"].sel(detection=1),
|
sample_detection_dataset["classes"].sel(detection=1),
|
||||||
)
|
)
|
||||||
xr.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
pred2.features, sample_detection_dataset["features"].sel(detection=1)
|
pred2.features, sample_detection_dataset["features"].sel(detection=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_convert_xr_dataset_empty(
|
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
|
||||||
empty_detection_dataset, dummy_geometry_builder
|
|
||||||
):
|
|
||||||
"""Test conversion of an empty dataset."""
|
"""Test conversion of an empty dataset."""
|
||||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||||
empty_detection_dataset, dummy_geometry_builder
|
empty_detection_dataset,
|
||||||
|
dummy_targets.decode_roi,
|
||||||
)
|
)
|
||||||
assert isinstance(raw_predictions, list)
|
assert isinstance(raw_predictions, list)
|
||||||
assert len(raw_predictions) == 0
|
assert len(raw_predictions) == 0
|
||||||
@ -327,8 +339,7 @@ def test_convert_xr_dataset_empty(
|
|||||||
def test_convert_raw_to_sound_event_basic(
|
def test_convert_raw_to_sound_event_basic(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions,
|
||||||
sample_recording,
|
sample_recording,
|
||||||
dummy_sound_event_decoder,
|
dummy_targets,
|
||||||
generic_tags,
|
|
||||||
):
|
):
|
||||||
"""Test basic conversion, default threshold, multi-label."""
|
"""Test basic conversion, default threshold, multi-label."""
|
||||||
|
|
||||||
@ -337,8 +348,7 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction=raw_pred,
|
raw_prediction=raw_pred,
|
||||||
recording=sample_recording,
|
recording=sample_recording,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(se_pred, data.SoundEventPrediction)
|
assert isinstance(se_pred, data.SoundEventPrediction)
|
||||||
@ -357,10 +367,11 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
)
|
)
|
||||||
assert feat_dict["batdetect2:f0"] == 7.0
|
assert feat_dict["batdetect2:f0"] == 7.0
|
||||||
|
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("soundevent:category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
("soundevent:species", "Myotis", 0.43),
|
("dwc:scientificName", "Myotis", 0.43),
|
||||||
}
|
}
|
||||||
actual_tags = {
|
actual_tags = {
|
||||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||||
@ -369,10 +380,7 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_thresholding(
|
def test_convert_raw_to_sound_event_thresholding(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions, sample_recording, dummy_targets
|
||||||
sample_recording,
|
|
||||||
dummy_sound_event_decoder,
|
|
||||||
generic_tags,
|
|
||||||
):
|
):
|
||||||
"""Test effect of classification threshold."""
|
"""Test effect of classification threshold."""
|
||||||
raw_pred = sample_raw_predictions[0]
|
raw_pred = sample_raw_predictions[0]
|
||||||
@ -381,15 +389,15 @@ def test_convert_raw_to_sound_event_thresholding(
|
|||||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction=raw_pred,
|
raw_prediction=raw_pred,
|
||||||
recording=sample_recording,
|
recording=sample_recording,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
classification_threshold=high_threshold,
|
classification_threshold=high_threshold,
|
||||||
top_class_only=False,
|
top_class_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("soundevent:category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
}
|
}
|
||||||
actual_tags = {
|
actual_tags = {
|
||||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||||
@ -400,8 +408,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
|||||||
def test_convert_raw_to_sound_event_no_threshold(
|
def test_convert_raw_to_sound_event_no_threshold(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions,
|
||||||
sample_recording,
|
sample_recording,
|
||||||
dummy_sound_event_decoder,
|
dummy_targets,
|
||||||
generic_tags,
|
|
||||||
):
|
):
|
||||||
"""Test when classification_threshold is None."""
|
"""Test when classification_threshold is None."""
|
||||||
raw_pred = sample_raw_predictions[2]
|
raw_pred = sample_raw_predictions[2]
|
||||||
@ -409,16 +416,16 @@ def test_convert_raw_to_sound_event_no_threshold(
|
|||||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction=raw_pred,
|
raw_prediction=raw_pred,
|
||||||
recording=sample_recording,
|
recording=sample_recording,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
classification_threshold=None,
|
classification_threshold=None,
|
||||||
top_class_only=False,
|
top_class_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||||
("soundevent:species", "Myotis", 0.05),
|
("dwc:scientificName", "Myotis", 0.05),
|
||||||
("soundevent:category", "noise", 0.02),
|
("category", "noise", 0.02),
|
||||||
}
|
}
|
||||||
actual_tags = {
|
actual_tags = {
|
||||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||||
@ -429,8 +436,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
|||||||
def test_convert_raw_to_sound_event_top_class(
|
def test_convert_raw_to_sound_event_top_class(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions,
|
||||||
sample_recording,
|
sample_recording,
|
||||||
dummy_sound_event_decoder,
|
dummy_targets,
|
||||||
generic_tags,
|
|
||||||
):
|
):
|
||||||
"""Test top_class_only=True behavior."""
|
"""Test top_class_only=True behavior."""
|
||||||
raw_pred = sample_raw_predictions[0]
|
raw_pred = sample_raw_predictions[0]
|
||||||
@ -438,15 +444,15 @@ def test_convert_raw_to_sound_event_top_class(
|
|||||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction=raw_pred,
|
raw_prediction=raw_pred,
|
||||||
recording=sample_recording,
|
recording=sample_recording,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only=True,
|
top_class_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("soundevent:category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
}
|
}
|
||||||
actual_tags = {
|
actual_tags = {
|
||||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||||
@ -457,8 +463,7 @@ def test_convert_raw_to_sound_event_top_class(
|
|||||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions,
|
||||||
sample_recording,
|
sample_recording,
|
||||||
dummy_sound_event_decoder,
|
dummy_targets,
|
||||||
generic_tags,
|
|
||||||
):
|
):
|
||||||
"""Test when all class scores are below the default threshold."""
|
"""Test when all class scores are below the default threshold."""
|
||||||
raw_pred = sample_raw_predictions[2]
|
raw_pred = sample_raw_predictions[2]
|
||||||
@ -466,12 +471,12 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
|||||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction=raw_pred,
|
raw_prediction=raw_pred,
|
||||||
recording=sample_recording,
|
recording=sample_recording,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only=False,
|
top_class_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||||
}
|
}
|
||||||
@ -484,15 +489,13 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
|||||||
def test_convert_raw_list_to_clip_basic(
|
def test_convert_raw_list_to_clip_basic(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions,
|
||||||
sample_clip,
|
sample_clip,
|
||||||
dummy_sound_event_decoder,
|
dummy_targets,
|
||||||
generic_tags,
|
|
||||||
):
|
):
|
||||||
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
||||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||||
raw_predictions=sample_raw_predictions,
|
raw_predictions=sample_raw_predictions,
|
||||||
clip=sample_clip,
|
clip=sample_clip,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only=False,
|
top_class_only=False,
|
||||||
)
|
)
|
||||||
@ -515,23 +518,19 @@ def test_convert_raw_list_to_clip_basic(
|
|||||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||||
for pt in clip_pred.sound_events[2].tags
|
for pt in clip_pred.sound_events[2].tags
|
||||||
}
|
}
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
expected_tags3 = {
|
expected_tags3 = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||||
}
|
}
|
||||||
assert se_pred3_tags == expected_tags3
|
assert se_pred3_tags == expected_tags3
|
||||||
|
|
||||||
|
|
||||||
def test_convert_raw_list_to_clip_empty(
|
def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
||||||
sample_clip,
|
|
||||||
dummy_sound_event_decoder,
|
|
||||||
generic_tags,
|
|
||||||
):
|
|
||||||
"""Test converting an empty list of RawPredictions."""
|
"""Test converting an empty list of RawPredictions."""
|
||||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||||
raw_predictions=[],
|
raw_predictions=[],
|
||||||
clip=sample_clip,
|
clip=sample_clip,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(clip_pred, data.ClipPrediction)
|
assert isinstance(clip_pred, data.ClipPrediction)
|
||||||
@ -542,16 +541,14 @@ def test_convert_raw_list_to_clip_empty(
|
|||||||
def test_convert_raw_list_to_clip_passes_args(
|
def test_convert_raw_list_to_clip_passes_args(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions,
|
||||||
sample_clip,
|
sample_clip,
|
||||||
dummy_sound_event_decoder,
|
dummy_targets,
|
||||||
generic_tags,
|
|
||||||
):
|
):
|
||||||
"""Test that arguments like top_class_only are passed through."""
|
"""Test that arguments like top_class_only are passed through."""
|
||||||
|
|
||||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||||
raw_predictions=sample_raw_predictions,
|
raw_predictions=sample_raw_predictions,
|
||||||
clip=sample_clip,
|
clip=sample_clip,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
generic_class_tags=generic_tags,
|
|
||||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only=True,
|
top_class_only=True,
|
||||||
)
|
)
|
||||||
@ -562,16 +559,18 @@ def test_convert_raw_list_to_clip_passes_args(
|
|||||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||||
for pt in clip_pred.sound_events[0].tags
|
for pt in clip_pred.sound_events[0].tags
|
||||||
}
|
}
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
expected_tags1 = {
|
expected_tags1 = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("soundevent:category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
}
|
}
|
||||||
assert se_pred1_tags == expected_tags1
|
assert se_pred1_tags == expected_tags1
|
||||||
|
|
||||||
|
|
||||||
def test_get_generic_tags_basic(generic_tags):
|
def test_get_generic_tags_basic(dummy_targets):
|
||||||
"""Test creation of generic tags with score."""
|
"""Test creation of generic tags with score."""
|
||||||
detection_score = 0.75
|
detection_score = 0.75
|
||||||
|
generic_tags = dummy_targets.generic_class_tags
|
||||||
predicted_tags = get_generic_tags(
|
predicted_tags = get_generic_tags(
|
||||||
detection_score=detection_score, generic_class_tags=generic_tags
|
detection_score=detection_score, generic_class_tags=generic_tags
|
||||||
)
|
)
|
||||||
@ -589,17 +588,19 @@ def test_get_prediction_features_basic():
|
|||||||
coords={"feature": ["feat1", "feat2", "feat3"]},
|
coords={"feature": ["feat1", "feat2", "feat3"]},
|
||||||
dims=["feature"],
|
dims=["feature"],
|
||||||
)
|
)
|
||||||
features = get_prediction_features(feature_data)
|
features = get_prediction_features(feature_data.values)
|
||||||
assert len(features) == 3
|
assert len(features) == 3
|
||||||
for feature, feat_name, feat_value in zip(
|
for feature, feat_name, feat_value in zip(
|
||||||
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
|
features,
|
||||||
|
["f0", "f1", "f2"],
|
||||||
|
[1.1, 2.2, 3.3],
|
||||||
):
|
):
|
||||||
assert isinstance(feature, data.Feature)
|
assert isinstance(feature, data.Feature)
|
||||||
assert feature.term.name == f"batdetect2:{feat_name}"
|
assert feature.term.name == f"batdetect2:{feat_name}"
|
||||||
assert feature.value == feat_value
|
assert feature.value == feat_value
|
||||||
|
|
||||||
|
|
||||||
def test_get_class_tags_basic(dummy_sound_event_decoder):
|
def test_get_class_tags_basic(dummy_targets):
|
||||||
"""Test creation of class tags based on scores and decoder."""
|
"""Test creation of class tags based on scores and decoder."""
|
||||||
class_scores = xr.DataArray(
|
class_scores = xr.DataArray(
|
||||||
[0.6, 0.2, 0.9],
|
[0.6, 0.2, 0.9],
|
||||||
@ -607,8 +608,8 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
|
|||||||
dims=["category"],
|
dims=["category"],
|
||||||
)
|
)
|
||||||
predicted_tags = get_class_tags(
|
predicted_tags = get_class_tags(
|
||||||
class_scores=class_scores,
|
class_scores=class_scores.values,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
)
|
)
|
||||||
assert len(predicted_tags) == 3
|
assert len(predicted_tags) == 3
|
||||||
tag_values = [pt.tag.value for pt in predicted_tags]
|
tag_values = [pt.tag.value for pt in predicted_tags]
|
||||||
@ -622,7 +623,7 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
|
|||||||
assert 0.9 in tag_scores
|
assert 0.9 in tag_scores
|
||||||
|
|
||||||
|
|
||||||
def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
def test_get_class_tags_thresholding(dummy_targets):
|
||||||
"""Test class tag creation with a threshold."""
|
"""Test class tag creation with a threshold."""
|
||||||
class_scores = xr.DataArray(
|
class_scores = xr.DataArray(
|
||||||
[0.6, 0.2, 0.9],
|
[0.6, 0.2, 0.9],
|
||||||
@ -631,8 +632,8 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
|||||||
)
|
)
|
||||||
threshold = 0.5
|
threshold = 0.5
|
||||||
predicted_tags = get_class_tags(
|
predicted_tags = get_class_tags(
|
||||||
class_scores=class_scores,
|
class_scores=class_scores.values,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -643,7 +644,7 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
|||||||
assert "uncertain" in tag_values
|
assert "uncertain" in tag_values
|
||||||
|
|
||||||
|
|
||||||
def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
def test_get_class_tags_top_class_only(dummy_targets):
|
||||||
"""Test class tag creation with top_class_only."""
|
"""Test class tag creation with top_class_only."""
|
||||||
class_scores = xr.DataArray(
|
class_scores = xr.DataArray(
|
||||||
[0.6, 0.2, 0.9],
|
[0.6, 0.2, 0.9],
|
||||||
@ -651,8 +652,8 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
|||||||
dims=["category"],
|
dims=["category"],
|
||||||
)
|
)
|
||||||
predicted_tags = get_class_tags(
|
predicted_tags = get_class_tags(
|
||||||
class_scores=class_scores,
|
class_scores=class_scores.values,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
top_class_only=True,
|
top_class_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -661,11 +662,11 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
|||||||
assert predicted_tags[0].score == 0.9
|
assert predicted_tags[0].score == 0.9
|
||||||
|
|
||||||
|
|
||||||
def test_get_class_tags_empty(dummy_sound_event_decoder):
|
def test_get_class_tags_empty(dummy_targets):
|
||||||
"""Test with empty class scores."""
|
"""Test with empty class scores."""
|
||||||
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
|
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
|
||||||
predicted_tags = get_class_tags(
|
predicted_tags = get_class_tags(
|
||||||
class_scores=class_scores,
|
class_scores=class_scores.values,
|
||||||
sound_event_decoder=dummy_sound_event_decoder,
|
targets=dummy_targets,
|
||||||
)
|
)
|
||||||
assert len(predicted_tags) == 0
|
assert len(predicted_tags) == 0
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from uuid import uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.terms import get_term
|
||||||
|
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
DEFAULT_SPECIES_LIST,
|
DEFAULT_SPECIES_LIST,
|
||||||
@ -21,26 +22,19 @@ from batdetect2.targets.classes import (
|
|||||||
load_decoder_from_config,
|
load_decoder_from_config,
|
||||||
load_encoder_from_config,
|
load_encoder_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
from batdetect2.targets.terms import TagInfo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_annotation(
|
def sample_annotation(
|
||||||
sound_event: data.SoundEvent,
|
sound_event: data.SoundEvent,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
) -> data.SoundEventAnnotation:
|
) -> data.SoundEventAnnotation:
|
||||||
"""Fixture for a sample SoundEventAnnotation."""
|
"""Fixture for a sample SoundEventAnnotation."""
|
||||||
return data.SoundEventAnnotation(
|
return data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
term=sample_term_registry.get_term("species"),
|
data.Tag(key="quality", value="Good"), # type: ignore
|
||||||
value="Pipistrellus pipistrellus",
|
|
||||||
),
|
|
||||||
data.Tag(
|
|
||||||
term=sample_term_registry.get_term("quality"),
|
|
||||||
value="Good",
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -136,59 +130,33 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
|||||||
|
|
||||||
def test_is_target_class_match_all(
|
def test_is_target_class_match_all(
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
):
|
):
|
||||||
tags = {
|
tags = {
|
||||||
data.Tag(
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
term=sample_term_registry["species"],
|
data.Tag(key="quality", value="Good"), # type: ignore
|
||||||
value="Pipistrellus pipistrellus",
|
|
||||||
),
|
|
||||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
|
||||||
}
|
}
|
||||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||||
|
|
||||||
tags = {
|
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||||
data.Tag(
|
|
||||||
term=sample_term_registry["species"],
|
|
||||||
value="Pipistrellus pipistrellus",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||||
|
|
||||||
tags = {
|
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||||
data.Tag(
|
|
||||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||||
|
|
||||||
|
|
||||||
def test_is_target_class_match_any(
|
def test_is_target_class_match_any(
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
):
|
):
|
||||||
tags = {
|
tags = {
|
||||||
data.Tag(
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
term=sample_term_registry["species"],
|
data.Tag(key="quality", value="Good"), # type: ignore
|
||||||
value="Pipistrellus pipistrellus",
|
|
||||||
),
|
|
||||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
|
||||||
}
|
}
|
||||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||||
|
|
||||||
tags = {
|
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||||
data.Tag(
|
|
||||||
term=sample_term_registry["species"],
|
|
||||||
value="Pipistrellus pipistrellus",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||||
|
|
||||||
tags = {
|
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||||
data.Tag(
|
|
||||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||||
|
|
||||||
|
|
||||||
@ -208,7 +176,6 @@ def test_get_class_names_from_config():
|
|||||||
|
|
||||||
def test_build_encoder_from_config(
|
def test_build_encoder_from_config(
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
):
|
):
|
||||||
config = ClassesConfig(
|
config = ClassesConfig(
|
||||||
classes=[
|
classes=[
|
||||||
@ -220,25 +187,18 @@ def test_build_encoder_from_config(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
encoder = build_sound_event_encoder(
|
encoder = build_sound_event_encoder(config)
|
||||||
config,
|
|
||||||
term_registry=sample_term_registry,
|
|
||||||
)
|
|
||||||
result = encoder(sample_annotation)
|
result = encoder(sample_annotation)
|
||||||
assert result == "pippip"
|
assert result == "pippip"
|
||||||
|
|
||||||
config = ClassesConfig(classes=[])
|
config = ClassesConfig(classes=[])
|
||||||
encoder = build_sound_event_encoder(
|
encoder = build_sound_event_encoder(config)
|
||||||
config,
|
|
||||||
term_registry=sample_term_registry,
|
|
||||||
)
|
|
||||||
result = encoder(sample_annotation)
|
result = encoder(sample_annotation)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def test_load_encoder_from_config_valid(
|
def test_load_encoder_from_config_valid(
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
create_temp_yaml: Callable[[str], Path],
|
create_temp_yaml: Callable[[str], Path],
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
@ -249,10 +209,7 @@ def test_load_encoder_from_config_valid(
|
|||||||
value: Pipistrellus pipistrellus
|
value: Pipistrellus pipistrellus
|
||||||
"""
|
"""
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
encoder = load_encoder_from_config(
|
encoder = load_encoder_from_config(temp_yaml_path)
|
||||||
temp_yaml_path,
|
|
||||||
term_registry=sample_term_registry,
|
|
||||||
)
|
|
||||||
# We cannot directly compare the function, so we test it.
|
# We cannot directly compare the function, so we test it.
|
||||||
result = encoder(sample_annotation) # type: ignore
|
result = encoder(sample_annotation) # type: ignore
|
||||||
assert result == "pippip"
|
assert result == "pippip"
|
||||||
@ -260,7 +217,6 @@ def test_load_encoder_from_config_valid(
|
|||||||
|
|
||||||
def test_load_encoder_from_config_invalid(
|
def test_load_encoder_from_config_invalid(
|
||||||
create_temp_yaml: Callable[[str], Path],
|
create_temp_yaml: Callable[[str], Path],
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
classes:
|
classes:
|
||||||
@ -271,10 +227,7 @@ def test_load_encoder_from_config_invalid(
|
|||||||
"""
|
"""
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
load_encoder_from_config(
|
load_encoder_from_config(temp_yaml_path)
|
||||||
temp_yaml_path,
|
|
||||||
term_registry=sample_term_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_default_class_name():
|
def test_get_default_class_name():
|
||||||
@ -291,7 +244,7 @@ def test_get_default_classes():
|
|||||||
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
||||||
|
|
||||||
|
|
||||||
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
def test_build_decoder_from_config():
|
||||||
config = ClassesConfig(
|
config = ClassesConfig(
|
||||||
classes=[
|
classes=[
|
||||||
TargetClass(
|
TargetClass(
|
||||||
@ -304,12 +257,10 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
|||||||
],
|
],
|
||||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
)
|
)
|
||||||
decoder = build_sound_event_decoder(
|
decoder = build_sound_event_decoder(config)
|
||||||
config, term_registry=sample_term_registry
|
|
||||||
)
|
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
assert len(tags) == 1
|
assert len(tags) == 1
|
||||||
assert tags[0].term == sample_term_registry["call_type"]
|
assert tags[0].term == get_term("event")
|
||||||
assert tags[0].value == "Echolocation"
|
assert tags[0].value == "Echolocation"
|
||||||
|
|
||||||
# Test when output_tags is None, should fall back to tags
|
# Test when output_tags is None, should fall back to tags
|
||||||
@ -324,32 +275,25 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
|||||||
],
|
],
|
||||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
)
|
)
|
||||||
decoder = build_sound_event_decoder(
|
decoder = build_sound_event_decoder(config)
|
||||||
config, term_registry=sample_term_registry
|
|
||||||
)
|
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
assert len(tags) == 1
|
assert len(tags) == 1
|
||||||
assert tags[0].term == sample_term_registry["species"]
|
assert tags[0].term == get_term("species")
|
||||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
# Test raise_on_unmapped=True
|
# Test raise_on_unmapped=True
|
||||||
decoder = build_sound_event_decoder(
|
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
|
||||||
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
decoder("unknown_class")
|
decoder("unknown_class")
|
||||||
|
|
||||||
# Test raise_on_unmapped=False
|
# Test raise_on_unmapped=False
|
||||||
decoder = build_sound_event_decoder(
|
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
|
||||||
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
|
||||||
)
|
|
||||||
tags = decoder("unknown_class")
|
tags = decoder("unknown_class")
|
||||||
assert len(tags) == 0
|
assert len(tags) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_load_decoder_from_config_valid(
|
def test_load_decoder_from_config_valid(
|
||||||
create_temp_yaml: Callable[[str], Path],
|
create_temp_yaml: Callable[[str], Path],
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
classes:
|
classes:
|
||||||
@ -366,17 +310,15 @@ def test_load_decoder_from_config_valid(
|
|||||||
"""
|
"""
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
decoder = load_decoder_from_config(
|
decoder = load_decoder_from_config(
|
||||||
temp_yaml_path, term_registry=sample_term_registry
|
temp_yaml_path,
|
||||||
)
|
)
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
assert len(tags) == 1
|
assert len(tags) == 1
|
||||||
assert tags[0].term == sample_term_registry["call_type"]
|
assert tags[0].term == get_term("call_type")
|
||||||
assert tags[0].value == "Echolocation"
|
assert tags[0].value == "Echolocation"
|
||||||
|
|
||||||
|
|
||||||
def test_build_generic_class_tags_from_config(
|
def test_build_generic_class_tags_from_config():
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
config = ClassesConfig(
|
config = ClassesConfig(
|
||||||
classes=[
|
classes=[
|
||||||
TargetClass(
|
TargetClass(
|
||||||
@ -391,11 +333,9 @@ def test_build_generic_class_tags_from_config(
|
|||||||
TagInfo(key="call_type", value="Echolocation"),
|
TagInfo(key="call_type", value="Echolocation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
generic_tags = build_generic_class_tags(
|
generic_tags = build_generic_class_tags(config)
|
||||||
config, term_registry=sample_term_registry
|
|
||||||
)
|
|
||||||
assert len(generic_tags) == 2
|
assert len(generic_tags) == 2
|
||||||
assert generic_tags[0].term == sample_term_registry["order"]
|
assert generic_tags[0].term == get_term("order")
|
||||||
assert generic_tags[0].value == "Chiroptera"
|
assert generic_tags[0].value == "Chiroptera"
|
||||||
assert generic_tags[1].term == sample_term_registry["call_type"]
|
assert generic_tags[1].term == get_term("call_type")
|
||||||
assert generic_tags[1].value == "Echolocation"
|
assert generic_tags[1].value == "Echolocation"
|
||||||
|
|||||||
@ -80,7 +80,6 @@ def test_generated_heatmaps_have_correct_dimensions(
|
|||||||
|
|
||||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||||
sample_target_config: TargetConfig,
|
sample_target_config: TargetConfig,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
pippip_tag: TagInfo,
|
pippip_tag: TagInfo,
|
||||||
):
|
):
|
||||||
config = sample_target_config.model_copy(
|
config = sample_target_config.model_copy(
|
||||||
@ -92,7 +91,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(config, term_registry=sample_term_registry)
|
targets = build_targets(config)
|
||||||
|
|
||||||
spec = xr.DataArray(
|
spec = xr.DataArray(
|
||||||
data=np.random.rand(100, 100),
|
data=np.random.rand(100, 100),
|
||||||
@ -113,12 +112,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
coordinates=[10, 10, 20, 20],
|
coordinates=[10, 10, 20, 20],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
tags=[
|
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
||||||
data.Tag(
|
|
||||||
term=sample_term_registry[pippip_tag.key],
|
|
||||||
value=pippip_tag.value,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,12 +2,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.terms import get_term
|
||||||
|
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
from batdetect2.targets.terms import get_term_from_key
|
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
|
||||||
@ -15,7 +15,6 @@ from batdetect2.train.preprocess import generate_train_example
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def build_from_config(
|
def build_from_config(
|
||||||
create_temp_yaml,
|
create_temp_yaml,
|
||||||
sample_term_registry,
|
|
||||||
):
|
):
|
||||||
def build(yaml_content):
|
def build(yaml_content):
|
||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
@ -31,9 +30,7 @@ def build_from_config(
|
|||||||
field="postprocessing",
|
field="postprocessing",
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(
|
targets = build_targets(targets_config)
|
||||||
targets_config, term_registry=sample_term_registry
|
|
||||||
)
|
|
||||||
preprocessor = build_preprocessor(preprocessing_config)
|
preprocessor = build_preprocessor(preprocessing_config)
|
||||||
labeller = build_clip_labeler(
|
labeller = build_clip_labeler(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -54,7 +51,6 @@ def build_from_config(
|
|||||||
# TODO: better name
|
# TODO: better name
|
||||||
def test_generated_train_example_has_expected_outputs(
|
def test_generated_train_example_has_expected_outputs(
|
||||||
build_from_config,
|
build_from_config,
|
||||||
sample_term_registry,
|
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
@ -78,10 +74,11 @@ def test_generated_train_example_has_expected_outputs(
|
|||||||
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[
|
||||||
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
|
],
|
||||||
)
|
)
|
||||||
clip_annotation = data.ClipAnnotation(
|
clip_annotation = data.ClipAnnotation(
|
||||||
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
||||||
@ -108,7 +105,6 @@ def test_generated_train_example_has_expected_outputs(
|
|||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object(
|
def test_encoding_decoding_roundtrip_recovers_object(
|
||||||
build_from_config,
|
build_from_config,
|
||||||
sample_term_registry,
|
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
@ -131,10 +127,11 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
|||||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[
|
||||||
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
|
],
|
||||||
)
|
)
|
||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
@ -171,14 +168,16 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
|||||||
assert len(recovered.tags) == 2
|
assert len(recovered.tags) == 2
|
||||||
|
|
||||||
predicted_species_tag = next(
|
predicted_species_tag = next(
|
||||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
assert predicted_species_tag is not None
|
assert predicted_species_tag is not None
|
||||||
assert predicted_species_tag.score == 1
|
assert predicted_species_tag.score == 1
|
||||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
predicted_order_tag = next(
|
predicted_order_tag = next(
|
||||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
assert predicted_order_tag is not None
|
assert predicted_order_tag is not None
|
||||||
assert predicted_order_tag.score == 1
|
assert predicted_order_tag.score == 1
|
||||||
@ -187,7 +186,6 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
|||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||||
build_from_config,
|
build_from_config,
|
||||||
sample_term_registry,
|
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
@ -217,10 +215,9 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
|||||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
)
|
)
|
||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
@ -257,14 +254,16 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
|||||||
assert len(recovered.tags) == 2
|
assert len(recovered.tags) == 2
|
||||||
|
|
||||||
predicted_species_tag = next(
|
predicted_species_tag = next(
|
||||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
assert predicted_species_tag is not None
|
assert predicted_species_tag is not None
|
||||||
assert predicted_species_tag.score == 1
|
assert predicted_species_tag.score == 1
|
||||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||||
|
|
||||||
predicted_order_tag = next(
|
predicted_order_tag = next(
|
||||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
assert predicted_order_tag is not None
|
assert predicted_order_tag is not None
|
||||||
assert predicted_order_tag.score == 1
|
assert predicted_order_tag.score == 1
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user