Using matching and affinity functions from soundevent

This commit is contained in:
mbsantiago 2025-12-12 19:25:01 +00:00
parent 113f438e74
commit f71fe0c2e2
12 changed files with 341 additions and 532 deletions

View File

@ -62,6 +62,10 @@ class BaseConfig(BaseModel):
)
)
@classmethod
def from_yaml(cls, yaml_str: str):
return cls.model_validate(yaml.safe_load(yaml_str))
T = TypeVar("T", bound=BaseModel)

View File

@ -2,75 +2,98 @@ from typing import Annotated, Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.geometry import compute_interval_overlap
from soundevent.geometry import (
buffer_geometry,
compute_bbox_iou,
compute_geometric_iou,
compute_temporal_closeness,
compute_temporal_iou,
)
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.typing.evaluate import AffinityFunction
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import AffinityFunction, RawPrediction
affinity_functions: Registry[AffinityFunction, []] = Registry(
"matching_strategy"
"affinity_function"
)
class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity"
time_buffer: float = 0.01
position: Literal["start", "end", "center"] | float = "start"
max_distance: float = 0.01
class TimeAffinity(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __init__(
self,
max_distance: float = 0.01,
position: Literal["start", "end", "center"] | float = "start",
):
if position == "start":
position = 0
elif position == "end":
position = 1
elif position == "center":
position = 0.5
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_timestamp_affinity(
geometry1, geometry2, time_buffer=self.time_buffer
self.position = position
self.max_distance = max_distance
def __call__(
self,
detection: RawPrediction,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
return compute_temporal_closeness(
target_geometry,
source_geometry,
ratio=self.position,
max_distance=self.max_distance,
)
@affinity_functions.register(TimeAffinityConfig)
@staticmethod
def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer)
def compute_timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
return TimeAffinity(
max_distance=config.max_distance,
position=config.position,
)
class IntervalIOUConfig(BaseConfig):
name: Literal["interval_iou"] = "interval_iou"
time_buffer: float = 0.01
time_buffer: float = 0.0
class IntervalIOU(AffinityFunction):
def __init__(self, time_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_interval_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
def __call__(
self,
detection: RawPrediction,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
)
return compute_temporal_iou(target_geometry, source_geometry)
@affinity_functions.register(IntervalIOUConfig)
@staticmethod
@ -78,64 +101,44 @@ class IntervalIOU(AffinityFunction):
return IntervalIOU(time_buffer=config.time_buffer)
def compute_interval_iou(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
time_buffer: float = 0.0
freq_buffer: float = 0.0
class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
if freq_buffer < 0:
raise ValueError("freq_buffer must be non-negative")
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
if not isinstance(geometry1, data.BoundingBox):
raise TypeError(
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
def __call__(
self,
prediction: RawPrediction,
gt: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry
source_geometry = prediction.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
if not isinstance(geometry2, data.BoundingBox):
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
return compute_bbox_iou(target_geometry, source_geometry)
@affinity_functions.register(BBoxIOUConfig)
@staticmethod
@ -146,65 +149,44 @@ class BBoxIOU(AffinityFunction):
)
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
time_buffer: float = 0.0
freq_buffer: float = 0.0
class GeometricIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_affinity(
geometry1,
geometry2,
time_buffer=self.time_buffer,
if freq_buffer < 0:
raise ValueError("freq_buffer must be non-negative")
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(
self,
prediction: RawPrediction,
gt: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry
source_geometry = prediction.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
return compute_geometric_iou(target_geometry, source_geometry)
@affinity_functions.register(GeometricIOUConfig)
@staticmethod
@ -213,7 +195,10 @@ class GeometricIOU(AffinityFunction):
AffinityConfig = Annotated[
TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig,
TimeAffinityConfig
| IntervalIOUConfig
| BBoxIOUConfig
| GeometricIOUConfig,
Field(discriminator="name"),
]

View File

@ -31,93 +31,6 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
matching_strategies = Registry("matching_strategy")
def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
scores: Sequence[float] | None = None,
targets: TargetProtocol | None = None,
matcher: MatcherProtocol | None = None,
) -> ClipMatches:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in sound_event_annotations
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
if scores is None:
scores = [
raw_prediction.detection_score
for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
sound_event_annotations[target_idx]
if target_idx is not None
else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
class_name: score
for class_name, score in zip(
targets.class_names,
prediction.class_scores, strict=False,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return ClipMatches(clip=clip, matches=matches)
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time_match"] = "start_time_match"
distance_threshold: float = 0.01
@ -514,99 +427,9 @@ class OptimalMatcher(MatcherProtocol):
MatchConfig = Annotated[
GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig,
GreedyMatchConfig
| StartTimeMatchConfig
| OptimalMatchConfig
| GreedyAffinityMatchConfig,
Field(discriminator="name"),
]
def compute_affinity_matrix(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
affinity_function: AffinityFunction,
time_scale: float = 1,
frequency_scale: float = 1,
) -> np.ndarray:
# Scale geometries if necessary
if time_scale != 1 or frequency_scale != 1:
ground_truth = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in ground_truth
]
predictions = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in predictions
]
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
for gt_idx, gt_geometry in enumerate(ground_truth):
for pred_idx, pred_geometry in enumerate(predictions):
affinity = affinity_function(
gt_geometry,
pred_geometry,
)
affinity_matrix[gt_idx, pred_idx] = affinity
return affinity_matrix
def select_optimal_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape
gts = set(range(num_gt))
preds = set(range(num_pred))
assiged_rows, assigned_columns = linear_sum_assignment(
affinity_matrix,
maximize=True,
)
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
affinity = float(affinity_matrix[gt_idx, pred_idx])
if affinity <= affinity_threshold:
continue
yield gt_idx, pred_idx, affinity
gts.remove(gt_idx)
preds.remove(pred_idx)
for gt_idx in gts:
yield gt_idx, None, 0
for pred_idx in preds:
yield None, pred_idx, 0
def select_greedy_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[int | None, int | None, float]]:
num_gt, num_pred = affinity_matrix.shape
unmatched_pred = set(range(num_pred))
for gt_idx in range(num_gt):
row = affinity_matrix[gt_idx]
top_pred = int(np.argmax(row))
top_affinity = float(row[top_pred])
if (
top_affinity <= affinity_threshold
or top_pred not in unmatched_pred
):
yield None, gt_idx, 0
continue
unmatched_pred.remove(top_pred)
yield top_pred, gt_idx, top_affinity
for pred_idx in unmatched_pred:
yield pred_idx, None, 0
def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategies.build(config)

View File

@ -210,7 +210,10 @@ class DetectionPrecision:
DetectionMetricConfig = Annotated[
DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig,
DetectionAveragePrecisionConfig
| DetectionROCAUCConfig
| DetectionRecallConfig
| DetectionPrecisionConfig,
Field(discriminator="name"),
]

View File

@ -14,16 +14,15 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.match import (
MatchConfig,
StartTimeMatchConfig,
build_matcher,
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig
from batdetect2.typing import (
AffinityFunction,
BatDetect2Prediction,
EvaluatorProtocol,
RawPrediction,
TargetProtocol,
)
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BaseTaskConfig",
@ -40,39 +39,34 @@ T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig):
prefix: str
ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol
matcher: MatcherProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
prefix: str
ignore_start_end: float
def __init__(
self,
matcher: MatcherProtocol,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str,
plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
ignore_start_end: float = 0.01,
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
):
self.matcher = matcher
self.prefix = prefix
self.targets = targets
self.metrics = metrics
self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end
def compute_metrics(
@ -100,7 +94,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions, strict=False)
for clip_annotation, preds in zip(
clip_annotations, predictions, strict=False
)
]
def evaluate_clip(
@ -118,9 +114,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
@ -138,25 +131,40 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self.ignore_start_end,
)
@classmethod
def build(
cls,
config: BaseTaskConfig,
class BaseSEDTaskConfig(BaseTaskConfig):
affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
affinity_threshold: float = 0
strict_match: bool = True
class BaseSEDTask(BaseTask[T_Output]):
affinity: AffinityFunction
def __init__(
self,
prefix: str,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None,
**kwargs,
affinity: AffinityFunction,
plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
affinity_threshold: float = 0,
ignore_start_end: float = 0.01,
strict_match: bool = True,
):
matcher = build_matcher(config.matching_strategy)
return cls(
matcher=matcher,
targets=targets,
super().__init__(
prefix=prefix,
metrics=metrics,
plots=plots,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
**kwargs,
targets=targets,
ignore_start_end=ignore_start_end,
)
self.affinity = affinity
self.affinity_threshold = affinity_threshold
self.strict_match = strict_match
def is_in_bounds(

View File

@ -1,11 +1,11 @@
from typing import (
List,
Literal,
)
from functools import partial
from typing import Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig,
ClassificationMetricConfig,
@ -18,24 +18,28 @@ from batdetect2.evaluate.plots.classification import (
build_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
BaseSEDTask,
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
TargetProtocol,
)
class ClassificationTaskConfig(BaseTaskConfig):
class ClassificationTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification"
metrics: List[ClassificationMetricConfig] = Field(
metrics: list[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
)
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True
class ClassificationTask(BaseTask[ClipEval]):
class ClassificationTask(BaseSEDTask[ClipEval]):
def __init__(
self,
*args,
@ -73,40 +77,39 @@ class ClassificationTask(BaseTask[ClipEval]):
gts = [
sound_event
for sound_event in all_gts
if self.is_class(sound_event, class_name)
if is_target_class(
sound_event,
class_name,
self.targets,
include_generics=self.include_generics,
)
]
scores = [float(pred.class_scores[class_idx]) for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
for match in match_detections_and_gts(
detections=preds,
ground_truths=gts,
affinity=self.affinity,
score=partial(get_class_score, class_idx=class_idx),
strict_match=self.strict_match,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
self.targets.encode_class(match.annotation)
if match.annotation is not None
else None
)
score = (
float(pred.class_scores[class_idx])
if pred is not None
else 0
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
is_generic=gt is not None and true_class is None,
gt=match.annotation,
pred=match.prediction,
is_prediction=match.prediction is not None,
is_ground_truth=match.annotation is not None,
is_generic=match.annotation is not None
and true_class is None,
true_class=true_class,
score=score,
score=match.prediction_score,
)
)
@ -114,20 +117,6 @@ class ClassificationTask(BaseTask[ClipEval]):
return ClipEval(clip=clip, matches=per_class_matches)
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@tasks_registry.register(ClassificationTaskConfig)
@staticmethod
def from_config(
@ -142,9 +131,32 @@ class ClassificationTask(BaseTask[ClipEval]):
build_classification_plotter(plot, targets)
for plot in config.plots
]
return ClassificationTask.build(
config=config,
affinity = build_affinity_function(config.affinity)
return ClassificationTask(
affinity=affinity,
prefix=config.prefix,
plots=plots,
targets=targets,
metrics=metrics,
strict_match=config.strict_match,
)
def get_class_score(pred: RawPrediction, class_idx: int) -> float:
return pred.class_scores[class_idx]
def is_target_class(
sound_event: data.SoundEventAnnotation,
class_name: str,
targets: TargetProtocol,
include_generics: bool = True,
) -> bool:
sound_event_class = targets.encode_class(sound_event)
if sound_event_class is None and include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
@ -19,19 +19,18 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification"
metrics: List[ClipClassificationMetricConfig] = Field(
metrics: list[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(),
]
)
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]):
@ -78,8 +77,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
build_clip_classification_plotter(plot, targets)
for plot in config.plots
]
return ClipClassificationTask.build(
config=config,
return ClipClassificationTask(
prefix=config.prefix,
plots=plots,
metrics=metrics,
targets=targets,

View File

@ -1,4 +1,4 @@
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
@ -18,19 +18,18 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection"
metrics: List[ClipDetectionMetricConfig] = Field(
metrics: list[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(),
]
)
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]):
@ -69,8 +68,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
build_clip_detection_plotter(plot, targets)
for plot in config.plots
]
return ClipDetectionTask.build(
config=config,
return ClipDetectionTask(
prefix=config.prefix,
metrics=metrics,
targets=targets,
plots=plots,

View File

@ -1,8 +1,10 @@
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.detection import (
ClipEval,
DetectionAveragePrecisionConfig,
@ -15,24 +17,24 @@ from batdetect2.evaluate.plots.detection import (
build_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
BaseSEDTask,
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class DetectionTaskConfig(BaseTaskConfig):
class DetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection"
metrics: List[DetectionMetricConfig] = Field(
metrics: list[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
)
plots: List[DetectionPlotConfig] = Field(default_factory=list)
plots: list[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]):
class DetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
@ -50,24 +52,22 @@ class DetectionTask(BaseTask[ClipEval]):
for pred in prediction.predictions
if self.include_prediction(pred, clip)
]
scores = [pred.detection_score for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
for match in match_detections_and_gts(
detections=preds,
ground_truths=gts,
affinity=self.affinity,
score=lambda pred: pred.detection_score,
strict_match=self.strict_match,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
matches.append(
MatchEval(
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
score=pred.detection_score if pred is not None else 0,
gt=match.annotation,
pred=match.prediction,
is_prediction=match.prediction is not None,
is_ground_truth=match.annotation is not None,
score=match.prediction_score,
)
)
@ -83,9 +83,12 @@ class DetectionTask(BaseTask[ClipEval]):
plots = [
build_detection_plotter(plot, targets) for plot in config.plots
]
return DetectionTask.build(
config=config,
affinity = build_affinity_function(config.affinity)
return DetectionTask(
prefix=config.prefix,
affinity=affinity,
metrics=metrics,
targets=targets,
plots=plots,
strict_match=config.strict_match,
)

View File

@ -1,8 +1,10 @@
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
MatchEval,
@ -15,24 +17,23 @@ from batdetect2.evaluate.plots.top_class import (
build_top_class_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
BaseSEDTask,
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
class TopClassDetectionTaskConfig(BaseTaskConfig):
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class"
metrics: List[TopClassMetricConfig] = Field(
metrics: list[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
)
plots: List[TopClassPlotConfig] = Field(default_factory=list)
plots: list[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]):
class TopClassDetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
@ -50,18 +51,17 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
for pred in prediction.predictions
if self.include_prediction(pred, clip)
]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
for match in match_detections_and_gts(
ground_truths=gts,
detections=preds,
affinity=self.affinity,
score=lambda pred: pred.class_scores.max(),
strict_match=self.strict_match,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
gt = match.annotation
pred = match.prediction
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
@ -69,11 +69,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
class_idx = (
pred.class_scores.argmax() if pred is not None else None
)
score = (
float(pred.class_scores[class_idx]) if pred is not None else 0
)
pred_class = (
self.targets.class_names[class_idx]
if class_idx is not None
@ -90,7 +85,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
true_class=true_class,
is_generic=gt is not None and true_class is None,
pred_class=pred_class,
score=score,
score=match.prediction_score,
)
)
@ -106,9 +101,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
plots = [
build_top_class_plotter(plot, targets) for plot in config.plots
]
return TopClassDetectionTask.build(
config=config,
affinity = build_affinity_function(config.affinity)
return TopClassDetectionTask(
prefix=config.prefix,
plots=plots,
metrics=metrics,
targets=targets,
affinity=affinity,
strict_match=config.strict_match,
)

View File

@ -81,11 +81,11 @@ class MatcherProtocol(Protocol):
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
class AffinityFunction(Protocol, Generic[Geom]):
class AffinityFunction(Protocol):
def __call__(
self,
geometry1: Geom,
geometry2: Geom,
detection: RawPrediction,
ground_truth: data.SoundEventAnnotation,
) -> float: ...

View File

@ -28,13 +28,13 @@ def test_has_tag(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert not condition(sound_event_annotation)
@ -51,15 +51,15 @@ def test_has_all_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Echolocation"),
],
)
assert not condition(sound_event_annotation)
@ -67,8 +67,8 @@ def test_has_all_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
],
)
assert condition(sound_event_annotation)
@ -76,9 +76,9 @@ def test_has_all_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="sex", value="Female"), # type: ignore
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
data.Tag(key="sex", value="Female"),
],
)
assert condition(sound_event_annotation)
@ -96,15 +96,15 @@ def test_has_any_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Echolocation"),
],
)
assert condition(sound_event_annotation)
@ -112,8 +112,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
],
)
assert condition(sound_event_annotation)
@ -121,8 +121,8 @@ def test_has_any_tags(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Social"), # type: ignore
data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Social"),
],
)
assert not condition(sound_event_annotation)
@ -140,21 +140,21 @@ def test_not(sound_event: data.SoundEvent):
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
],
)
assert not condition(sound_event_annotation)
@ -402,31 +402,6 @@ def test_has_tags_fails_if_empty():
""")
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_duration_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: duration
operator: eq
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_all_of(recording: data.Recording):
condition = build_condition_from_str("""
name: all_of
@ -444,7 +419,7 @@ def test_all_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(se)
@ -453,7 +428,7 @@ def test_all_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert not condition(se)
@ -462,7 +437,7 @@ def test_all_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert not condition(se)
@ -484,7 +459,7 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert not condition(se)
@ -493,7 +468,7 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(se)
@ -502,7 +477,7 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(se)
@ -511,6 +486,6 @@ def test_any_of(recording: data.Recording):
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert condition(se)