Added bbox iou affinity function

This commit is contained in:
mbsantiago 2025-09-28 16:08:21 +01:00
parent 10865ee600
commit c9f0c5c431

View File

@ -3,6 +3,7 @@ from typing import Annotated, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.geometry import compute_interval_overlap
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
@ -93,9 +94,11 @@ def compute_interval_iou(
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
@ -106,6 +109,86 @@ def compute_interval_iou(
return intersection / union
class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float):
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
if not isinstance(geometry1, data.BoundingBox):
raise TypeError(
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
)
if not isinstance(geometry2, data.BoundingBox):
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
@affinity_functions.register(BBoxIOUConfig)
@staticmethod
def from_config(config: BBoxIOUConfig):
return BBoxIOU(
time_buffer=config.time_buffer,
freq_buffer=config.freq_buffer,
)
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01
@ -133,6 +216,7 @@ AffinityConfig = Annotated[
Union[
TimeAffinityConfig,
IntervalIOUConfig,
BBoxIOUConfig,
GeometricIOUConfig,
],
Field(discriminator="name"),