Using matching and affinity functions from soundevent

This commit is contained in:
mbsantiago 2025-12-12 19:25:01 +00:00
parent 113f438e74
commit f71fe0c2e2
12 changed files with 341 additions and 532 deletions

View File

@ -62,6 +62,10 @@ class BaseConfig(BaseModel):
) )
) )
@classmethod
def from_yaml(cls, yaml_str: str):
return cls.model_validate(yaml.safe_load(yaml_str))
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)

View File

@ -2,75 +2,98 @@ from typing import Annotated, Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.geometry import (
from soundevent.geometry import compute_interval_overlap buffer_geometry,
compute_bbox_iou,
compute_geometric_iou,
compute_temporal_closeness,
compute_temporal_iou,
)
from batdetect2.core.configs import BaseConfig from batdetect2.core import BaseConfig, Registry
from batdetect2.core.registries import Registry from batdetect2.typing import AffinityFunction, RawPrediction
from batdetect2.typing.evaluate import AffinityFunction
affinity_functions: Registry[AffinityFunction, []] = Registry( affinity_functions: Registry[AffinityFunction, []] = Registry(
"matching_strategy" "affinity_function"
) )
class TimeAffinityConfig(BaseConfig): class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity" name: Literal["time_affinity"] = "time_affinity"
time_buffer: float = 0.01 position: Literal["start", "end", "center"] | float = "start"
max_distance: float = 0.01
class TimeAffinity(AffinityFunction): class TimeAffinity(AffinityFunction):
def __init__(self, time_buffer: float): def __init__(
self.time_buffer = time_buffer self,
max_distance: float = 0.01,
position: Literal["start", "end", "center"] | float = "start",
):
if position == "start":
position = 0
elif position == "end":
position = 1
elif position == "center":
position = 0.5
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): self.position = position
return compute_timestamp_affinity( self.max_distance = max_distance
geometry1, geometry2, time_buffer=self.time_buffer
def __call__(
self,
detection: RawPrediction,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
return compute_temporal_closeness(
target_geometry,
source_geometry,
ratio=self.position,
max_distance=self.max_distance,
) )
@affinity_functions.register(TimeAffinityConfig) @affinity_functions.register(TimeAffinityConfig)
@staticmethod @staticmethod
def from_config(config: TimeAffinityConfig): def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer) return TimeAffinity(
max_distance=config.max_distance,
position=config.position,
def compute_timestamp_affinity( )
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
class IntervalIOUConfig(BaseConfig): class IntervalIOUConfig(BaseConfig):
name: Literal["interval_iou"] = "interval_iou" name: Literal["interval_iou"] = "interval_iou"
time_buffer: float = 0.01 time_buffer: float = 0.0
class IntervalIOU(AffinityFunction): class IntervalIOU(AffinityFunction):
def __init__(self, time_buffer: float): def __init__(self, time_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
self.time_buffer = time_buffer self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): def __call__(
return compute_interval_iou( self,
geometry1, detection: RawPrediction,
geometry2, ground_truth: data.SoundEventAnnotation,
time_buffer=self.time_buffer, ) -> float:
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
) )
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
)
return compute_temporal_iou(target_geometry, source_geometry)
@affinity_functions.register(IntervalIOUConfig) @affinity_functions.register(IntervalIOUConfig)
@staticmethod @staticmethod
@ -78,64 +101,44 @@ class IntervalIOU(AffinityFunction):
return IntervalIOU(time_buffer=config.time_buffer) return IntervalIOU(time_buffer=config.time_buffer)
def compute_interval_iou(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
class BBoxIOUConfig(BaseConfig): class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou" name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01 time_buffer: float = 0.0
freq_buffer: float = 1000 freq_buffer: float = 0.0
class BBoxIOU(AffinityFunction): class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float): def __init__(self, time_buffer: float, freq_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
if freq_buffer < 0:
raise ValueError("freq_buffer must be non-negative")
self.time_buffer = time_buffer self.time_buffer = time_buffer
self.freq_buffer = freq_buffer self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): def __call__(
if not isinstance(geometry1, data.BoundingBox): self,
raise TypeError( prediction: RawPrediction,
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}" gt: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry
source_geometry = prediction.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
) )
if not isinstance(geometry2, data.BoundingBox): return compute_bbox_iou(target_geometry, source_geometry)
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
@affinity_functions.register(BBoxIOUConfig) @affinity_functions.register(BBoxIOUConfig)
@staticmethod @staticmethod
@ -146,65 +149,44 @@ class BBoxIOU(AffinityFunction):
) )
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig): class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou" name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01 time_buffer: float = 0.0
freq_buffer: float = 1000 freq_buffer: float = 0.0
class GeometricIOU(AffinityFunction): class GeometricIOU(AffinityFunction):
def __init__(self, time_buffer: float): def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
self.time_buffer = time_buffer if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): if freq_buffer < 0:
return compute_affinity( raise ValueError("freq_buffer must be non-negative")
geometry1,
geometry2, self.time_buffer = time_buffer
time_buffer=self.time_buffer, self.freq_buffer = freq_buffer
def __call__(
self,
prediction: RawPrediction,
gt: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry
source_geometry = prediction.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
) )
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
return compute_geometric_iou(target_geometry, source_geometry)
@affinity_functions.register(GeometricIOUConfig) @affinity_functions.register(GeometricIOUConfig)
@staticmethod @staticmethod
@ -213,7 +195,10 @@ class GeometricIOU(AffinityFunction):
AffinityConfig = Annotated[ AffinityConfig = Annotated[
TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig, TimeAffinityConfig
| IntervalIOUConfig
| BBoxIOUConfig
| GeometricIOUConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -31,93 +31,6 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
matching_strategies = Registry("matching_strategy") matching_strategies = Registry("matching_strategy")
def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
scores: Sequence[float] | None = None,
targets: TargetProtocol | None = None,
matcher: MatcherProtocol | None = None,
) -> ClipMatches:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in sound_event_annotations
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
if scores is None:
scores = [
raw_prediction.detection_score
for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
sound_event_annotations[target_idx]
if target_idx is not None
else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
class_name: score
for class_name, score in zip(
targets.class_names,
prediction.class_scores, strict=False,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return ClipMatches(clip=clip, matches=matches)
class StartTimeMatchConfig(BaseConfig): class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time_match"] = "start_time_match" name: Literal["start_time_match"] = "start_time_match"
distance_threshold: float = 0.01 distance_threshold: float = 0.01
@ -514,99 +427,9 @@ class OptimalMatcher(MatcherProtocol):
MatchConfig = Annotated[ MatchConfig = Annotated[
GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig, GreedyMatchConfig
| StartTimeMatchConfig
| OptimalMatchConfig
| GreedyAffinityMatchConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
def compute_affinity_matrix(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
affinity_function: AffinityFunction,
time_scale: float = 1,
frequency_scale: float = 1,
) -> np.ndarray:
# Scale geometries if necessary
if time_scale != 1 or frequency_scale != 1:
ground_truth = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in ground_truth
]
predictions = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in predictions
]
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
for gt_idx, gt_geometry in enumerate(ground_truth):
for pred_idx, pred_geometry in enumerate(predictions):
affinity = affinity_function(
gt_geometry,
pred_geometry,
)
affinity_matrix[gt_idx, pred_idx] = affinity
return affinity_matrix
def select_optimal_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape
gts = set(range(num_gt))
preds = set(range(num_pred))
assiged_rows, assigned_columns = linear_sum_assignment(
affinity_matrix,
maximize=True,
)
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
affinity = float(affinity_matrix[gt_idx, pred_idx])
if affinity <= affinity_threshold:
continue
yield gt_idx, pred_idx, affinity
gts.remove(gt_idx)
preds.remove(pred_idx)
for gt_idx in gts:
yield gt_idx, None, 0
for pred_idx in preds:
yield None, pred_idx, 0
def select_greedy_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape
unmatched_pred = set(range(num_pred))
for gt_idx in range(num_gt):
row = affinity_matrix[gt_idx]
top_pred = int(np.argmax(row))
top_affinity = float(row[top_pred])
if (
top_affinity <= affinity_threshold
or top_pred not in unmatched_pred
):
yield None, gt_idx, 0
continue
unmatched_pred.remove(top_pred)
yield top_pred, gt_idx, top_affinity
for pred_idx in unmatched_pred:
yield pred_idx, None, 0
def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategies.build(config)

View File

@ -210,7 +210,10 @@ class DetectionPrecision:
DetectionMetricConfig = Annotated[ DetectionMetricConfig = Annotated[
DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig, DetectionAveragePrecisionConfig
| DetectionROCAUCConfig
| DetectionRecallConfig
| DetectionPrecisionConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -14,16 +14,15 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig, Registry
from batdetect2.core.registries import Registry from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig
from batdetect2.evaluate.match import ( from batdetect2.typing import (
MatchConfig, AffinityFunction,
StartTimeMatchConfig, BatDetect2Prediction,
build_matcher, EvaluatorProtocol,
RawPrediction,
TargetProtocol,
) )
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"BaseTaskConfig", "BaseTaskConfig",
@ -40,39 +39,34 @@ T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig): class BaseTaskConfig(BaseConfig):
prefix: str prefix: str
ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseTask(EvaluatorProtocol, Generic[T_Output]): class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol targets: TargetProtocol
matcher: MatcherProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
prefix: str prefix: str
ignore_start_end: float
def __init__( def __init__(
self, self,
matcher: MatcherProtocol,
targets: TargetProtocol, targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str, prefix: str,
plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
ignore_start_end: float = 0.01, ignore_start_end: float = 0.01,
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
): ):
self.matcher = matcher self.prefix = prefix
self.targets = targets
self.metrics = metrics self.metrics = metrics
self.plots = plots or [] self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end self.ignore_start_end = ignore_start_end
def compute_metrics( def compute_metrics(
@ -100,7 +94,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
) -> List[T_Output]: ) -> List[T_Output]:
return [ return [
self.evaluate_clip(clip_annotation, preds) self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions, strict=False) for clip_annotation, preds in zip(
clip_annotations, predictions, strict=False
)
] ]
def evaluate_clip( def evaluate_clip(
@ -118,9 +114,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
return False return False
geometry = sound_event_annotation.sound_event.geometry geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds( return is_in_bounds(
geometry, geometry,
clip, clip,
@ -138,25 +131,40 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self.ignore_start_end, self.ignore_start_end,
) )
@classmethod
def build( class BaseSEDTaskConfig(BaseTaskConfig):
cls, affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
config: BaseTaskConfig, affinity_threshold: float = 0
strict_match: bool = True
class BaseSEDTask(BaseTask[T_Output]):
affinity: AffinityFunction
def __init__(
self,
prefix: str,
targets: TargetProtocol, targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None, affinity: AffinityFunction,
**kwargs, plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
affinity_threshold: float = 0,
ignore_start_end: float = 0.01,
strict_match: bool = True,
): ):
matcher = build_matcher(config.matching_strategy) super().__init__(
return cls( prefix=prefix,
matcher=matcher,
targets=targets,
metrics=metrics, metrics=metrics,
plots=plots, plots=plots,
prefix=config.prefix, targets=targets,
ignore_start_end=config.ignore_start_end, ignore_start_end=ignore_start_end,
**kwargs,
) )
self.affinity = affinity
self.affinity_threshold = affinity_threshold
self.strict_match = strict_match
def is_in_bounds( def is_in_bounds(

View File

@ -1,11 +1,11 @@
from typing import ( from functools import partial
List, from typing import Literal
Literal,
)
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.classification import ( from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig, ClassificationAveragePrecisionConfig,
ClassificationMetricConfig, ClassificationMetricConfig,
@ -18,24 +18,28 @@ from batdetect2.evaluate.plots.classification import (
build_classification_plotter, build_classification_plotter,
) )
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseSEDTask,
BaseTaskConfig, BaseSEDTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import BatDetect2Prediction, TargetProtocol from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
TargetProtocol,
)
class ClassificationTaskConfig(BaseTaskConfig): class ClassificationTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification" name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification" prefix: str = "classification"
metrics: List[ClassificationMetricConfig] = Field( metrics: list[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()] default_factory=lambda: [ClassificationAveragePrecisionConfig()]
) )
plots: List[ClassificationPlotConfig] = Field(default_factory=list) plots: list[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True include_generics: bool = True
class ClassificationTask(BaseTask[ClipEval]): class ClassificationTask(BaseSEDTask[ClipEval]):
def __init__( def __init__(
self, self,
*args, *args,
@ -73,40 +77,39 @@ class ClassificationTask(BaseTask[ClipEval]):
gts = [ gts = [
sound_event sound_event
for sound_event in all_gts for sound_event in all_gts
if self.is_class(sound_event, class_name) if is_target_class(
sound_event,
class_name,
self.targets,
include_generics=self.include_generics,
)
] ]
scores = [float(pred.class_scores[class_idx]) for pred in preds]
matches = [] matches = []
for pred_idx, gt_idx, _ in self.matcher( for match in match_detections_and_gts(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore detections=preds,
predictions=[pred.geometry for pred in preds], ground_truths=gts,
scores=scores, affinity=self.affinity,
score=partial(get_class_score, class_idx=class_idx),
strict_match=self.strict_match,
): ):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = ( true_class = (
self.targets.encode_class(gt) if gt is not None else None self.targets.encode_class(match.annotation)
if match.annotation is not None
else None
) )
score = (
float(pred.class_scores[class_idx])
if pred is not None
else 0
)
matches.append( matches.append(
MatchEval( MatchEval(
clip=clip, clip=clip,
gt=gt, gt=match.annotation,
pred=pred, pred=match.prediction,
is_prediction=pred is not None, is_prediction=match.prediction is not None,
is_ground_truth=gt is not None, is_ground_truth=match.annotation is not None,
is_generic=gt is not None and true_class is None, is_generic=match.annotation is not None
and true_class is None,
true_class=true_class, true_class=true_class,
score=score, score=match.prediction_score,
) )
) )
@ -114,20 +117,6 @@ class ClassificationTask(BaseTask[ClipEval]):
return ClipEval(clip=clip, matches=per_class_matches) return ClipEval(clip=clip, matches=per_class_matches)
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@tasks_registry.register(ClassificationTaskConfig) @tasks_registry.register(ClassificationTaskConfig)
@staticmethod @staticmethod
def from_config( def from_config(
@ -142,9 +131,32 @@ class ClassificationTask(BaseTask[ClipEval]):
build_classification_plotter(plot, targets) build_classification_plotter(plot, targets)
for plot in config.plots for plot in config.plots
] ]
return ClassificationTask.build( affinity = build_affinity_function(config.affinity)
config=config, return ClassificationTask(
affinity=affinity,
prefix=config.prefix,
plots=plots, plots=plots,
targets=targets, targets=targets,
metrics=metrics, metrics=metrics,
strict_match=config.strict_match,
) )
def get_class_score(pred: RawPrediction, class_idx: int) -> float:
return pred.class_scores[class_idx]
def is_target_class(
sound_event: data.SoundEventAnnotation,
class_name: str,
targets: TargetProtocol,
include_generics: bool = True,
) -> bool:
sound_event_class = targets.encode_class(sound_event)
if sound_event_class is None and include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name

View File

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -19,19 +19,18 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig, BaseTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class ClipClassificationTaskConfig(BaseTaskConfig): class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification" name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification" prefix: str = "clip_classification"
metrics: List[ClipClassificationMetricConfig] = Field( metrics: list[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [ default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(), ClipClassificationAveragePrecisionConfig(),
] ]
) )
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list) plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]): class ClipClassificationTask(BaseTask[ClipEval]):
@ -78,8 +77,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
build_clip_classification_plotter(plot, targets) build_clip_classification_plotter(plot, targets)
for plot in config.plots for plot in config.plots
] ]
return ClipClassificationTask.build( return ClipClassificationTask(
config=config, prefix=config.prefix,
plots=plots, plots=plots,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,

View File

@ -1,4 +1,4 @@
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -18,19 +18,18 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig, BaseTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class ClipDetectionTaskConfig(BaseTaskConfig): class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection" name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection" prefix: str = "clip_detection"
metrics: List[ClipDetectionMetricConfig] = Field( metrics: list[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [ default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(), ClipDetectionAveragePrecisionConfig(),
] ]
) )
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list) plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]): class ClipDetectionTask(BaseTask[ClipEval]):
@ -69,8 +68,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
build_clip_detection_plotter(plot, targets) build_clip_detection_plotter(plot, targets)
for plot in config.plots for plot in config.plots
] ]
return ClipDetectionTask.build( return ClipDetectionTask(
config=config, prefix=config.prefix,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
plots=plots, plots=plots,

View File

@ -1,8 +1,10 @@
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.detection import ( from batdetect2.evaluate.metrics.detection import (
ClipEval, ClipEval,
DetectionAveragePrecisionConfig, DetectionAveragePrecisionConfig,
@ -15,24 +17,24 @@ from batdetect2.evaluate.plots.detection import (
build_detection_plotter, build_detection_plotter,
) )
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseSEDTask,
BaseTaskConfig, BaseSEDTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.typing.postprocess import BatDetect2Prediction
class DetectionTaskConfig(BaseTaskConfig): class DetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection" name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection" prefix: str = "detection"
metrics: List[DetectionMetricConfig] = Field( metrics: list[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()] default_factory=lambda: [DetectionAveragePrecisionConfig()]
) )
plots: List[DetectionPlotConfig] = Field(default_factory=list) plots: list[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]): class DetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
@ -50,24 +52,22 @@ class DetectionTask(BaseTask[ClipEval]):
for pred in prediction.predictions for pred in prediction.predictions
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]
scores = [pred.detection_score for pred in preds]
matches = [] matches = []
for pred_idx, gt_idx, _ in self.matcher( for match in match_detections_and_gts(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore detections=preds,
predictions=[pred.geometry for pred in preds], ground_truths=gts,
scores=scores, affinity=self.affinity,
score=lambda pred: pred.detection_score,
strict_match=self.strict_match,
): ):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
matches.append( matches.append(
MatchEval( MatchEval(
gt=gt, gt=match.annotation,
pred=pred, pred=match.prediction,
is_prediction=pred is not None, is_prediction=match.prediction is not None,
is_ground_truth=gt is not None, is_ground_truth=match.annotation is not None,
score=pred.detection_score if pred is not None else 0, score=match.prediction_score,
) )
) )
@ -83,9 +83,12 @@ class DetectionTask(BaseTask[ClipEval]):
plots = [ plots = [
build_detection_plotter(plot, targets) for plot in config.plots build_detection_plotter(plot, targets) for plot in config.plots
] ]
return DetectionTask.build( affinity = build_affinity_function(config.affinity)
config=config, return DetectionTask(
prefix=config.prefix,
affinity=affinity,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
plots=plots, plots=plots,
strict_match=config.strict_match,
) )

