mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Using matching and affinity functions from soundevent
This commit is contained in:
parent
113f438e74
commit
f71fe0c2e2
@ -62,6 +62,10 @@ class BaseConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, yaml_str: str):
|
||||||
|
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|||||||
@ -2,75 +2,98 @@ from typing import Annotated, Literal
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.geometry import (
|
||||||
from soundevent.geometry import compute_interval_overlap
|
buffer_geometry,
|
||||||
|
compute_bbox_iou,
|
||||||
|
compute_geometric_iou,
|
||||||
|
compute_temporal_closeness,
|
||||||
|
compute_temporal_iou,
|
||||||
|
)
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.typing import AffinityFunction, RawPrediction
|
||||||
from batdetect2.typing.evaluate import AffinityFunction
|
|
||||||
|
|
||||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||||
"matching_strategy"
|
"affinity_function"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TimeAffinityConfig(BaseConfig):
|
class TimeAffinityConfig(BaseConfig):
|
||||||
name: Literal["time_affinity"] = "time_affinity"
|
name: Literal["time_affinity"] = "time_affinity"
|
||||||
time_buffer: float = 0.01
|
position: Literal["start", "end", "center"] | float = "start"
|
||||||
|
max_distance: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
class TimeAffinity(AffinityFunction):
|
class TimeAffinity(AffinityFunction):
|
||||||
def __init__(self, time_buffer: float):
|
def __init__(
|
||||||
self.time_buffer = time_buffer
|
self,
|
||||||
|
max_distance: float = 0.01,
|
||||||
|
position: Literal["start", "end", "center"] | float = "start",
|
||||||
|
):
|
||||||
|
if position == "start":
|
||||||
|
position = 0
|
||||||
|
elif position == "end":
|
||||||
|
position = 1
|
||||||
|
elif position == "center":
|
||||||
|
position = 0.5
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
self.position = position
|
||||||
return compute_timestamp_affinity(
|
self.max_distance = max_distance
|
||||||
geometry1, geometry2, time_buffer=self.time_buffer
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
detection: RawPrediction,
|
||||||
|
ground_truth: data.SoundEventAnnotation,
|
||||||
|
) -> float:
|
||||||
|
target_geometry = ground_truth.sound_event.geometry
|
||||||
|
source_geometry = detection.geometry
|
||||||
|
return compute_temporal_closeness(
|
||||||
|
target_geometry,
|
||||||
|
source_geometry,
|
||||||
|
ratio=self.position,
|
||||||
|
max_distance=self.max_distance,
|
||||||
)
|
)
|
||||||
|
|
||||||
@affinity_functions.register(TimeAffinityConfig)
|
@affinity_functions.register(TimeAffinityConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: TimeAffinityConfig):
|
def from_config(config: TimeAffinityConfig):
|
||||||
return TimeAffinity(time_buffer=config.time_buffer)
|
return TimeAffinity(
|
||||||
|
max_distance=config.max_distance,
|
||||||
|
position=config.position,
|
||||||
def compute_timestamp_affinity(
|
)
|
||||||
geometry1: data.Geometry,
|
|
||||||
geometry2: data.Geometry,
|
|
||||||
time_buffer: float = 0.01,
|
|
||||||
) -> float:
|
|
||||||
assert isinstance(geometry1, data.TimeStamp)
|
|
||||||
assert isinstance(geometry2, data.TimeStamp)
|
|
||||||
|
|
||||||
start_time1 = geometry1.coordinates
|
|
||||||
start_time2 = geometry2.coordinates
|
|
||||||
|
|
||||||
a = min(start_time1, start_time2)
|
|
||||||
b = max(start_time1, start_time2)
|
|
||||||
|
|
||||||
if b - a >= 2 * time_buffer:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
intersection = a - b + 2 * time_buffer
|
|
||||||
union = b - a + 2 * time_buffer
|
|
||||||
return intersection / union
|
|
||||||
|
|
||||||
|
|
||||||
class IntervalIOUConfig(BaseConfig):
|
class IntervalIOUConfig(BaseConfig):
|
||||||
name: Literal["interval_iou"] = "interval_iou"
|
name: Literal["interval_iou"] = "interval_iou"
|
||||||
time_buffer: float = 0.01
|
time_buffer: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class IntervalIOU(AffinityFunction):
|
class IntervalIOU(AffinityFunction):
|
||||||
def __init__(self, time_buffer: float):
|
def __init__(self, time_buffer: float):
|
||||||
|
if time_buffer < 0:
|
||||||
|
raise ValueError("time_buffer must be non-negative")
|
||||||
|
|
||||||
self.time_buffer = time_buffer
|
self.time_buffer = time_buffer
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
def __call__(
|
||||||
return compute_interval_iou(
|
self,
|
||||||
geometry1,
|
detection: RawPrediction,
|
||||||
geometry2,
|
ground_truth: data.SoundEventAnnotation,
|
||||||
time_buffer=self.time_buffer,
|
) -> float:
|
||||||
|
target_geometry = ground_truth.sound_event.geometry
|
||||||
|
source_geometry = detection.geometry
|
||||||
|
|
||||||
|
if self.time_buffer > 0:
|
||||||
|
target_geometry = buffer_geometry(
|
||||||
|
target_geometry,
|
||||||
|
time=self.time_buffer,
|
||||||
)
|
)
|
||||||
|
source_geometry = buffer_geometry(
|
||||||
|
source_geometry,
|
||||||
|
time=self.time_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compute_temporal_iou(target_geometry, source_geometry)
|
||||||
|
|
||||||
@affinity_functions.register(IntervalIOUConfig)
|
@affinity_functions.register(IntervalIOUConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -78,64 +101,44 @@ class IntervalIOU(AffinityFunction):
|
|||||||
return IntervalIOU(time_buffer=config.time_buffer)
|
return IntervalIOU(time_buffer=config.time_buffer)
|
||||||
|
|
||||||
|
|
||||||
def compute_interval_iou(
|
|
||||||
geometry1: data.Geometry,
|
|
||||||
geometry2: data.Geometry,
|
|
||||||
time_buffer: float = 0.01,
|
|
||||||
) -> float:
|
|
||||||
assert isinstance(geometry1, data.TimeInterval)
|
|
||||||
assert isinstance(geometry2, data.TimeInterval)
|
|
||||||
|
|
||||||
start_time1, end_time1 = geometry1.coordinates
|
|
||||||
start_time2, end_time2 = geometry1.coordinates
|
|
||||||
|
|
||||||
start_time1 -= time_buffer
|
|
||||||
start_time2 -= time_buffer
|
|
||||||
end_time1 += time_buffer
|
|
||||||
end_time2 += time_buffer
|
|
||||||
|
|
||||||
intersection = compute_interval_overlap(
|
|
||||||
(start_time1, end_time1),
|
|
||||||
(start_time2, end_time2),
|
|
||||||
)
|
|
||||||
|
|
||||||
union = (
|
|
||||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
|
||||||
)
|
|
||||||
|
|
||||||
if union == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return intersection / union
|
|
||||||
|
|
||||||
|
|
||||||
class BBoxIOUConfig(BaseConfig):
|
class BBoxIOUConfig(BaseConfig):
|
||||||
name: Literal["bbox_iou"] = "bbox_iou"
|
name: Literal["bbox_iou"] = "bbox_iou"
|
||||||
time_buffer: float = 0.01
|
time_buffer: float = 0.0
|
||||||
freq_buffer: float = 1000
|
freq_buffer: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class BBoxIOU(AffinityFunction):
|
class BBoxIOU(AffinityFunction):
|
||||||
def __init__(self, time_buffer: float, freq_buffer: float):
|
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||||
|
if time_buffer < 0:
|
||||||
|
raise ValueError("time_buffer must be non-negative")
|
||||||
|
|
||||||
|
if freq_buffer < 0:
|
||||||
|
raise ValueError("freq_buffer must be non-negative")
|
||||||
|
|
||||||
self.time_buffer = time_buffer
|
self.time_buffer = time_buffer
|
||||||
self.freq_buffer = freq_buffer
|
self.freq_buffer = freq_buffer
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
def __call__(
|
||||||
if not isinstance(geometry1, data.BoundingBox):
|
self,
|
||||||
raise TypeError(
|
prediction: RawPrediction,
|
||||||
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
gt: data.SoundEventAnnotation,
|
||||||
|
):
|
||||||
|
target_geometry = gt.sound_event.geometry
|
||||||
|
source_geometry = prediction.geometry
|
||||||
|
|
||||||
|
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||||
|
target_geometry = buffer_geometry(
|
||||||
|
target_geometry,
|
||||||
|
time=self.time_buffer,
|
||||||
|
freq=self.freq_buffer,
|
||||||
|
)
|
||||||
|
source_geometry = buffer_geometry(
|
||||||
|
source_geometry,
|
||||||
|
time=self.time_buffer,
|
||||||
|
freq=self.freq_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(geometry2, data.BoundingBox):
|
return compute_bbox_iou(target_geometry, source_geometry)
|
||||||
raise TypeError(
|
|
||||||
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
|
||||||
)
|
|
||||||
return bbox_iou(
|
|
||||||
geometry1,
|
|
||||||
geometry2,
|
|
||||||
time_buffer=self.time_buffer,
|
|
||||||
freq_buffer=self.freq_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
@affinity_functions.register(BBoxIOUConfig)
|
@affinity_functions.register(BBoxIOUConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -146,65 +149,44 @@ class BBoxIOU(AffinityFunction):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def bbox_iou(
|
|
||||||
geometry1: data.BoundingBox,
|
|
||||||
geometry2: data.BoundingBox,
|
|
||||||
time_buffer: float = 0.01,
|
|
||||||
freq_buffer: float = 1000,
|
|
||||||
) -> float:
|
|
||||||
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
|
||||||
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
|
||||||
|
|
||||||
start_time1 -= time_buffer
|
|
||||||
start_time2 -= time_buffer
|
|
||||||
end_time1 += time_buffer
|
|
||||||
end_time2 += time_buffer
|
|
||||||
|
|
||||||
low_freq1 -= freq_buffer
|
|
||||||
low_freq2 -= freq_buffer
|
|
||||||
high_freq1 += freq_buffer
|
|
||||||
high_freq2 += freq_buffer
|
|
||||||
|
|
||||||
time_intersection = compute_interval_overlap(
|
|
||||||
(start_time1, end_time1),
|
|
||||||
(start_time2, end_time2),
|
|
||||||
)
|
|
||||||
|
|
||||||
freq_intersection = max(
|
|
||||||
0,
|
|
||||||
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
|
||||||
)
|
|
||||||
|
|
||||||
intersection = time_intersection * freq_intersection
|
|
||||||
|
|
||||||
if intersection == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
union = (
|
|
||||||
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
|
||||||
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
|
||||||
- intersection
|
|
||||||
)
|
|
||||||
|
|
||||||
return intersection / union
|
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOUConfig(BaseConfig):
|
class GeometricIOUConfig(BaseConfig):
|
||||||
name: Literal["geometric_iou"] = "geometric_iou"
|
name: Literal["geometric_iou"] = "geometric_iou"
|
||||||
time_buffer: float = 0.01
|
time_buffer: float = 0.0
|
||||||
freq_buffer: float = 1000
|
freq_buffer: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOU(AffinityFunction):
|
class GeometricIOU(AffinityFunction):
|
||||||
def __init__(self, time_buffer: float):
|
def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
|
||||||
self.time_buffer = time_buffer
|
if time_buffer < 0:
|
||||||
|
raise ValueError("time_buffer must be non-negative")
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
if freq_buffer < 0:
|
||||||
return compute_affinity(
|
raise ValueError("freq_buffer must be non-negative")
|
||||||
geometry1,
|
|
||||||
geometry2,
|
self.time_buffer = time_buffer
|
||||||
time_buffer=self.time_buffer,
|
self.freq_buffer = freq_buffer
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prediction: RawPrediction,
|
||||||
|
gt: data.SoundEventAnnotation,
|
||||||
|
):
|
||||||
|
target_geometry = gt.sound_event.geometry
|
||||||
|
source_geometry = prediction.geometry
|
||||||
|
|
||||||
|
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||||
|
target_geometry = buffer_geometry(
|
||||||
|
target_geometry,
|
||||||
|
time=self.time_buffer,
|
||||||
|
freq=self.freq_buffer,
|
||||||
)
|
)
|
||||||
|
source_geometry = buffer_geometry(
|
||||||
|
source_geometry,
|
||||||
|
time=self.time_buffer,
|
||||||
|
freq=self.freq_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compute_geometric_iou(target_geometry, source_geometry)
|
||||||
|
|
||||||
@affinity_functions.register(GeometricIOUConfig)
|
@affinity_functions.register(GeometricIOUConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -213,7 +195,10 @@ class GeometricIOU(AffinityFunction):
|
|||||||
|
|
||||||
|
|
||||||
AffinityConfig = Annotated[
|
AffinityConfig = Annotated[
|
||||||
TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig,
|
TimeAffinityConfig
|
||||||
|
| IntervalIOUConfig
|
||||||
|
| BBoxIOUConfig
|
||||||
|
| GeometricIOUConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -31,93 +31,6 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
|||||||
matching_strategies = Registry("matching_strategy")
|
matching_strategies = Registry("matching_strategy")
|
||||||
|
|
||||||
|
|
||||||
def match(
|
|
||||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
|
||||||
raw_predictions: Sequence[RawPrediction],
|
|
||||||
clip: data.Clip,
|
|
||||||
scores: Sequence[float] | None = None,
|
|
||||||
targets: TargetProtocol | None = None,
|
|
||||||
matcher: MatcherProtocol | None = None,
|
|
||||||
) -> ClipMatches:
|
|
||||||
if matcher is None:
|
|
||||||
matcher = build_matcher()
|
|
||||||
|
|
||||||
if targets is None:
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
|
||||||
sound_event_annotation.sound_event.geometry
|
|
||||||
for sound_event_annotation in sound_event_annotations
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_geometries = [
|
|
||||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
if scores is None:
|
|
||||||
scores = [
|
|
||||||
raw_prediction.detection_score
|
|
||||||
for raw_prediction in raw_predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
for source_idx, target_idx, affinity in matcher(
|
|
||||||
ground_truth=target_geometries,
|
|
||||||
predictions=predicted_geometries,
|
|
||||||
scores=scores,
|
|
||||||
):
|
|
||||||
target = (
|
|
||||||
sound_event_annotations[target_idx]
|
|
||||||
if target_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
prediction = (
|
|
||||||
raw_predictions[source_idx] if source_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
gt_det = target_idx is not None
|
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
|
||||||
gt_geometry = (
|
|
||||||
target_geometries[target_idx] if target_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
|
||||||
pred_geometry = (
|
|
||||||
predicted_geometries[source_idx]
|
|
||||||
if source_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
class_scores = (
|
|
||||||
{
|
|
||||||
class_name: score
|
|
||||||
for class_name, score in zip(
|
|
||||||
targets.class_names,
|
|
||||||
prediction.class_scores, strict=False,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if prediction is not None
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
matches.append(
|
|
||||||
MatchEvaluation(
|
|
||||||
clip=clip,
|
|
||||||
sound_event_annotation=target,
|
|
||||||
gt_det=gt_det,
|
|
||||||
gt_class=gt_class,
|
|
||||||
gt_geometry=gt_geometry,
|
|
||||||
pred_score=pred_score,
|
|
||||||
pred_class_scores=class_scores,
|
|
||||||
pred_geometry=pred_geometry,
|
|
||||||
affinity=affinity,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClipMatches(clip=clip, matches=matches)
|
|
||||||
|
|
||||||
|
|
||||||
class StartTimeMatchConfig(BaseConfig):
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
name: Literal["start_time_match"] = "start_time_match"
|
name: Literal["start_time_match"] = "start_time_match"
|
||||||
distance_threshold: float = 0.01
|
distance_threshold: float = 0.01
|
||||||
@ -514,99 +427,9 @@ class OptimalMatcher(MatcherProtocol):
|
|||||||
|
|
||||||
|
|
||||||
MatchConfig = Annotated[
|
MatchConfig = Annotated[
|
||||||
GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig,
|
GreedyMatchConfig
|
||||||
|
| StartTimeMatchConfig
|
||||||
|
| OptimalMatchConfig
|
||||||
|
| GreedyAffinityMatchConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def compute_affinity_matrix(
|
|
||||||
ground_truth: Sequence[data.Geometry],
|
|
||||||
predictions: Sequence[data.Geometry],
|
|
||||||
affinity_function: AffinityFunction,
|
|
||||||
time_scale: float = 1,
|
|
||||||
frequency_scale: float = 1,
|
|
||||||
) -> np.ndarray:
|
|
||||||
# Scale geometries if necessary
|
|
||||||
if time_scale != 1 or frequency_scale != 1:
|
|
||||||
ground_truth = [
|
|
||||||
scale_geometry(geometry, time_scale, frequency_scale)
|
|
||||||
for geometry in ground_truth
|
|
||||||
]
|
|
||||||
|
|
||||||
predictions = [
|
|
||||||
scale_geometry(geometry, time_scale, frequency_scale)
|
|
||||||
for geometry in predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
|
|
||||||
for gt_idx, gt_geometry in enumerate(ground_truth):
|
|
||||||
for pred_idx, pred_geometry in enumerate(predictions):
|
|
||||||
affinity = affinity_function(
|
|
||||||
gt_geometry,
|
|
||||||
pred_geometry,
|
|
||||||
)
|
|
||||||
affinity_matrix[gt_idx, pred_idx] = affinity
|
|
||||||
|
|
||||||
return affinity_matrix
|
|
||||||
|
|
||||||
|
|
||||||
def select_optimal_matches(
|
|
||||||
affinity_matrix: np.ndarray,
|
|
||||||
affinity_threshold: float = 0.5,
|
|
||||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
|
||||||
num_gt, num_pred = affinity_matrix.shape
|
|
||||||
gts = set(range(num_gt))
|
|
||||||
preds = set(range(num_pred))
|
|
||||||
|
|
||||||
assiged_rows, assigned_columns = linear_sum_assignment(
|
|
||||||
affinity_matrix,
|
|
||||||
maximize=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
|
|
||||||
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
|
||||||
|
|
||||||
if affinity <= affinity_threshold:
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield gt_idx, pred_idx, affinity
|
|
||||||
gts.remove(gt_idx)
|
|
||||||
preds.remove(pred_idx)
|
|
||||||
|
|
||||||
for gt_idx in gts:
|
|
||||||
yield gt_idx, None, 0
|
|
||||||
|
|
||||||
for pred_idx in preds:
|
|
||||||
yield None, pred_idx, 0
|
|
||||||
|
|
||||||
|
|
||||||
def select_greedy_matches(
|
|
||||||
affinity_matrix: np.ndarray,
|
|
||||||
affinity_threshold: float = 0.5,
|
|
||||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
|
||||||
num_gt, num_pred = affinity_matrix.shape
|
|
||||||
unmatched_pred = set(range(num_pred))
|
|
||||||
|
|
||||||
for gt_idx in range(num_gt):
|
|
||||||
row = affinity_matrix[gt_idx]
|
|
||||||
|
|
||||||
top_pred = int(np.argmax(row))
|
|
||||||
top_affinity = float(row[top_pred])
|
|
||||||
|
|
||||||
if (
|
|
||||||
top_affinity <= affinity_threshold
|
|
||||||
or top_pred not in unmatched_pred
|
|
||||||
):
|
|
||||||
yield None, gt_idx, 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
unmatched_pred.remove(top_pred)
|
|
||||||
yield top_pred, gt_idx, top_affinity
|
|
||||||
|
|
||||||
for pred_idx in unmatched_pred:
|
|
||||||
yield pred_idx, None, 0
|
|
||||||
|
|
||||||
|
|
||||||
def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
|
|
||||||
config = config or StartTimeMatchConfig()
|
|
||||||
return matching_strategies.build(config)
|
|
||||||
|
|||||||
@ -210,7 +210,10 @@ class DetectionPrecision:
|
|||||||
|
|
||||||
|
|
||||||
DetectionMetricConfig = Annotated[
|
DetectionMetricConfig = Annotated[
|
||||||
DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig,
|
DetectionAveragePrecisionConfig
|
||||||
|
| DetectionROCAUCConfig
|
||||||
|
| DetectionRecallConfig
|
||||||
|
| DetectionPrecisionConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -14,16 +14,15 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig
|
||||||
from batdetect2.evaluate.match import (
|
from batdetect2.typing import (
|
||||||
MatchConfig,
|
AffinityFunction,
|
||||||
StartTimeMatchConfig,
|
BatDetect2Prediction,
|
||||||
build_matcher,
|
EvaluatorProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseTaskConfig",
|
"BaseTaskConfig",
|
||||||
@ -40,39 +39,34 @@ T_Output = TypeVar("T_Output")
|
|||||||
|
|
||||||
class BaseTaskConfig(BaseConfig):
|
class BaseTaskConfig(BaseConfig):
|
||||||
prefix: str
|
prefix: str
|
||||||
ignore_start_end: float = 0.01
|
|
||||||
matching_strategy: MatchConfig = Field(
|
|
||||||
default_factory=StartTimeMatchConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
|
||||||
matcher: MatcherProtocol
|
|
||||||
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||||
|
|
||||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
|
||||||
ignore_start_end: float
|
|
||||||
|
|
||||||
prefix: str
|
prefix: str
|
||||||
|
|
||||||
|
ignore_start_end: float
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
matcher: MatcherProtocol,
|
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
|
plots: List[
|
||||||
|
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
ignore_start_end: float = 0.01,
|
ignore_start_end: float = 0.01,
|
||||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
|
||||||
):
|
):
|
||||||
self.matcher = matcher
|
self.prefix = prefix
|
||||||
|
self.targets = targets
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.plots = plots or []
|
self.plots = plots or []
|
||||||
self.targets = targets
|
|
||||||
self.prefix = prefix
|
|
||||||
self.ignore_start_end = ignore_start_end
|
self.ignore_start_end = ignore_start_end
|
||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(
|
||||||
@ -100,7 +94,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
) -> List[T_Output]:
|
) -> List[T_Output]:
|
||||||
return [
|
return [
|
||||||
self.evaluate_clip(clip_annotation, preds)
|
self.evaluate_clip(clip_annotation, preds)
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions, strict=False)
|
for clip_annotation, preds in zip(
|
||||||
|
clip_annotations, predictions, strict=False
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
@ -118,9 +114,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return is_in_bounds(
|
return is_in_bounds(
|
||||||
geometry,
|
geometry,
|
||||||
clip,
|
clip,
|
||||||
@ -138,25 +131,40 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
self.ignore_start_end,
|
self.ignore_start_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(
|
class BaseSEDTaskConfig(BaseTaskConfig):
|
||||||
cls,
|
affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
|
||||||
config: BaseTaskConfig,
|
affinity_threshold: float = 0
|
||||||
|
strict_match: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSEDTask(BaseTask[T_Output]):
|
||||||
|
affinity: AffinityFunction
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
affinity: AffinityFunction,
|
||||||
**kwargs,
|
plots: List[
|
||||||
|
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
affinity_threshold: float = 0,
|
||||||
|
ignore_start_end: float = 0.01,
|
||||||
|
strict_match: bool = True,
|
||||||
):
|
):
|
||||||
matcher = build_matcher(config.matching_strategy)
|
super().__init__(
|
||||||
return cls(
|
prefix=prefix,
|
||||||
matcher=matcher,
|
|
||||||
targets=targets,
|
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
prefix=config.prefix,
|
targets=targets,
|
||||||
ignore_start_end=config.ignore_start_end,
|
ignore_start_end=ignore_start_end,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
self.affinity = affinity
|
||||||
|
self.affinity_threshold = affinity_threshold
|
||||||
|
self.strict_match = strict_match
|
||||||
|
|
||||||
|
|
||||||
def is_in_bounds(
|
def is_in_bounds(
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from typing import (
|
from functools import partial
|
||||||
List,
|
from typing import Literal
|
||||||
Literal,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.evaluation import match_detections_and_gts
|
||||||
|
|
||||||
|
from batdetect2.evaluate.affinity import build_affinity_function
|
||||||
from batdetect2.evaluate.metrics.classification import (
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
ClassificationAveragePrecisionConfig,
|
ClassificationAveragePrecisionConfig,
|
||||||
ClassificationMetricConfig,
|
ClassificationMetricConfig,
|
||||||
@ -18,24 +18,28 @@ from batdetect2.evaluate.plots.classification import (
|
|||||||
build_classification_plotter,
|
build_classification_plotter,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseSEDTask,
|
||||||
BaseTaskConfig,
|
BaseSEDTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTaskConfig(BaseTaskConfig):
|
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||||
prefix: str = "classification"
|
prefix: str = "classification"
|
||||||
metrics: List[ClassificationMetricConfig] = Field(
|
metrics: list[ClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
||||||
include_generics: bool = True
|
include_generics: bool = True
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTask(BaseTask[ClipEval]):
|
class ClassificationTask(BaseSEDTask[ClipEval]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@ -73,40 +77,39 @@ class ClassificationTask(BaseTask[ClipEval]):
|
|||||||
gts = [
|
gts = [
|
||||||
sound_event
|
sound_event
|
||||||
for sound_event in all_gts
|
for sound_event in all_gts
|
||||||
if self.is_class(sound_event, class_name)
|
if is_target_class(
|
||||||
|
sound_event,
|
||||||
|
class_name,
|
||||||
|
self.targets,
|
||||||
|
include_generics=self.include_generics,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
for pred_idx, gt_idx, _ in self.matcher(
|
for match in match_detections_and_gts(
|
||||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
detections=preds,
|
||||||
predictions=[pred.geometry for pred in preds],
|
ground_truths=gts,
|
||||||
scores=scores,
|
affinity=self.affinity,
|
||||||
|
score=partial(get_class_score, class_idx=class_idx),
|
||||||
|
strict_match=self.strict_match,
|
||||||
):
|
):
|
||||||
gt = gts[gt_idx] if gt_idx is not None else None
|
|
||||||
pred = preds[pred_idx] if pred_idx is not None else None
|
|
||||||
|
|
||||||
true_class = (
|
true_class = (
|
||||||
self.targets.encode_class(gt) if gt is not None else None
|
self.targets.encode_class(match.annotation)
|
||||||
|
if match.annotation is not None
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
score = (
|
|
||||||
float(pred.class_scores[class_idx])
|
|
||||||
if pred is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEval(
|
MatchEval(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
gt=gt,
|
gt=match.annotation,
|
||||||
pred=pred,
|
pred=match.prediction,
|
||||||
is_prediction=pred is not None,
|
is_prediction=match.prediction is not None,
|
||||||
is_ground_truth=gt is not None,
|
is_ground_truth=match.annotation is not None,
|
||||||
is_generic=gt is not None and true_class is None,
|
is_generic=match.annotation is not None
|
||||||
|
and true_class is None,
|
||||||
true_class=true_class,
|
true_class=true_class,
|
||||||
score=score,
|
score=match.prediction_score,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -114,20 +117,6 @@ class ClassificationTask(BaseTask[ClipEval]):
|
|||||||
|
|
||||||
return ClipEval(clip=clip, matches=per_class_matches)
|
return ClipEval(clip=clip, matches=per_class_matches)
|
||||||
|
|
||||||
def is_class(
|
|
||||||
self,
|
|
||||||
sound_event: data.SoundEventAnnotation,
|
|
||||||
class_name: str,
|
|
||||||
) -> bool:
|
|
||||||
sound_event_class = self.targets.encode_class(sound_event)
|
|
||||||
|
|
||||||
if sound_event_class is None and self.include_generics:
|
|
||||||
# Sound events that are generic could be of the given
|
|
||||||
# class
|
|
||||||
return True
|
|
||||||
|
|
||||||
return sound_event_class == class_name
|
|
||||||
|
|
||||||
@tasks_registry.register(ClassificationTaskConfig)
|
@tasks_registry.register(ClassificationTaskConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
@ -142,9 +131,32 @@ class ClassificationTask(BaseTask[ClipEval]):
|
|||||||
build_classification_plotter(plot, targets)
|
build_classification_plotter(plot, targets)
|
||||||
for plot in config.plots
|
for plot in config.plots
|
||||||
]
|
]
|
||||||
return ClassificationTask.build(
|
affinity = build_affinity_function(config.affinity)
|
||||||
config=config,
|
return ClassificationTask(
|
||||||
|
affinity=affinity,
|
||||||
|
prefix=config.prefix,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
|
strict_match=config.strict_match,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_score(pred: RawPrediction, class_idx: int) -> float:
|
||||||
|
return pred.class_scores[class_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def is_target_class(
|
||||||
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
class_name: str,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
include_generics: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
sound_event_class = targets.encode_class(sound_event)
|
||||||
|
|
||||||
|
if sound_event_class is None and include_generics:
|
||||||
|
# Sound events that are generic could be of the given
|
||||||
|
# class
|
||||||
|
return True
|
||||||
|
|
||||||
|
return sound_event_class == class_name
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -19,19 +19,18 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_classification"] = "clip_classification"
|
name: Literal["clip_classification"] = "clip_classification"
|
||||||
prefix: str = "clip_classification"
|
prefix: str = "clip_classification"
|
||||||
metrics: List[ClipClassificationMetricConfig] = Field(
|
metrics: list[ClipClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
ClipClassificationAveragePrecisionConfig(),
|
ClipClassificationAveragePrecisionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||||
@ -78,8 +77,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
build_clip_classification_plotter(plot, targets)
|
build_clip_classification_plotter(plot, targets)
|
||||||
for plot in config.plots
|
for plot in config.plots
|
||||||
]
|
]
|
||||||
return ClipClassificationTask.build(
|
return ClipClassificationTask(
|
||||||
config=config,
|
prefix=config.prefix,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -18,19 +18,18 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_detection"] = "clip_detection"
|
name: Literal["clip_detection"] = "clip_detection"
|
||||||
prefix: str = "clip_detection"
|
prefix: str = "clip_detection"
|
||||||
metrics: List[ClipDetectionMetricConfig] = Field(
|
metrics: list[ClipDetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
ClipDetectionAveragePrecisionConfig(),
|
ClipDetectionAveragePrecisionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||||
@ -69,8 +68,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
build_clip_detection_plotter(plot, targets)
|
build_clip_detection_plotter(plot, targets)
|
||||||
for plot in config.plots
|
for plot in config.plots
|
||||||
]
|
]
|
||||||
return ClipDetectionTask.build(
|
return ClipDetectionTask(
|
||||||
config=config,
|
prefix=config.prefix,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
from typing import List, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.evaluation import match_detections_and_gts
|
||||||
|
|
||||||
|
from batdetect2.evaluate.affinity import build_affinity_function
|
||||||
from batdetect2.evaluate.metrics.detection import (
|
from batdetect2.evaluate.metrics.detection import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
DetectionAveragePrecisionConfig,
|
DetectionAveragePrecisionConfig,
|
||||||
@ -15,24 +17,24 @@ from batdetect2.evaluate.plots.detection import (
|
|||||||
build_detection_plotter,
|
build_detection_plotter,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseSEDTask,
|
||||||
BaseTaskConfig,
|
BaseSEDTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseTaskConfig):
|
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||||
prefix: str = "detection"
|
prefix: str = "detection"
|
||||||
metrics: List[DetectionMetricConfig] = Field(
|
metrics: list[DetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DetectionTask(BaseTask[ClipEval]):
|
class DetectionTask(BaseSEDTask[ClipEval]):
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
@ -50,24 +52,22 @@ class DetectionTask(BaseTask[ClipEval]):
|
|||||||
for pred in prediction.predictions
|
for pred in prediction.predictions
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
scores = [pred.detection_score for pred in preds]
|
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for pred_idx, gt_idx, _ in self.matcher(
|
for match in match_detections_and_gts(
|
||||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
detections=preds,
|
||||||
predictions=[pred.geometry for pred in preds],
|
ground_truths=gts,
|
||||||
scores=scores,
|
affinity=self.affinity,
|
||||||
|
score=lambda pred: pred.detection_score,
|
||||||
|
strict_match=self.strict_match,
|
||||||
):
|
):
|
||||||
gt = gts[gt_idx] if gt_idx is not None else None
|
|
||||||
pred = preds[pred_idx] if pred_idx is not None else None
|
|
||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEval(
|
MatchEval(
|
||||||
gt=gt,
|
gt=match.annotation,
|
||||||
pred=pred,
|
pred=match.prediction,
|
||||||
is_prediction=pred is not None,
|
is_prediction=match.prediction is not None,
|
||||||
is_ground_truth=gt is not None,
|
is_ground_truth=match.annotation is not None,
|
||||||
score=pred.detection_score if pred is not None else 0,
|
score=match.prediction_score,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -83,9 +83,12 @@ class DetectionTask(BaseTask[ClipEval]):
|
|||||||
plots = [
|
plots = [
|
||||||
build_detection_plotter(plot, targets) for plot in config.plots
|
build_detection_plotter(plot, targets) for plot in config.plots
|
||||||
]
|
]
|
||||||
return DetectionTask.build(
|
affinity = build_affinity_function(config.affinity)
|
||||||
config=config,
|
return DetectionTask(
|
||||||
|
prefix=config.prefix,
|
||||||
|
affinity=affinity,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
|
strict_match=config.strict_match,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
from typing import List, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.evaluation import match_detections_and_gts
|
||||||
|
|
||||||
|
from batdetect2.evaluate.affinity import build_affinity_function
|
||||||
from batdetect2.evaluate.metrics.top_class import (
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
MatchEval,
|
MatchEval,
|
||||||
@ -15,24 +17,23 @@ from batdetect2.evaluate.plots.top_class import (
|
|||||||
build_top_class_plotter,
|
build_top_class_plotter,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseSEDTask,
|
||||||
BaseTaskConfig,
|
BaseSEDTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["top_class_detection"] = "top_class_detection"
|
name: Literal["top_class_detection"] = "top_class_detection"
|
||||||
prefix: str = "top_class"
|
prefix: str = "top_class"
|
||||||
metrics: List[TopClassMetricConfig] = Field(
|
metrics: list[TopClassMetricConfig] = Field(
|
||||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
@ -50,18 +51,17 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
for pred in prediction.predictions
|
for pred in prediction.predictions
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
# Take the highest score for each prediction
|
|
||||||
scores = [pred.class_scores.max() for pred in preds]
|
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for pred_idx, gt_idx, _ in self.matcher(
|
for match in match_detections_and_gts(
|
||||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
ground_truths=gts,
|
||||||
predictions=[pred.geometry for pred in preds],
|
detections=preds,
|
||||||
scores=scores,
|
affinity=self.affinity,
|
||||||
|
score=lambda pred: pred.class_scores.max(),
|
||||||
|
strict_match=self.strict_match,
|
||||||
):
|
):
|
||||||
gt = gts[gt_idx] if gt_idx is not None else None
|
gt = match.annotation
|
||||||
pred = preds[pred_idx] if pred_idx is not None else None
|
pred = match.prediction
|
||||||
|
|
||||||
true_class = (
|
true_class = (
|
||||||
self.targets.encode_class(gt) if gt is not None else None
|
self.targets.encode_class(gt) if gt is not None else None
|
||||||
)
|
)
|
||||||
@ -69,11 +69,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
class_idx = (
|
class_idx = (
|
||||||
pred.class_scores.argmax() if pred is not None else None
|
pred.class_scores.argmax() if pred is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
score = (
|
|
||||||
float(pred.class_scores[class_idx]) if pred is not None else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_class = (
|
pred_class = (
|
||||||
self.targets.class_names[class_idx]
|
self.targets.class_names[class_idx]
|
||||||
if class_idx is not None
|
if class_idx is not None
|
||||||
@ -90,7 +85,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
true_class=true_class,
|
true_class=true_class,
|
||||||
is_generic=gt is not None and true_class is None,
|
is_generic=gt is not None and true_class is None,
|
||||||
pred_class=pred_class,
|
pred_class=pred_class,
|
||||||
score=score,
|
score=match.prediction_score,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,9 +101,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
plots = [
|
plots = [
|
||||||
build_top_class_plotter(plot, targets) for plot in config.plots
|
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||||
]
|
]
|
||||||
return TopClassDetectionTask.build(
|
affinity = build_affinity_function(config.affinity)
|
||||||
config=config,
|
return TopClassDetectionTask(
|
||||||
|
prefix=config.prefix,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
affinity=affinity,
|
||||||
|
strict_match=config.strict_match,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -81,11 +81,11 @@ class MatcherProtocol(Protocol):
|
|||||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
class AffinityFunction(Protocol, Generic[Geom]):
|
class AffinityFunction(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
geometry1: Geom,
|
detection: RawPrediction,
|
||||||
geometry2: Geom,
|
ground_truth: data.SoundEventAnnotation,
|
||||||
) -> float: ...
|
) -> float: ...
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,13 +28,13 @@ def test_has_tag(sound_event: data.SoundEvent):
|
|||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert condition(sound_event_annotation)
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
)
|
)
|
||||||
assert not condition(sound_event_annotation)
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
@ -51,15 +51,15 @@ def test_has_all_tags(sound_event: data.SoundEvent):
|
|||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert not condition(sound_event_annotation)
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
data.Tag(key="event", value="Echolocation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert not condition(sound_event_annotation)
|
assert not condition(sound_event_annotation)
|
||||||
@ -67,8 +67,8 @@ def test_has_all_tags(sound_event: data.SoundEvent):
|
|||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
data.Tag(key="event", value="Echolocation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert condition(sound_event_annotation)
|
assert condition(sound_event_annotation)
|
||||||
@ -76,9 +76,9 @@ def test_has_all_tags(sound_event: data.SoundEvent):
|
|||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
data.Tag(key="event", value="Echolocation"),
|
||||||
data.Tag(key="sex", value="Female"), # type: ignore
|
data.Tag(key="sex", value="Female"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert condition(sound_event_annotation)
|
assert condition(sound_event_annotation)
|
||||||
@ -96,15 +96,15 @@ def test_has_any_tags(sound_event: data.SoundEvent):
|
|||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert condition(sound_event_annotation)
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
data.Tag(key="event", value="Echolocation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert condition(sound_event_annotation)
|
assert condition(sound_event_annotation)
|
||||||
@ -112,8 +112,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
|
|||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
data.Tag(key="event", value="Echolocation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert condition(sound_event_annotation)
|
assert condition(sound_event_annotation)
|
||||||
@ -121,8 +121,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
|
|||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||||
data.Tag(key="event", value="Social"), # type: ignore
|
data.Tag(key="event", value="Social"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert not condition(sound_event_annotation)
|
assert not condition(sound_event_annotation)
|
||||||
@ -140,21 +140,21 @@ def test_not(sound_event: data.SoundEvent):
|
|||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert not condition(sound_event_annotation)
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
)
|
)
|
||||||
assert condition(sound_event_annotation)
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
data.Tag(key="event", value="Echolocation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert not condition(sound_event_annotation)
|
assert not condition(sound_event_annotation)
|
||||||
@ -402,31 +402,6 @@ def test_has_tags_fails_if_empty():
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: eq
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
|
|
||||||
def test_duration_is_false_if_no_geometry(recording: data.Recording):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: duration
|
|
||||||
operator: eq
|
|
||||||
seconds: 1
|
|
||||||
""")
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_of(recording: data.Recording):
|
def test_all_of(recording: data.Recording):
|
||||||
condition = build_condition_from_str("""
|
condition = build_condition_from_str("""
|
||||||
name: all_of
|
name: all_of
|
||||||
@ -444,7 +419,7 @@ def test_all_of(recording: data.Recording):
|
|||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert condition(se)
|
assert condition(se)
|
||||||
|
|
||||||
@ -453,7 +428,7 @@ def test_all_of(recording: data.Recording):
|
|||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert not condition(se)
|
assert not condition(se)
|
||||||
|
|
||||||
@ -462,7 +437,7 @@ def test_all_of(recording: data.Recording):
|
|||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
)
|
)
|
||||||
assert not condition(se)
|
assert not condition(se)
|
||||||
|
|
||||||
@ -484,7 +459,7 @@ def test_any_of(recording: data.Recording):
|
|||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
)
|
)
|
||||||
assert not condition(se)
|
assert not condition(se)
|
||||||
|
|
||||||
@ -493,7 +468,7 @@ def test_any_of(recording: data.Recording):
|
|||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert condition(se)
|
assert condition(se)
|
||||||
|
|
||||||
@ -502,7 +477,7 @@ def test_any_of(recording: data.Recording):
|
|||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
assert condition(se)
|
assert condition(se)
|
||||||
|
|
||||||
@ -511,6 +486,6 @@ def test_any_of(recording: data.Recording):
|
|||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
)
|
)
|
||||||
assert condition(se)
|
assert condition(se)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user