mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Added bbox iou affinity function
This commit is contained in:
parent
10865ee600
commit
c9f0c5c431
@ -3,6 +3,7 @@ from typing import Annotated, Literal, Optional, Union
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
|
from soundevent.geometry import compute_interval_overlap
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.core.registries import Registry
|
||||||
@ -93,9 +94,11 @@ def compute_interval_iou(
|
|||||||
end_time1 += time_buffer
|
end_time1 += time_buffer
|
||||||
end_time2 += time_buffer
|
end_time2 += time_buffer
|
||||||
|
|
||||||
intersection = max(
|
intersection = compute_interval_overlap(
|
||||||
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
(start_time1, end_time1),
|
||||||
|
(start_time2, end_time2),
|
||||||
)
|
)
|
||||||
|
|
||||||
union = (
|
union = (
|
||||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||||
)
|
)
|
||||||
@ -106,6 +109,86 @@ def compute_interval_iou(
|
|||||||
return intersection / union
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxIOUConfig(BaseConfig):
|
||||||
|
name: Literal["bbox_iou"] = "bbox_iou"
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
freq_buffer: float = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxIOU(AffinityFunction):
|
||||||
|
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||||
|
self.time_buffer = time_buffer
|
||||||
|
self.freq_buffer = freq_buffer
|
||||||
|
|
||||||
|
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||||
|
if not isinstance(geometry1, data.BoundingBox):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(geometry2, data.BoundingBox):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
||||||
|
)
|
||||||
|
return bbox_iou(
|
||||||
|
geometry1,
|
||||||
|
geometry2,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.freq_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
@affinity_functions.register(BBoxIOUConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: BBoxIOUConfig):
|
||||||
|
return BBoxIOU(
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=config.freq_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_iou(
|
||||||
|
geometry1: data.BoundingBox,
|
||||||
|
geometry2: data.BoundingBox,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float:
|
||||||
|
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
||||||
|
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
||||||
|
|
||||||
|
start_time1 -= time_buffer
|
||||||
|
start_time2 -= time_buffer
|
||||||
|
end_time1 += time_buffer
|
||||||
|
end_time2 += time_buffer
|
||||||
|
|
||||||
|
low_freq1 -= freq_buffer
|
||||||
|
low_freq2 -= freq_buffer
|
||||||
|
high_freq1 += freq_buffer
|
||||||
|
high_freq2 += freq_buffer
|
||||||
|
|
||||||
|
time_intersection = compute_interval_overlap(
|
||||||
|
(start_time1, end_time1),
|
||||||
|
(start_time2, end_time2),
|
||||||
|
)
|
||||||
|
|
||||||
|
freq_intersection = max(
|
||||||
|
0,
|
||||||
|
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
||||||
|
)
|
||||||
|
|
||||||
|
intersection = time_intersection * freq_intersection
|
||||||
|
|
||||||
|
if intersection == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
union = (
|
||||||
|
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
||||||
|
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
||||||
|
- intersection
|
||||||
|
)
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOUConfig(BaseConfig):
|
class GeometricIOUConfig(BaseConfig):
|
||||||
name: Literal["geometric_iou"] = "geometric_iou"
|
name: Literal["geometric_iou"] = "geometric_iou"
|
||||||
time_buffer: float = 0.01
|
time_buffer: float = 0.01
|
||||||
@ -133,6 +216,7 @@ AffinityConfig = Annotated[
|
|||||||
Union[
|
Union[
|
||||||
TimeAffinityConfig,
|
TimeAffinityConfig,
|
||||||
IntervalIOUConfig,
|
IntervalIOUConfig,
|
||||||
|
BBoxIOUConfig,
|
||||||
GeometricIOUConfig,
|
GeometricIOUConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user