View File

@ -1,8 +1,10 @@
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.top_class import ( from batdetect2.evaluate.metrics.top_class import (
ClipEval, ClipEval,
MatchEval, MatchEval,
@ -15,24 +17,23 @@ from batdetect2.evaluate.plots.top_class import (
build_top_class_plotter, build_top_class_plotter,
) )
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseSEDTask,
BaseTaskConfig, BaseSEDTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class TopClassDetectionTaskConfig(BaseTaskConfig): class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection" name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class" prefix: str = "top_class"
metrics: List[TopClassMetricConfig] = Field( metrics: list[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()] default_factory=lambda: [TopClassAveragePrecisionConfig()]
) )
plots: List[TopClassPlotConfig] = Field(default_factory=list) plots: list[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]): class TopClassDetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
@ -50,18 +51,17 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
for pred in prediction.predictions for pred in prediction.predictions
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]
matches = [] matches = []
for pred_idx, gt_idx, _ in self.matcher( for match in match_detections_and_gts(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore ground_truths=gts,
predictions=[pred.geometry for pred in preds], detections=preds,
scores=scores, affinity=self.affinity,
score=lambda pred: pred.class_scores.max(),
strict_match=self.strict_match,
): ):
gt = gts[gt_idx] if gt_idx is not None else None gt = match.annotation
pred = preds[pred_idx] if pred_idx is not None else None pred = match.prediction
true_class = ( true_class = (
self.targets.encode_class(gt) if gt is not None else None self.targets.encode_class(gt) if gt is not None else None
) )
@ -69,11 +69,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
class_idx = ( class_idx = (
pred.class_scores.argmax() if pred is not None else None pred.class_scores.argmax() if pred is not None else None
) )
score = (
float(pred.class_scores[class_idx]) if pred is not None else 0
)
pred_class = ( pred_class = (
self.targets.class_names[class_idx] self.targets.class_names[class_idx]
if class_idx is not None if class_idx is not None
@ -90,7 +85,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
true_class=true_class, true_class=true_class,
is_generic=gt is not None and true_class is None, is_generic=gt is not None and true_class is None,
pred_class=pred_class, pred_class=pred_class,
score=score, score=match.prediction_score,
) )
) )
@ -106,9 +101,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
plots = [ plots = [
build_top_class_plotter(plot, targets) for plot in config.plots build_top_class_plotter(plot, targets) for plot in config.plots
] ]
return TopClassDetectionTask.build( affinity = build_affinity_function(config.affinity)
config=config, return TopClassDetectionTask(
prefix=config.prefix,
plots=plots, plots=plots,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
affinity=affinity,
strict_match=config.strict_match,
) )

