mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Using matching and affinity functions from soundevent
This commit is contained in:
parent
113f438e74
commit
f71fe0c2e2
@ -62,6 +62,10 @@ class BaseConfig(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_str: str):
|
||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@ -2,75 +2,98 @@ from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.geometry import compute_interval_overlap
|
||||
from soundevent.geometry import (
|
||||
buffer_geometry,
|
||||
compute_bbox_iou,
|
||||
compute_geometric_iou,
|
||||
compute_temporal_closeness,
|
||||
compute_temporal_iou,
|
||||
)
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing.evaluate import AffinityFunction
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import AffinityFunction, RawPrediction
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
"matching_strategy"
|
||||
"affinity_function"
|
||||
)
|
||||
|
||||
|
||||
class TimeAffinityConfig(BaseConfig):
|
||||
name: Literal["time_affinity"] = "time_affinity"
|
||||
time_buffer: float = 0.01
|
||||
position: Literal["start", "end", "center"] | float = "start"
|
||||
max_distance: float = 0.01
|
||||
|
||||
|
||||
class TimeAffinity(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
def __init__(
|
||||
self,
|
||||
max_distance: float = 0.01,
|
||||
position: Literal["start", "end", "center"] | float = "start",
|
||||
):
|
||||
if position == "start":
|
||||
position = 0
|
||||
elif position == "end":
|
||||
position = 1
|
||||
elif position == "center":
|
||||
position = 0.5
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_timestamp_affinity(
|
||||
geometry1, geometry2, time_buffer=self.time_buffer
|
||||
self.position = position
|
||||
self.max_distance = max_distance
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
detection: RawPrediction,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float:
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
return compute_temporal_closeness(
|
||||
target_geometry,
|
||||
source_geometry,
|
||||
ratio=self.position,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
|
||||
@affinity_functions.register(TimeAffinityConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TimeAffinityConfig):
|
||||
return TimeAffinity(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_timestamp_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeStamp)
|
||||
assert isinstance(geometry2, data.TimeStamp)
|
||||
|
||||
start_time1 = geometry1.coordinates
|
||||
start_time2 = geometry2.coordinates
|
||||
|
||||
a = min(start_time1, start_time2)
|
||||
b = max(start_time1, start_time2)
|
||||
|
||||
if b - a >= 2 * time_buffer:
|
||||
return 0
|
||||
|
||||
intersection = a - b + 2 * time_buffer
|
||||
union = b - a + 2 * time_buffer
|
||||
return intersection / union
|
||||
return TimeAffinity(
|
||||
max_distance=config.max_distance,
|
||||
position=config.position,
|
||||
)
|
||||
|
||||
|
||||
class IntervalIOUConfig(BaseConfig):
|
||||
name: Literal["interval_iou"] = "interval_iou"
|
||||
time_buffer: float = 0.01
|
||||
time_buffer: float = 0.0
|
||||
|
||||
|
||||
class IntervalIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
if time_buffer < 0:
|
||||
raise ValueError("time_buffer must be non-negative")
|
||||
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_interval_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
def __call__(
|
||||
self,
|
||||
detection: RawPrediction,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float:
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
|
||||
if self.time_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
target_geometry,
|
||||
time=self.time_buffer,
|
||||
)
|
||||
source_geometry = buffer_geometry(
|
||||
source_geometry,
|
||||
time=self.time_buffer,
|
||||
)
|
||||
|
||||
return compute_temporal_iou(target_geometry, source_geometry)
|
||||
|
||||
@affinity_functions.register(IntervalIOUConfig)
|
||||
@staticmethod
|
||||
@ -78,64 +101,44 @@ class IntervalIOU(AffinityFunction):
|
||||
return IntervalIOU(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_interval_iou(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeInterval)
|
||||
assert isinstance(geometry2, data.TimeInterval)
|
||||
|
||||
start_time1, end_time1 = geometry1.coordinates
|
||||
start_time2, end_time2 = geometry1.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||
)
|
||||
|
||||
if union == 0:
|
||||
return 0
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class BBoxIOUConfig(BaseConfig):
|
||||
name: Literal["bbox_iou"] = "bbox_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
time_buffer: float = 0.0
|
||||
freq_buffer: float = 0.0
|
||||
|
||||
|
||||
class BBoxIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||
if time_buffer < 0:
|
||||
raise ValueError("time_buffer must be non-negative")
|
||||
|
||||
if freq_buffer < 0:
|
||||
raise ValueError("freq_buffer must be non-negative")
|
||||
|
||||
self.time_buffer = time_buffer
|
||||
self.freq_buffer = freq_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
if not isinstance(geometry1, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
||||
def __call__(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
gt: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = gt.sound_event.geometry
|
||||
source_geometry = prediction.geometry
|
||||
|
||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
target_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
source_geometry = buffer_geometry(
|
||||
source_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
|
||||
if not isinstance(geometry2, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
||||
)
|
||||
return bbox_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.freq_buffer,
|
||||
)
|
||||
return compute_bbox_iou(target_geometry, source_geometry)
|
||||
|
||||
@affinity_functions.register(BBoxIOUConfig)
|
||||
@staticmethod
|
||||
@ -146,65 +149,44 @@ class BBoxIOU(AffinityFunction):
|
||||
)
|
||||
|
||||
|
||||
def bbox_iou(
|
||||
geometry1: data.BoundingBox,
|
||||
geometry2: data.BoundingBox,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
||||
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
low_freq1 -= freq_buffer
|
||||
low_freq2 -= freq_buffer
|
||||
high_freq1 += freq_buffer
|
||||
high_freq2 += freq_buffer
|
||||
|
||||
time_intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
freq_intersection = max(
|
||||
0,
|
||||
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
||||
)
|
||||
|
||||
intersection = time_intersection * freq_intersection
|
||||
|
||||
if intersection == 0:
|
||||
return 0
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
||||
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
||||
- intersection
|
||||
)
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class GeometricIOUConfig(BaseConfig):
|
||||
name: Literal["geometric_iou"] = "geometric_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
time_buffer: float = 0.0
|
||||
freq_buffer: float = 0.0
|
||||
|
||||
|
||||
class GeometricIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
|
||||
if time_buffer < 0:
|
||||
raise ValueError("time_buffer must be non-negative")
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_affinity(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
if freq_buffer < 0:
|
||||
raise ValueError("freq_buffer must be non-negative")
|
||||
|
||||
self.time_buffer = time_buffer
|
||||
self.freq_buffer = freq_buffer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
gt: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = gt.sound_event.geometry
|
||||
source_geometry = prediction.geometry
|
||||
|
||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
target_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
source_geometry = buffer_geometry(
|
||||
source_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
|
||||
return compute_geometric_iou(target_geometry, source_geometry)
|
||||
|
||||
@affinity_functions.register(GeometricIOUConfig)
|
||||
@staticmethod
|
||||
@ -213,7 +195,10 @@ class GeometricIOU(AffinityFunction):
|
||||
|
||||
|
||||
AffinityConfig = Annotated[
|
||||
TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig,
|
||||
TimeAffinityConfig
|
||||
| IntervalIOUConfig
|
||||
| BBoxIOUConfig
|
||||
| GeometricIOUConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -31,93 +31,6 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||
matching_strategies = Registry("matching_strategy")
|
||||
|
||||
|
||||
def match(
|
||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||
raw_predictions: Sequence[RawPrediction],
|
||||
clip: data.Clip,
|
||||
scores: Sequence[float] | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
matcher: MatcherProtocol | None = None,
|
||||
) -> ClipMatches:
|
||||
if matcher is None:
|
||||
matcher = build_matcher()
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
sound_event_annotation.sound_event.geometry
|
||||
for sound_event_annotation in sound_event_annotations
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
if scores is None:
|
||||
scores = [
|
||||
raw_prediction.detection_score
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for source_idx, target_idx, affinity in matcher(
|
||||
ground_truth=target_geometries,
|
||||
predictions=predicted_geometries,
|
||||
scores=scores,
|
||||
):
|
||||
target = (
|
||||
sound_event_annotations[target_idx]
|
||||
if target_idx is not None
|
||||
else None
|
||||
)
|
||||
prediction = (
|
||||
raw_predictions[source_idx] if source_idx is not None else None
|
||||
)
|
||||
|
||||
gt_det = target_idx is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
gt_geometry = (
|
||||
target_geometries[target_idx] if target_idx is not None else None
|
||||
)
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
pred_geometry = (
|
||||
predicted_geometries[source_idx]
|
||||
if source_idx is not None
|
||||
else None
|
||||
)
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
class_name: score
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.class_scores, strict=False,
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEvaluation(
|
||||
clip=clip,
|
||||
sound_event_annotation=target,
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
gt_geometry=gt_geometry,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=class_scores,
|
||||
pred_geometry=pred_geometry,
|
||||
affinity=affinity,
|
||||
)
|
||||
)
|
||||
|
||||
return ClipMatches(clip=clip, matches=matches)
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
name: Literal["start_time_match"] = "start_time_match"
|
||||
distance_threshold: float = 0.01
|
||||
@ -514,99 +427,9 @@ class OptimalMatcher(MatcherProtocol):
|
||||
|
||||
|
||||
MatchConfig = Annotated[
|
||||
GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig,
|
||||
GreedyMatchConfig
|
||||
| StartTimeMatchConfig
|
||||
| OptimalMatchConfig
|
||||
| GreedyAffinityMatchConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def compute_affinity_matrix(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
affinity_function: AffinityFunction,
|
||||
time_scale: float = 1,
|
||||
frequency_scale: float = 1,
|
||||
) -> np.ndarray:
|
||||
# Scale geometries if necessary
|
||||
if time_scale != 1 or frequency_scale != 1:
|
||||
ground_truth = [
|
||||
scale_geometry(geometry, time_scale, frequency_scale)
|
||||
for geometry in ground_truth
|
||||
]
|
||||
|
||||
predictions = [
|
||||
scale_geometry(geometry, time_scale, frequency_scale)
|
||||
for geometry in predictions
|
||||
]
|
||||
|
||||
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
|
||||
for gt_idx, gt_geometry in enumerate(ground_truth):
|
||||
for pred_idx, pred_geometry in enumerate(predictions):
|
||||
affinity = affinity_function(
|
||||
gt_geometry,
|
||||
pred_geometry,
|
||||
)
|
||||
affinity_matrix[gt_idx, pred_idx] = affinity
|
||||
|
||||
return affinity_matrix
|
||||
|
||||
|
||||
def select_optimal_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
gts = set(range(num_gt))
|
||||
preds = set(range(num_pred))
|
||||
|
||||
assiged_rows, assigned_columns = linear_sum_assignment(
|
||||
affinity_matrix,
|
||||
maximize=True,
|
||||
)
|
||||
|
||||
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
|
||||
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
||||
|
||||
if affinity <= affinity_threshold:
|
||||
continue
|
||||
|
||||
yield gt_idx, pred_idx, affinity
|
||||
gts.remove(gt_idx)
|
||||
preds.remove(pred_idx)
|
||||
|
||||
for gt_idx in gts:
|
||||
yield gt_idx, None, 0
|
||||
|
||||
for pred_idx in preds:
|
||||
yield None, pred_idx, 0
|
||||
|
||||
|
||||
def select_greedy_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[int | None, int | None, float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
unmatched_pred = set(range(num_pred))
|
||||
|
||||
for gt_idx in range(num_gt):
|
||||
row = affinity_matrix[gt_idx]
|
||||
|
||||
top_pred = int(np.argmax(row))
|
||||
top_affinity = float(row[top_pred])
|
||||
|
||||
if (
|
||||
top_affinity <= affinity_threshold
|
||||
or top_pred not in unmatched_pred
|
||||
):
|
||||
yield None, gt_idx, 0
|
||||
continue
|
||||
|
||||
unmatched_pred.remove(top_pred)
|
||||
yield top_pred, gt_idx, top_affinity
|
||||
|
||||
for pred_idx in unmatched_pred:
|
||||
yield pred_idx, None, 0
|
||||
|
||||
|
||||
def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
|
||||
config = config or StartTimeMatchConfig()
|
||||
return matching_strategies.build(config)
|
||||
|
||||
@ -210,7 +210,10 @@ class DetectionPrecision:
|
||||
|
||||
|
||||
DetectionMetricConfig = Annotated[
|
||||
DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig,
|
||||
DetectionAveragePrecisionConfig
|
||||
| DetectionROCAUCConfig
|
||||
| DetectionRecallConfig
|
||||
| DetectionPrecisionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -14,16 +14,15 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
build_matcher,
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig
|
||||
from batdetect2.typing import (
|
||||
AffinityFunction,
|
||||
BatDetect2Prediction,
|
||||
EvaluatorProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BaseTaskConfig",
|
||||
@ -40,39 +39,34 @@ T_Output = TypeVar("T_Output")
|
||||
|
||||
class BaseTaskConfig(BaseConfig):
|
||||
prefix: str
|
||||
ignore_start_end: float = 0.01
|
||||
matching_strategy: MatchConfig = Field(
|
||||
default_factory=StartTimeMatchConfig
|
||||
)
|
||||
|
||||
|
||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
targets: TargetProtocol
|
||||
|
||||
matcher: MatcherProtocol
|
||||
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
prefix: str
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matcher: MatcherProtocol,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
prefix: str,
|
||||
plots: List[
|
||||
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
||||
]
|
||||
| None = None,
|
||||
ignore_start_end: float = 0.01,
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
||||
):
|
||||
self.matcher = matcher
|
||||
self.prefix = prefix
|
||||
self.targets = targets
|
||||
self.metrics = metrics
|
||||
self.plots = plots or []
|
||||
self.targets = targets
|
||||
self.prefix = prefix
|
||||
self.ignore_start_end = ignore_start_end
|
||||
|
||||
def compute_metrics(
|
||||
@ -100,7 +94,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
) -> List[T_Output]:
|
||||
return [
|
||||
self.evaluate_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions, strict=False)
|
||||
for clip_annotation, preds in zip(
|
||||
clip_annotations, predictions, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
def evaluate_clip(
|
||||
@ -118,9 +114,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
return False
|
||||
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
return is_in_bounds(
|
||||
geometry,
|
||||
clip,
|
||||
@ -138,25 +131,40 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
self.ignore_start_end,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
config: BaseTaskConfig,
|
||||
|
||||
class BaseSEDTaskConfig(BaseTaskConfig):
|
||||
affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
|
||||
affinity_threshold: float = 0
|
||||
strict_match: bool = True
|
||||
|
||||
|
||||
class BaseSEDTask(BaseTask[T_Output]):
|
||||
affinity: AffinityFunction
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
|
||||
**kwargs,
|
||||
affinity: AffinityFunction,
|
||||
plots: List[
|
||||
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
||||
]
|
||||
| None = None,
|
||||
affinity_threshold: float = 0,
|
||||
ignore_start_end: float = 0.01,
|
||||
strict_match: bool = True,
|
||||
):
|
||||
matcher = build_matcher(config.matching_strategy)
|
||||
return cls(
|
||||
matcher=matcher,
|
||||
targets=targets,
|
||||
super().__init__(
|
||||
prefix=prefix,
|
||||
metrics=metrics,
|
||||
plots=plots,
|
||||
prefix=config.prefix,
|
||||
ignore_start_end=config.ignore_start_end,
|
||||
**kwargs,
|
||||
targets=targets,
|
||||
ignore_start_end=ignore_start_end,
|
||||
)
|
||||
self.affinity = affinity
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.strict_match = strict_match
|
||||
|
||||
|
||||
def is_in_bounds(
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from typing import (
|
||||
List,
|
||||
Literal,
|
||||
)
|
||||
from functools import partial
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.affinity import build_affinity_function
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationMetricConfig,
|
||||
@ -18,24 +18,28 @@ from batdetect2.evaluate.plots.classification import (
|
||||
build_classification_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
BaseSEDTask,
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseTaskConfig):
|
||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||
prefix: str = "classification"
|
||||
metrics: List[ClassificationMetricConfig] = Field(
|
||||
metrics: list[ClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
include_generics: bool = True
|
||||
|
||||
|
||||
class ClassificationTask(BaseTask[ClipEval]):
|
||||
class ClassificationTask(BaseSEDTask[ClipEval]):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
@ -73,40 +77,39 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in all_gts
|
||||
if self.is_class(sound_event, class_name)
|
||||
if is_target_class(
|
||||
sound_event,
|
||||
class_name,
|
||||
self.targets,
|
||||
include_generics=self.include_generics,
|
||||
)
|
||||
]
|
||||
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
||||
|
||||
matches = []
|
||||
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
for match in match_detections_and_gts(
|
||||
detections=preds,
|
||||
ground_truths=gts,
|
||||
affinity=self.affinity,
|
||||
score=partial(get_class_score, class_idx=class_idx),
|
||||
strict_match=self.strict_match,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
self.targets.encode_class(match.annotation)
|
||||
if match.annotation is not None
|
||||
else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx])
|
||||
if pred is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
clip=clip,
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
gt=match.annotation,
|
||||
pred=match.prediction,
|
||||
is_prediction=match.prediction is not None,
|
||||
is_ground_truth=match.annotation is not None,
|
||||
is_generic=match.annotation is not None
|
||||
and true_class is None,
|
||||
true_class=true_class,
|
||||
score=score,
|
||||
score=match.prediction_score,
|
||||
)
|
||||
)
|
||||
|
||||
@ -114,20 +117,6 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
|
||||
return ClipEval(clip=clip, matches=per_class_matches)
|
||||
|
||||
def is_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
) -> bool:
|
||||
sound_event_class = self.targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and self.include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@tasks_registry.register(ClassificationTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
@ -142,9 +131,32 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
build_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClassificationTask.build(
|
||||
config=config,
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return ClassificationTask(
|
||||
affinity=affinity,
|
||||
prefix=config.prefix,
|
||||
plots=plots,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
strict_match=config.strict_match,
|
||||
)
|
||||
|
||||
|
||||
def get_class_score(pred: RawPrediction, class_idx: int) -> float:
|
||||
return pred.class_scores[class_idx]
|
||||
|
||||
|
||||
def is_target_class(
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
targets: TargetProtocol,
|
||||
include_generics: bool = True,
|
||||
) -> bool:
|
||||
sound_event_class = targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -19,19 +19,18 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_classification"] = "clip_classification"
|
||||
prefix: str = "clip_classification"
|
||||
metrics: List[ClipClassificationMetricConfig] = Field(
|
||||
metrics: list[ClipClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipClassificationAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
@ -78,8 +77,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
build_clip_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipClassificationTask.build(
|
||||
config=config,
|
||||
return ClipClassificationTask(
|
||||
prefix=config.prefix,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -18,19 +18,18 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_detection"] = "clip_detection"
|
||||
prefix: str = "clip_detection"
|
||||
metrics: List[ClipDetectionMetricConfig] = Field(
|
||||
metrics: list[ClipDetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipDetectionAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
@ -69,8 +68,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
build_clip_detection_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipDetectionTask.build(
|
||||
config=config,
|
||||
return ClipDetectionTask(
|
||||
prefix=config.prefix,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.affinity import build_affinity_function
|
||||
from batdetect2.evaluate.metrics.detection import (
|
||||
ClipEval,
|
||||
DetectionAveragePrecisionConfig,
|
||||
@ -15,24 +17,24 @@ from batdetect2.evaluate.plots.detection import (
|
||||
build_detection_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
BaseSEDTask,
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseTaskConfig):
|
||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||
prefix: str = "detection"
|
||||
metrics: List[DetectionMetricConfig] = Field(
|
||||
metrics: list[DetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
||||
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DetectionTask(BaseTask[ClipEval]):
|
||||
class DetectionTask(BaseSEDTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
@ -50,24 +52,22 @@ class DetectionTask(BaseTask[ClipEval]):
|
||||
for pred in prediction.predictions
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
scores = [pred.detection_score for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
for match in match_detections_and_gts(
|
||||
detections=preds,
|
||||
ground_truths=gts,
|
||||
affinity=self.affinity,
|
||||
score=lambda pred: pred.detection_score,
|
||||
strict_match=self.strict_match,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
score=pred.detection_score if pred is not None else 0,
|
||||
gt=match.annotation,
|
||||
pred=match.prediction,
|
||||
is_prediction=match.prediction is not None,
|
||||
is_ground_truth=match.annotation is not None,
|
||||
score=match.prediction_score,
|
||||
)
|
||||
)
|
||||
|
||||
@ -83,9 +83,12 @@ class DetectionTask(BaseTask[ClipEval]):
|
||||
plots = [
|
||||
build_detection_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
return DetectionTask.build(
|
||||
config=config,
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return DetectionTask(
|
||||
prefix=config.prefix,
|
||||
affinity=affinity,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
strict_match=config.strict_match,
|
||||
)
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.affinity import build_affinity_function
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
@ -15,24 +17,23 @@ from batdetect2.evaluate.plots.top_class import (
|
||||
build_top_class_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
BaseSEDTask,
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["top_class_detection"] = "top_class_detection"
|
||||
prefix: str = "top_class"
|
||||
metrics: List[TopClassMetricConfig] = Field(
|
||||
metrics: list[TopClassMetricConfig] = Field(
|
||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
||||
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
@ -50,18 +51,17 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
for pred in prediction.predictions
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
# Take the highest score for each prediction
|
||||
scores = [pred.class_scores.max() for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
for match in match_detections_and_gts(
|
||||
ground_truths=gts,
|
||||
detections=preds,
|
||||
affinity=self.affinity,
|
||||
score=lambda pred: pred.class_scores.max(),
|
||||
strict_match=self.strict_match,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
gt = match.annotation
|
||||
pred = match.prediction
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
)
|
||||
@ -69,11 +69,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
class_idx = (
|
||||
pred.class_scores.argmax() if pred is not None else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx]) if pred is not None else 0
|
||||
)
|
||||
|
||||
pred_class = (
|
||||
self.targets.class_names[class_idx]
|
||||
if class_idx is not None
|
||||
@ -90,7 +85,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
true_class=true_class,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
pred_class=pred_class,
|
||||
score=score,
|
||||
score=match.prediction_score,
|
||||
)
|
||||
)
|
||||
|
||||
@ -106,9 +101,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
plots = [
|
||||
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
return TopClassDetectionTask.build(
|
||||
config=config,
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return TopClassDetectionTask(
|
||||
prefix=config.prefix,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
affinity=affinity,
|
||||
strict_match=config.strict_match,
|
||||
)
|
||||
|
||||
@ -81,11 +81,11 @@ class MatcherProtocol(Protocol):
|
||||
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||
|
||||
|
||||
class AffinityFunction(Protocol, Generic[Geom]):
|
||||
class AffinityFunction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
geometry1: Geom,
|
||||
geometry2: Geom,
|
||||
detection: RawPrediction,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float: ...
|
||||
|
||||
|
||||
|
||||
@ -28,13 +28,13 @@ def test_has_tag(sound_event: data.SoundEvent):
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
@ -51,15 +51,15 @@ def test_has_all_tags(sound_event: data.SoundEvent):
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
@ -67,8 +67,8 @@ def test_has_all_tags(sound_event: data.SoundEvent):
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
@ -76,9 +76,9 @@ def test_has_all_tags(sound_event: data.SoundEvent):
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
data.Tag(key="sex", value="Female"), # type: ignore
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
data.Tag(key="sex", value="Female"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
@ -96,15 +96,15 @@ def test_has_any_tags(sound_event: data.SoundEvent):
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
@ -112,8 +112,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
@ -121,8 +121,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||
data.Tag(key="event", value="Social"), # type: ignore
|
||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||
data.Tag(key="event", value="Social"),
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
@ -140,21 +140,21 @@ def test_not(sound_event: data.SoundEvent):
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
@ -402,31 +402,6 @@ def test_has_tags_fails_if_empty():
|
||||
""")
|
||||
|
||||
|
||||
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: eq
|
||||
hertz: 200
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_duration_is_false_if_no_geometry(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: eq
|
||||
seconds: 1
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_all_of(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: all_of
|
||||
@ -444,7 +419,7 @@ def test_all_of(recording: data.Recording):
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
@ -453,7 +428,7 @@ def test_all_of(recording: data.Recording):
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
@ -462,7 +437,7 @@ def test_all_of(recording: data.Recording):
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
@ -484,7 +459,7 @@ def test_any_of(recording: data.Recording):
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
@ -493,7 +468,7 @@ def test_any_of(recording: data.Recording):
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
@ -502,7 +477,7 @@ def test_any_of(recording: data.Recording):
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
@ -511,6 +486,6 @@ def test_any_of(recording: data.Recording):
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user