mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Added bbox iou affinity function
This commit is contained in:
parent
10865ee600
commit
c9f0c5c431
@ -3,6 +3,7 @@ from typing import Annotated, Literal, Optional, Union
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.geometry import compute_interval_overlap
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
@ -93,9 +94,11 @@ def compute_interval_iou(
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
intersection = max(
|
||||
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
||||
intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||
)
|
||||
@ -106,6 +109,86 @@ def compute_interval_iou(
|
||||
return intersection / union
|
||||
|
||||
|
||||
class BBoxIOUConfig(BaseConfig):
|
||||
name: Literal["bbox_iou"] = "bbox_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
|
||||
|
||||
class BBoxIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
self.freq_buffer = freq_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
if not isinstance(geometry1, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
||||
)
|
||||
|
||||
if not isinstance(geometry2, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
||||
)
|
||||
return bbox_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.freq_buffer,
|
||||
)
|
||||
|
||||
@affinity_functions.register(BBoxIOUConfig)
|
||||
@staticmethod
|
||||
def from_config(config: BBoxIOUConfig):
|
||||
return BBoxIOU(
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=config.freq_buffer,
|
||||
)
|
||||
|
||||
|
||||
def bbox_iou(
|
||||
geometry1: data.BoundingBox,
|
||||
geometry2: data.BoundingBox,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
||||
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
low_freq1 -= freq_buffer
|
||||
low_freq2 -= freq_buffer
|
||||
high_freq1 += freq_buffer
|
||||
high_freq2 += freq_buffer
|
||||
|
||||
time_intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
freq_intersection = max(
|
||||
0,
|
||||
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
||||
)
|
||||
|
||||
intersection = time_intersection * freq_intersection
|
||||
|
||||
if intersection == 0:
|
||||
return 0
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
||||
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
||||
- intersection
|
||||
)
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class GeometricIOUConfig(BaseConfig):
|
||||
name: Literal["geometric_iou"] = "geometric_iou"
|
||||
time_buffer: float = 0.01
|
||||
@ -133,6 +216,7 @@ AffinityConfig = Annotated[
|
||||
Union[
|
||||
TimeAffinityConfig,
|
||||
IntervalIOUConfig,
|
||||
BBoxIOUConfig,
|
||||
GeometricIOUConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user