View File

@ -81,11 +81,11 @@ class MatcherProtocol(Protocol):
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
class AffinityFunction(Protocol, Generic[Geom]): class AffinityFunction(Protocol):
def __call__( def __call__(
self, self,
geometry1: Geom, detection: RawPrediction,
geometry2: Geom, ground_truth: data.SoundEventAnnotation,
) -> float: ... ) -> float: ...

View File

@ -28,13 +28,13 @@ def test_has_tag(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert condition(sound_event_annotation) assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore tags=[data.Tag(key="species", value="Eptesicus fuscus")],
) )
assert not condition(sound_event_annotation) assert not condition(sound_event_annotation)
@ -51,15 +51,15 @@ def test_has_all_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert not condition(sound_event_annotation) assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Echolocation"), # type: ignore data.Tag(key="event", value="Echolocation"),
], ],
) )
assert not condition(sound_event_annotation) assert not condition(sound_event_annotation)
@ -67,8 +67,8 @@ def test_has_all_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"), # type: ignore data.Tag(key="event", value="Echolocation"),
], ],
) )
assert condition(sound_event_annotation) assert condition(sound_event_annotation)
@ -76,9 +76,9 @@ def test_has_all_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"), # type: ignore data.Tag(key="event", value="Echolocation"),
data.Tag(key="sex", value="Female"), # type: ignore data.Tag(key="sex", value="Female"),
], ],
) )
assert condition(sound_event_annotation) assert condition(sound_event_annotation)
@ -96,15 +96,15 @@ def test_has_any_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert condition(sound_event_annotation) assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Echolocation"), # type: ignore data.Tag(key="event", value="Echolocation"),
], ],
) )
assert condition(sound_event_annotation) assert condition(sound_event_annotation)
@ -112,8 +112,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"), # type: ignore data.Tag(key="event", value="Echolocation"),
], ],
) )
assert condition(sound_event_annotation) assert condition(sound_event_annotation)
@ -121,8 +121,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Social"), # type: ignore data.Tag(key="event", value="Social"),
], ],
) )
assert not condition(sound_event_annotation) assert not condition(sound_event_annotation)
@ -140,21 +140,21 @@ def test_not(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert not condition(sound_event_annotation) assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore tags=[data.Tag(key="species", value="Eptesicus fuscus")],
) )
assert condition(sound_event_annotation) assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation( sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"), # type: ignore data.Tag(key="event", value="Echolocation"),
], ],
) )
assert not condition(sound_event_annotation) assert not condition(sound_event_annotation)
@ -402,31 +402,6 @@ def test_has_tags_fails_if_empty():
""") """)
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_duration_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: duration
operator: eq
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_all_of(recording: data.Recording): def test_all_of(recording: data.Recording):
condition = build_condition_from_str(""" condition = build_condition_from_str("""
name: all_of name: all_of
@ -444,7 +419,7 @@ def test_all_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]), geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording, recording=recording,
), ),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert condition(se) assert condition(se)
@ -453,7 +428,7 @@ def test_all_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 2]), geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording, recording=recording,
), ),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert not condition(se) assert not condition(se)
@ -462,7 +437,7 @@ def test_all_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]), geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording, recording=recording,
), ),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore tags=[data.Tag(key="species", value="Eptesicus fuscus")],
) )
assert not condition(se) assert not condition(se)
@ -484,7 +459,7 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 2]), geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording, recording=recording,
), ),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore tags=[data.Tag(key="species", value="Eptesicus fuscus")],
) )
assert not condition(se) assert not condition(se)
@ -493,7 +468,7 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]), geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording, recording=recording,
), ),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert condition(se) assert condition(se)
@ -502,7 +477,7 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 2]), geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording, recording=recording,
), ),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore tags=[data.Tag(key="species", value="Myotis myotis")],
) )
assert condition(se) assert condition(se)
@ -511,6 +486,6 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]), geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording, recording=recording,
), ),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore tags=[data.Tag(key="species", value="Eptesicus fuscus")],
) )
assert condition(se) assert condition(se)