mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Restructure eval metrics and plotting
This commit is contained in:
parent
ec1c0ff020
commit
e752e96b93
@ -1,6 +1,7 @@
|
||||
from typing import Generic, Protocol, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
__all__ = [
|
||||
"Registry",
|
||||
@ -8,26 +9,36 @@ __all__ = [
|
||||
|
||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||
T_Type = TypeVar("T_Type", covariant=True)
|
||||
P_Type = ParamSpec("P_Type")
|
||||
|
||||
|
||||
class LogicProtocol(Generic[T_Config, T_Type], Protocol):
|
||||
"""A generic protocol for the logic classes (conditions or transforms)."""
|
||||
class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol):
|
||||
"""A generic protocol for the logic classes."""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: T_Config) -> T_Type: ...
|
||||
def from_config(
|
||||
cls,
|
||||
config: T_Config,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
) -> T_Type: ...
|
||||
|
||||
|
||||
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
|
||||
|
||||
|
||||
class Registry(Generic[T_Type]):
|
||||
class Registry(Generic[T_Type, P_Type]):
|
||||
"""A generic class to create and manage a registry of items."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._registry = {}
|
||||
|
||||
def register(self, config_cls: Type[T_Config]):
|
||||
def register(
|
||||
self,
|
||||
config_cls: Type[T_Config],
|
||||
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
||||
) -> None:
|
||||
"""A decorator factory to register a new item."""
|
||||
fields = config_cls.model_fields
|
||||
|
||||
@ -39,13 +50,14 @@ class Registry(Generic[T_Type]):
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("'name' field must be a string literal.")
|
||||
|
||||
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
|
||||
self._registry[name] = logic_cls
|
||||
return logic_cls
|
||||
self._registry[name] = logic_cls
|
||||
|
||||
return decorator
|
||||
|
||||
def build(self, config: BaseModel) -> T_Type:
|
||||
def build(
|
||||
self,
|
||||
config: BaseModel,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
) -> T_Type:
|
||||
"""Builds a logic instance from a config object."""
|
||||
|
||||
name = getattr(config, "name") # noqa: B009
|
||||
@ -58,4 +70,4 @@ class Registry(Generic[T_Type]):
|
||||
f"No {self._name} with name '{name}' is registered."
|
||||
)
|
||||
|
||||
return self._registry[name].from_config(config)
|
||||
return self._registry[name].from_config(config, *args, **kwargs)
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.data._core import Registry
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
_conditions: Registry[SoundEventCondition] = Registry("condition")
|
||||
condition_registry: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
@ -18,7 +18,6 @@ class HasTagConfig(BaseConfig):
|
||||
tag: data.Tag
|
||||
|
||||
|
||||
@_conditions.register(HasTagConfig)
|
||||
class HasTag:
|
||||
def __init__(self, tag: data.Tag):
|
||||
self.tag = tag
|
||||
@ -33,12 +32,14 @@ class HasTag:
|
||||
return cls(tag=config.tag)
|
||||
|
||||
|
||||
condition_registry.register(HasTagConfig, HasTag)
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
@_conditions.register(HasAllTagsConfig)
|
||||
class HasAllTags:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
@ -56,12 +57,14 @@ class HasAllTags:
|
||||
return cls(tags=config.tags)
|
||||
|
||||
|
||||
condition_registry.register(HasAllTagsConfig, HasAllTags)
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
@_conditions.register(HasAnyTagConfig)
|
||||
class HasAnyTag:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
@ -79,6 +82,8 @@ class HasAnyTag:
|
||||
return cls(tags=config.tags)
|
||||
|
||||
|
||||
condition_registry.register(HasAnyTagConfig, HasAnyTag)
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
@ -109,7 +114,6 @@ def _build_comparator(
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
|
||||
|
||||
@_conditions.register(DurationConfig)
|
||||
class Duration:
|
||||
def __init__(self, operator: Operator, seconds: float):
|
||||
self.operator = operator
|
||||
@ -135,6 +139,9 @@ class Duration:
|
||||
return cls(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
condition_registry.register(DurationConfig, Duration)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
name: Literal["frequency"] = "frequency"
|
||||
boundary: Literal["low", "high"]
|
||||
@ -142,7 +149,6 @@ class FrequencyConfig(BaseConfig):
|
||||
hertz: float
|
||||
|
||||
|
||||
@_conditions.register(FrequencyConfig)
|
||||
class Frequency:
|
||||
def __init__(
|
||||
self,
|
||||
@ -184,12 +190,14 @@ class Frequency:
|
||||
)
|
||||
|
||||
|
||||
condition_registry.register(FrequencyConfig, Frequency)
|
||||
|
||||
|
||||
class AllOfConfig(BaseConfig):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@_conditions.register(AllOfConfig)
|
||||
class AllOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
@ -207,12 +215,14 @@ class AllOf:
|
||||
return cls(conditions)
|
||||
|
||||
|
||||
condition_registry.register(AllOfConfig, AllOf)
|
||||
|
||||
|
||||
class AnyOfConfig(BaseConfig):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: List["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@_conditions.register(AnyOfConfig)
|
||||
class AnyOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
@ -230,12 +240,14 @@ class AnyOf:
|
||||
return cls(conditions)
|
||||
|
||||
|
||||
condition_registry.register(AnyOfConfig, AnyOf)
|
||||
|
||||
|
||||
class NotConfig(BaseConfig):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "SoundEventConditionConfig"
|
||||
|
||||
|
||||
@_conditions.register(NotConfig)
|
||||
class Not:
|
||||
def __init__(self, condition: SoundEventCondition):
|
||||
self.condition = condition
|
||||
@ -251,6 +263,8 @@ class Not:
|
||||
return cls(condition)
|
||||
|
||||
|
||||
condition_registry.register(NotConfig, Not)
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
Union[
|
||||
HasTagConfig,
|
||||
@ -269,7 +283,7 @@ SoundEventConditionConfig = Annotated[
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
) -> SoundEventCondition:
|
||||
return _conditions.build(config)
|
||||
return condition_registry.build(config)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
|
||||
@ -17,7 +17,7 @@ SoundEventTransform = Callable[
|
||||
data.SoundEventAnnotation,
|
||||
]
|
||||
|
||||
_transforms: Registry[SoundEventTransform] = Registry("transform")
|
||||
transform_registry: Registry[SoundEventTransform, []] = Registry("transform")
|
||||
|
||||
|
||||
class SetFrequencyBoundConfig(BaseConfig):
|
||||
@ -26,7 +26,6 @@ class SetFrequencyBoundConfig(BaseConfig):
|
||||
hertz: float
|
||||
|
||||
|
||||
@_transforms.register(SetFrequencyBoundConfig)
|
||||
class SetFrequencyBound:
|
||||
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
|
||||
self.hertz = hertz
|
||||
@ -69,13 +68,15 @@ class SetFrequencyBound:
|
||||
return cls(hertz=config.hertz, boundary=config.boundary)
|
||||
|
||||
|
||||
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
|
||||
|
||||
|
||||
class ApplyIfConfig(BaseConfig):
|
||||
name: Literal["apply_if"] = "apply_if"
|
||||
transform: "SoundEventTransformConfig"
|
||||
condition: SoundEventConditionConfig
|
||||
|
||||
|
||||
@_transforms.register(ApplyIfConfig)
|
||||
class ApplyIf:
|
||||
def __init__(
|
||||
self,
|
||||
@ -101,13 +102,15 @@ class ApplyIf:
|
||||
return cls(condition=condition, transform=transform)
|
||||
|
||||
|
||||
transform_registry.register(ApplyIfConfig, ApplyIf)
|
||||
|
||||
|
||||
class ReplaceTagConfig(BaseConfig):
|
||||
name: Literal["replace_tag"] = "replace_tag"
|
||||
original: data.Tag
|
||||
replacement: data.Tag
|
||||
|
||||
|
||||
@_transforms.register(ReplaceTagConfig)
|
||||
class ReplaceTag:
|
||||
def __init__(
|
||||
self,
|
||||
@ -136,6 +139,9 @@ class ReplaceTag:
|
||||
return cls(original=config.original, replacement=config.replacement)
|
||||
|
||||
|
||||
transform_registry.register(ReplaceTagConfig, ReplaceTag)
|
||||
|
||||
|
||||
class MapTagValueConfig(BaseConfig):
|
||||
name: Literal["map_tag_value"] = "map_tag_value"
|
||||
tag_key: str
|
||||
@ -143,7 +149,6 @@ class MapTagValueConfig(BaseConfig):
|
||||
target_key: Optional[str] = None
|
||||
|
||||
|
||||
@_transforms.register(MapTagValueConfig)
|
||||
class MapTagValue:
|
||||
def __init__(
|
||||
self,
|
||||
@ -193,12 +198,14 @@ class MapTagValue:
|
||||
)
|
||||
|
||||
|
||||
transform_registry.register(MapTagValueConfig, MapTagValue)
|
||||
|
||||
|
||||
class ApplyAllConfig(BaseConfig):
|
||||
name: Literal["apply_all"] = "apply_all"
|
||||
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
||||
|
||||
|
||||
@_transforms.register(ApplyAllConfig)
|
||||
class ApplyAll:
|
||||
def __init__(self, steps: List[SoundEventTransform]):
|
||||
self.steps = steps
|
||||
@ -218,6 +225,8 @@ class ApplyAll:
|
||||
return cls(steps)
|
||||
|
||||
|
||||
transform_registry.register(ApplyAllConfig, ApplyAll)
|
||||
|
||||
SoundEventTransformConfig = Annotated[
|
||||
Union[
|
||||
SetFrequencyBoundConfig,
|
||||
@ -233,7 +242,7 @@ SoundEventTransformConfig = Annotated[
|
||||
def build_sound_event_transform(
|
||||
config: SoundEventTransformConfig,
|
||||
) -> SoundEventTransform:
|
||||
return _transforms.build(config)
|
||||
return transform_registry.build(config)
|
||||
|
||||
|
||||
def transform_clip_annotation(
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
"Evaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
151
src/batdetect2/evaluate/affinity.py
Normal file
151
src/batdetect2/evaluate/affinity.py
Normal file
@ -0,0 +1,151 @@
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.typing.evaluate import AffinityFunction
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
"matching_strategy"
|
||||
)
|
||||
|
||||
|
||||
class TimeAffinityConfig(BaseConfig):
|
||||
name: Literal["time_affinity"] = "time_affinity"
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
class TimeAffinity(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_timestamp_affinity(
|
||||
geometry1, geometry2, time_buffer=self.time_buffer
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: TimeAffinityConfig):
|
||||
return cls(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
|
||||
|
||||
|
||||
def compute_timestamp_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeStamp)
|
||||
assert isinstance(geometry2, data.TimeStamp)
|
||||
|
||||
start_time1 = geometry1.coordinates
|
||||
start_time2 = geometry2.coordinates
|
||||
|
||||
a = min(start_time1, start_time2)
|
||||
b = max(start_time1, start_time2)
|
||||
|
||||
if b - a >= 2 * time_buffer:
|
||||
return 0
|
||||
|
||||
intersection = a - b + 2 * time_buffer
|
||||
union = b - a + 2 * time_buffer
|
||||
return intersection / union
|
||||
|
||||
|
||||
class IntervalIOUConfig(BaseConfig):
|
||||
name: Literal["interval_iou"] = "interval_iou"
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
class IntervalIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_interval_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: IntervalIOUConfig):
|
||||
return cls(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
|
||||
|
||||
|
||||
def compute_interval_iou(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeInterval)
|
||||
assert isinstance(geometry2, data.TimeInterval)
|
||||
|
||||
start_time1, end_time1 = geometry1.coordinates
|
||||
start_time2, end_time2 = geometry1.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
intersection = max(
|
||||
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
||||
)
|
||||
union = (
|
||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||
)
|
||||
|
||||
if union == 0:
|
||||
return 0
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class GeometricIOUConfig(BaseConfig):
|
||||
name: Literal["geometric_iou"] = "geometric_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
|
||||
|
||||
class GeometricIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_affinity(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: GeometricIOUConfig):
|
||||
return cls(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
|
||||
|
||||
AffinityConfig = Annotated[
|
||||
Union[
|
||||
TimeAffinityConfig,
|
||||
IntervalIOUConfig,
|
||||
GeometricIOUConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_affinity_function(
|
||||
config: Optional[AffinityConfig] = None,
|
||||
) -> AffinityFunction:
|
||||
config = config or GeometricIOUConfig()
|
||||
return affinity_functions.build(config)
|
||||
@ -1,10 +1,16 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAPConfig,
|
||||
DetectionAPConfig,
|
||||
MetricConfig,
|
||||
)
|
||||
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
@ -13,7 +19,19 @@ __all__ = [
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
ignore_start_end: float = 0.01
|
||||
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||
metrics: List[MetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionAPConfig(),
|
||||
ClassificationAPConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[PlotConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ExampleGalleryConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
|
||||
@ -3,60 +3,61 @@ from typing import List
|
||||
import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
|
||||
|
||||
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
|
||||
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
for match in matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
|
||||
if sound_event_annotation is not None:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
assert geometry is not None
|
||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||
compute_bounds(geometry)
|
||||
if sound_event_annotation is not None:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
assert geometry is not None
|
||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||
compute_bounds(geometry)
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||
compute_bounds(match.pred_geometry)
|
||||
)
|
||||
|
||||
data.append(
|
||||
{
|
||||
("recording", "uuid"): match.clip.recording.uuid,
|
||||
("clip", "uuid"): match.clip.uuid,
|
||||
("clip", "start_time"): match.clip.start_time,
|
||||
("clip", "end_time"): match.clip.end_time,
|
||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||
if match.sound_event_annotation is not None
|
||||
else None,
|
||||
("gt", "class"): match.gt_class,
|
||||
("gt", "det"): match.gt_det,
|
||||
("gt", "start_time"): gt_start_time,
|
||||
("gt", "end_time"): gt_end_time,
|
||||
("gt", "low_freq"): gt_low_freq,
|
||||
("gt", "high_freq"): gt_high_freq,
|
||||
("pred", "score"): match.pred_score,
|
||||
("pred", "class"): match.pred_class,
|
||||
("pred", "class_score"): match.pred_class_score,
|
||||
("pred", "start_time"): pred_start_time,
|
||||
("pred", "end_time"): pred_end_time,
|
||||
("pred", "low_freq"): pred_low_freq,
|
||||
("pred", "high_freq"): pred_high_freq,
|
||||
("match", "affinity"): match.affinity,
|
||||
**{
|
||||
("pred_class_score", key): value
|
||||
for key, value in match.pred_class_scores.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||
compute_bounds(match.pred_geometry)
|
||||
)
|
||||
|
||||
data.append(
|
||||
{
|
||||
("recording", "uuid"): match.clip.recording.uuid,
|
||||
("clip", "uuid"): match.clip.uuid,
|
||||
("clip", "start_time"): match.clip.start_time,
|
||||
("clip", "end_time"): match.clip.end_time,
|
||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||
if match.sound_event_annotation is not None
|
||||
else None,
|
||||
("gt", "class"): match.gt_class,
|
||||
("gt", "det"): match.gt_det,
|
||||
("gt", "start_time"): gt_start_time,
|
||||
("gt", "end_time"): gt_end_time,
|
||||
("gt", "low_freq"): gt_low_freq,
|
||||
("gt", "high_freq"): gt_high_freq,
|
||||
("pred", "score"): match.pred_score,
|
||||
("pred", "class"): match.pred_class,
|
||||
("pred", "class_score"): match.pred_class_score,
|
||||
("pred", "start_time"): pred_start_time,
|
||||
("pred", "end_time"): pred_end_time,
|
||||
("pred", "low_freq"): pred_low_freq,
|
||||
("pred", "high_freq"): pred_high_freq,
|
||||
("match", "affinity"): match.affinity,
|
||||
**{
|
||||
("pred_class_score", key): value
|
||||
for key, value in match.pred_class_scores.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||
return df
|
||||
|
||||
@ -4,11 +4,8 @@ import pandas as pd
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||
from batdetect2.evaluate.match import build_matcher, match_all_predictions
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.plotting.clips import build_audio_loader
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
@ -55,6 +52,8 @@ def evaluate(
|
||||
clip_annotations = []
|
||||
predictions = []
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation)
|
||||
|
||||
for batch in loader:
|
||||
outputs = model.detector(batch.spec)
|
||||
|
||||
@ -76,20 +75,12 @@ def evaluate(
|
||||
clip_annotations.extend(clip_annotations)
|
||||
predictions.extend(predictions)
|
||||
|
||||
matcher = build_matcher(config.evaluation.match)
|
||||
|
||||
matches = match_all_predictions(
|
||||
clip_annotations,
|
||||
predictions,
|
||||
targets=targets,
|
||||
matcher=matcher,
|
||||
)
|
||||
|
||||
matches = evaluator.evaluate(clip_annotations, predictions)
|
||||
df = extract_matches_dataframe(matches)
|
||||
|
||||
metrics = [
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
||||
DetectionAP(),
|
||||
ClassificationAP(class_names=targets.class_names),
|
||||
]
|
||||
|
||||
results = {
|
||||
|
||||
169
src/batdetect2/evaluate/evaluator.py
Normal file
169
src/batdetect2/evaluate/evaluator.py
Normal file
@ -0,0 +1,169 @@
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.match import build_matcher, match
|
||||
from batdetect2.evaluate.metrics import build_metric
|
||||
from batdetect2.evaluate.plots import build_plotter
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
MatcherProtocol,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
)
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
config: EvaluationConfig,
|
||||
targets: TargetProtocol,
|
||||
matcher: MatcherProtocol,
|
||||
metrics: List[MetricsProtocol],
|
||||
plots: List[PlotterProtocol],
|
||||
):
|
||||
self.config = config
|
||||
self.targets = targets
|
||||
self.matcher = matcher
|
||||
self.metrics = metrics
|
||||
self.plots = plots
|
||||
|
||||
def match(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEvaluation:
|
||||
clip = clip_annotation.clip
|
||||
ground_truth = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
predictions = [
|
||||
prediction
|
||||
for prediction in predictions
|
||||
if self.filter_predictions(prediction, clip)
|
||||
]
|
||||
return ClipEvaluation(
|
||||
clip=clip_annotation.clip,
|
||||
matches=match(
|
||||
ground_truth,
|
||||
predictions,
|
||||
clip=clip,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
),
|
||||
)
|
||||
|
||||
def filter_sound_event_annotations(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
if not self.targets.filter(sound_event_annotation):
|
||||
return False
|
||||
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
return is_in_bounds(
|
||||
geometry,
|
||||
clip,
|
||||
self.config.ignore_start_end,
|
||||
)
|
||||
|
||||
def filter_predictions(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
return is_in_bounds(
|
||||
prediction.geometry,
|
||||
clip,
|
||||
self.config.ignore_start_end,
|
||||
)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipEvaluation]:
|
||||
if len(clip_annotations) != len(predictions):
|
||||
raise ValueError(
|
||||
"Number of annotated clips and sets of predictions do not match"
|
||||
)
|
||||
|
||||
return [
|
||||
self.match(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEvaluation],
|
||||
) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for metric in self.metrics:
|
||||
results.update(metric(clip_evaluations))
|
||||
|
||||
return results
|
||||
|
||||
def generate_plots(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for plotter in self.plots:
|
||||
for name, fig in plotter(clip_evaluations):
|
||||
yield name, fig
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[EvaluationConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
plots: Optional[List[PlotterProtocol]] = None,
|
||||
metrics: Optional[List[MetricsProtocol]] = None,
|
||||
) -> Evaluator:
|
||||
config = config or EvaluationConfig()
|
||||
targets = targets or build_targets()
|
||||
matcher = matcher or build_matcher(config.match)
|
||||
|
||||
if metrics is None:
|
||||
metrics = [
|
||||
build_metric(config, targets.class_names)
|
||||
for config in config.metrics
|
||||
]
|
||||
|
||||
if plots is None:
|
||||
plots = [build_plotter(config) for config in config.plots]
|
||||
|
||||
return Evaluator(
|
||||
config=config,
|
||||
targets=targets,
|
||||
matcher=matcher,
|
||||
metrics=metrics,
|
||||
plots=plots,
|
||||
)
|
||||
|
||||
|
||||
def is_in_bounds(
|
||||
geometry: data.Geometry,
|
||||
clip: data.Clip,
|
||||
buffer: float,
|
||||
) -> bool:
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
return (start_time >= clip.start_time + buffer) and (
|
||||
start_time <= clip.end_time - buffer
|
||||
)
|
||||
@ -1,9 +1,7 @@
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
@ -12,6 +10,11 @@ from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
GeometricIOUConfig,
|
||||
build_affinity_function,
|
||||
)
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
MatchEvaluation,
|
||||
@ -23,7 +26,88 @@ from batdetect2.typing.postprocess import RawPrediction
|
||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||
"""The geometry representation to use for matching."""
|
||||
|
||||
matching_strategy = Registry("matching_strategy")
|
||||
matching_strategies = Registry("matching_strategy")
|
||||
|
||||
|
||||
def match(
|
||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||
raw_predictions: Sequence[RawPrediction],
|
||||
clip: data.Clip,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
if matcher is None:
|
||||
matcher = build_matcher()
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
sound_event_annotation.sound_event.geometry
|
||||
for sound_event_annotation in sound_event_annotations
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
scores = [
|
||||
raw_prediction.detection_score for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for source_idx, target_idx, affinity in matcher(
|
||||
ground_truth=target_geometries,
|
||||
predictions=predicted_geometries,
|
||||
scores=scores,
|
||||
):
|
||||
target = (
|
||||
sound_event_annotations[target_idx]
|
||||
if target_idx is not None
|
||||
else None
|
||||
)
|
||||
prediction = (
|
||||
raw_predictions[source_idx] if source_idx is not None else None
|
||||
)
|
||||
|
||||
gt_det = target_idx is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
|
||||
pred_geometry = (
|
||||
predicted_geometries[source_idx]
|
||||
if source_idx is not None
|
||||
else None
|
||||
)
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
str(class_name): float(score)
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.class_scores,
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEvaluation(
|
||||
clip=clip,
|
||||
sound_event_annotation=target,
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=class_scores,
|
||||
pred_geometry=pred_geometry,
|
||||
affinity=affinity,
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
@ -31,7 +115,6 @@ class StartTimeMatchConfig(BaseConfig):
|
||||
distance_threshold: float = 0.01
|
||||
|
||||
|
||||
@matching_strategy.register(StartTimeMatchConfig)
|
||||
class StartTimeMatcher(MatcherProtocol):
|
||||
def __init__(self, distance_threshold: float):
|
||||
self.distance_threshold = distance_threshold
|
||||
@ -54,6 +137,9 @@ class StartTimeMatcher(MatcherProtocol):
|
||||
return cls(distance_threshold=config.distance_threshold)
|
||||
|
||||
|
||||
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
|
||||
|
||||
|
||||
def match_start_times(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
@ -74,8 +160,8 @@ def match_start_times(
|
||||
|
||||
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
|
||||
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
|
||||
scores = np.array(scores)
|
||||
|
||||
scores = np.array(scores)
|
||||
sort_args = np.argsort(scores)[::-1]
|
||||
|
||||
distances = np.abs(gt_times[None, :] - pred_times[:, None])
|
||||
@ -143,89 +229,25 @@ _geometry_cast_functions: Mapping[
|
||||
}
|
||||
|
||||
|
||||
def _timestamp_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeStamp)
|
||||
assert isinstance(geometry2, data.TimeStamp)
|
||||
|
||||
start_time1 = geometry1.coordinates
|
||||
start_time2 = geometry2.coordinates
|
||||
|
||||
a = min(start_time1, start_time2)
|
||||
b = max(start_time1, start_time2)
|
||||
|
||||
if b - a >= 2 * time_buffer:
|
||||
return 0
|
||||
|
||||
intersection = a - b + 2 * time_buffer
|
||||
union = b - a + 2 * time_buffer
|
||||
return intersection / union
|
||||
|
||||
|
||||
def _interval_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeInterval)
|
||||
assert isinstance(geometry2, data.TimeInterval)
|
||||
|
||||
start_time1, end_time1 = geometry1.coordinates
|
||||
start_time2, end_time2 = geometry1.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
intersection = max(
|
||||
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
||||
)
|
||||
union = (
|
||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||
)
|
||||
|
||||
if union == 0:
|
||||
return 0
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
|
||||
"timestamp": _timestamp_affinity,
|
||||
"interval": _interval_affinity,
|
||||
"bbox": compute_affinity,
|
||||
}
|
||||
|
||||
|
||||
class GreedyMatchConfig(BaseConfig):
|
||||
name: Literal["greedy_match"] = "greedy_match"
|
||||
geometry: MatchingGeometry = "timestamp"
|
||||
affinity_threshold: float = 0.0
|
||||
time_buffer: float = 0.005
|
||||
frequency_buffer: float = 1_000
|
||||
affinity_threshold: float = 0.5
|
||||
affinity_function: AffinityConfig = Field(
|
||||
default_factory=GeometricIOUConfig
|
||||
)
|
||||
|
||||
|
||||
@matching_strategy.register(GreedyMatchConfig)
|
||||
class GreedyMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
geometry: MatchingGeometry,
|
||||
affinity_threshold: float,
|
||||
time_buffer: float,
|
||||
frequency_buffer: float,
|
||||
affinity_function: AffinityFunction,
|
||||
):
|
||||
self.geometry = geometry
|
||||
self.affinity_function = affinity_function
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.time_buffer = time_buffer
|
||||
self.frequency_buffer = frequency_buffer
|
||||
|
||||
self.affinity_function = _affinity_functions[self.geometry]
|
||||
self.cast_geometry = _geometry_cast_functions[self.geometry]
|
||||
|
||||
def __call__(
|
||||
@ -240,28 +262,27 @@ class GreedyMatcher(MatcherProtocol):
|
||||
scores=scores,
|
||||
affinity_function=self.affinity_function,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: GreedyMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return cls(
|
||||
geometry=config.geometry,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
time_buffer=config.time_buffer,
|
||||
frequency_buffer=config.frequency_buffer,
|
||||
affinity_function=affinity_function,
|
||||
)
|
||||
|
||||
|
||||
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
|
||||
|
||||
|
||||
def greedy_match(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
affinity_threshold: float = 0.5,
|
||||
affinity_function: AffinityFunction = compute_affinity,
|
||||
time_buffer: float = 0.001,
|
||||
freq_buffer: float = 1000,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||
|
||||
@ -279,10 +300,6 @@ def greedy_match(
|
||||
Confidence scores for each source geometry for prioritization.
|
||||
affinity_threshold
|
||||
The minimum affinity score required for a valid match.
|
||||
time_buffer
|
||||
Time tolerance in seconds for affinity calculation.
|
||||
freq_buffer
|
||||
Frequency tolerance in Hertz for affinity calculation.
|
||||
|
||||
Yields
|
||||
------
|
||||
@ -314,12 +331,7 @@ def greedy_match(
|
||||
|
||||
affinities = np.array(
|
||||
[
|
||||
affinity_function(
|
||||
source_geometry,
|
||||
target_geometry,
|
||||
time_buffer=time_buffer,
|
||||
freq_buffer=freq_buffer,
|
||||
)
|
||||
affinity_function(source_geometry, target_geometry)
|
||||
for target_geometry in ground_truth
|
||||
]
|
||||
)
|
||||
@ -344,12 +356,11 @@ def greedy_match(
|
||||
|
||||
class OptimalMatchConfig(BaseConfig):
|
||||
name: Literal["optimal_match"] = "optimal_match"
|
||||
affinity_threshold: float = 0.0
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.005
|
||||
frequency_buffer: float = 1_000
|
||||
|
||||
|
||||
@matching_strategy.register(OptimalMatchConfig)
|
||||
class OptimalMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
@ -384,6 +395,8 @@ class OptimalMatcher(MatcherProtocol):
|
||||
)
|
||||
|
||||
|
||||
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[
|
||||
GreedyMatchConfig,
|
||||
@ -396,174 +409,4 @@ MatchConfig = Annotated[
|
||||
|
||||
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
||||
config = config or StartTimeMatchConfig()
|
||||
return matching_strategy.build(config)
|
||||
|
||||
|
||||
def _is_in_bounds(
|
||||
geometry: data.Geometry,
|
||||
clip: data.Clip,
|
||||
buffer: float,
|
||||
) -> bool:
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
return (start_time >= clip.start_time + buffer) and (
|
||||
start_time <= clip.end_time - buffer
|
||||
)
|
||||
|
||||
|
||||
def match_sound_events_and_predictions(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
raw_predictions: List[RawPrediction],
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
ignore_start_end: float = 0.01,
|
||||
) -> List[MatchEvaluation]:
|
||||
if matcher is None:
|
||||
matcher = build_matcher()
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
target_sound_events = [
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if targets.filter(sound_event_annotation)
|
||||
and sound_event_annotation.sound_event.geometry is not None
|
||||
and _is_in_bounds(
|
||||
sound_event_annotation.sound_event.geometry,
|
||||
clip=clip_annotation.clip,
|
||||
buffer=ignore_start_end,
|
||||
)
|
||||
]
|
||||
|
||||
target_geometries: List[data.Geometry] = [
|
||||
sound_event_annotation.sound_event.geometry
|
||||
for sound_event_annotation in target_sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
raw_predictions = [
|
||||
raw_prediction
|
||||
for raw_prediction in raw_predictions
|
||||
if _is_in_bounds(
|
||||
raw_prediction.geometry,
|
||||
clip=clip_annotation.clip,
|
||||
buffer=ignore_start_end,
|
||||
)
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
scores = [
|
||||
raw_prediction.detection_score for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for source_idx, target_idx, affinity in matcher(
|
||||
ground_truth=target_geometries,
|
||||
predictions=predicted_geometries,
|
||||
scores=scores,
|
||||
):
|
||||
target = (
|
||||
target_sound_events[target_idx] if target_idx is not None else None
|
||||
)
|
||||
prediction = (
|
||||
raw_predictions[source_idx] if source_idx is not None else None
|
||||
)
|
||||
|
||||
gt_det = target_idx is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
|
||||
pred_geometry = (
|
||||
predicted_geometries[source_idx]
|
||||
if source_idx is not None
|
||||
else None
|
||||
)
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
str(class_name): float(score)
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.class_scores,
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEvaluation(
|
||||
clip=clip_annotation.clip,
|
||||
sound_event_annotation=target,
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=class_scores,
|
||||
pred_geometry=pred_geometry,
|
||||
affinity=affinity,
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def match_all_predictions(
|
||||
clip_annotations: List[data.ClipAnnotation],
|
||||
predictions: List[List[RawPrediction]],
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
ignore_start_end: float = 0.01,
|
||||
) -> List[MatchEvaluation]:
|
||||
logger.info("Matching all annotations and predictions...")
|
||||
return [
|
||||
match
|
||||
for clip_annotation, raw_predictions in zip(
|
||||
clip_annotations,
|
||||
predictions,
|
||||
)
|
||||
for match in match_sound_events_and_predictions(
|
||||
clip_annotation,
|
||||
raw_predictions,
|
||||
targets=targets,
|
||||
matcher=matcher,
|
||||
ignore_start_end=ignore_start_end,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassExamples:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(matches: List[MatchEvaluation]) -> ClassExamples:
|
||||
class_examples = ClassExamples()
|
||||
|
||||
for match in matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples.false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples.false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples.cross_triggers.append(match)
|
||||
class_examples.cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples.true_positives.append(match)
|
||||
|
||||
return class_examples
|
||||
return matching_strategies.build(config)
|
||||
|
||||
@ -1,61 +1,151 @@
|
||||
from typing import Dict, List
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.typing import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.typing import MetricsProtocol
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
|
||||
__all__ = ["DetectionAveragePrecision"]
|
||||
__all__ = ["DetectionAP", "ClassificationAP"]
|
||||
|
||||
|
||||
class DetectionAveragePrecision(MetricsProtocol):
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
||||
|
||||
|
||||
class DetectionAPConfig(BaseConfig):
|
||||
name: Literal["detection_ap"] = "detection_ap"
|
||||
|
||||
|
||||
class DetectionAP(MetricsProtocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[(match.gt_det, match.pred_score) for match in matches]
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
score = float(metrics.average_precision_score(y_true, y_score))
|
||||
return {"detection_AP": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
||||
return cls()
|
||||
|
||||
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str]):
|
||||
|
||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||
|
||||
|
||||
class ClassificationAPConfig(BaseConfig):
|
||||
name: Literal["classification_ap"] = "classification_ap"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
# NOTE: Need to exclude generic but unclassified targets
|
||||
y_true = label_binarize(
|
||||
[
|
||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||
for match in matches
|
||||
if not (match.gt_det and match.gt_class is None)
|
||||
],
|
||||
classes=self.class_names,
|
||||
)
|
||||
y_pred = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
name: match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
}
|
||||
for match in matches
|
||||
if not (match.gt_det and match.gt_class is None)
|
||||
]
|
||||
).fillna(0)
|
||||
self.selected = class_names
|
||||
|
||||
ret = {}
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[class_name]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_ap = metrics.average_precision_score(
|
||||
y_true_class,
|
||||
y_pred_class,
|
||||
)
|
||||
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
||||
class_scores[class_name] = float(class_ap)
|
||||
|
||||
ret["classification_mAP"] = np.mean(
|
||||
[value for value in ret.values() if value != 0]
|
||||
mean_ap = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
|
||||
return ret
|
||||
return {
|
||||
"classification_mAP": float(mean_ap),
|
||||
**{
|
||||
f"classification_AP/{class_name}": class_scores[class_name]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationAPConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
||||
|
||||
|
||||
MetricConfig = Annotated[
|
||||
Union[ClassificationAPConfig, DetectionAPConfig],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_metric(config: MetricConfig, class_names: List[str]):
|
||||
return metrics_registry.build(config, class_names)
|
||||
|
||||
163
src/batdetect2/evaluate/plots.py
Normal file
163
src/batdetect2/evaluate/plots.py
Normal file
@ -0,0 +1,163 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
MatchEvaluation,
|
||||
PlotterProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"build_plotter",
|
||||
"ExampleGallery",
|
||||
"ExampleGalleryConfig",
|
||||
]
|
||||
|
||||
|
||||
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
|
||||
|
||||
|
||||
class ExampleGalleryConfig(BaseConfig):
|
||||
name: Literal["example_gallery"] = "example_gallery"
|
||||
examples_per_class: int = 5
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class ExampleGallery(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
examples_per_class: int,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
):
|
||||
self.examples_per_class = examples_per_class
|
||||
self.preprocessor = preprocessor or build_preprocessor()
|
||||
self.audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
per_class_matches = group_matches(clip_evaluations)
|
||||
|
||||
for class_name, matches in per_class_matches.items():
|
||||
true_positives = get_binned_sample(
|
||||
matches.true_positives,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
false_positives = get_binned_sample(
|
||||
matches.false_positives,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
false_negatives = random.sample(
|
||||
matches.false_negatives,
|
||||
k=min(self.examples_per_class, len(matches.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers = get_binned_sample(
|
||||
matches.cross_triggers,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
fig = plot_match_gallery(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
yield f"example_gallery/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ExampleGalleryConfig):
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
||||
return cls(
|
||||
examples_per_class=config.examples_per_class,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
||||
|
||||
|
||||
PlotConfig = Annotated[
|
||||
Union[ExampleGalleryConfig,], Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
def build_plotter(config: PlotConfig) -> PlotterProtocol:
|
||||
return plots_registry.build(config)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassMatches:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(
|
||||
clip_evaluations: Sequence[ClipEvaluation],
|
||||
) -> Dict[str, ClassMatches]:
|
||||
class_examples = defaultdict(ClassMatches)
|
||||
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[gt_class].cross_triggers.append(match)
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
return class_examples
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[
|
||||
(index, match.pred_class_scores[pred_class])
|
||||
for index, match in enumerate(matches)
|
||||
if (pred_class := match.pred_class) is not None
|
||||
]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").sample(1)
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
@ -2,6 +2,7 @@ from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
||||
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.heatmaps import (
|
||||
plot_classification_heatmap,
|
||||
plot_detection_heatmap,
|
||||
@ -26,4 +27,5 @@ __all__ = [
|
||||
"plot_true_positive_match",
|
||||
"plot_detection_heatmap",
|
||||
"plot_classification_heatmap",
|
||||
"plot_match_gallery",
|
||||
]
|
||||
|
||||
@ -1,160 +0,0 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from batdetect2 import plotting
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassExamples:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def plot_example_gallery(
|
||||
matches: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
):
|
||||
class_examples = defaultdict(ClassExamples)
|
||||
|
||||
for match in matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[gt_class].cross_triggers.append(match)
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
for class_name, examples in class_examples.items():
|
||||
true_positives = get_binned_sample(
|
||||
examples.true_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_positives = get_binned_sample(
|
||||
examples.false_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_negatives = random.sample(
|
||||
examples.false_negatives,
|
||||
k=min(n_examples, len(examples.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers = get_binned_sample(
|
||||
examples.cross_triggers,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
fig = plot_class_examples(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=preprocessor,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
yield class_name, fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_class_examples(
|
||||
true_positives: List[MatchEvaluation],
|
||||
false_positives: List[MatchEvaluation],
|
||||
false_negatives: List[MatchEvaluation],
|
||||
cross_triggers: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
):
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
for index, match in enumerate(true_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, index + 1)
|
||||
try:
|
||||
plotting.plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[
|
||||
(index, match.pred_class_scores[pred_class])
|
||||
for index, match in enumerate(matches)
|
||||
if (pred_class := match.pred_class) is not None
|
||||
]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").sample(1)
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
81
src/batdetect2/plotting/gallery.py
Normal file
81
src/batdetect2/plotting/gallery.py
Normal file
@ -0,0 +1,81 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from batdetect2.plotting.matches import (
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
plot_false_positive_match,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = ["plot_match_gallery"]
|
||||
|
||||
|
||||
def plot_match_gallery(
|
||||
true_positives: List[MatchEvaluation],
|
||||
false_positives: List[MatchEvaluation],
|
||||
false_negatives: List[MatchEvaluation],
|
||||
cross_triggers: List[MatchEvaluation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
):
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
for index, match in enumerate(true_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, index + 1)
|
||||
try:
|
||||
plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
||||
try:
|
||||
plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
||||
try:
|
||||
plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
|
||||
try:
|
||||
plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
return fig
|
||||
@ -7,8 +7,8 @@ from soundevent.geometry import compute_bounds
|
||||
from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
|
||||
__all__ = [
|
||||
@ -32,6 +32,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||
def plot_matches(
|
||||
matches: List[data.Match],
|
||||
clip: data.Clip,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
@ -46,12 +47,11 @@ def plot_matches(
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
@ -116,6 +116,7 @@ def plot_matches(
|
||||
|
||||
def plot_false_positive_match(
|
||||
match: MatchEvaluation,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
@ -143,6 +144,7 @@ def plot_false_positive_match(
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
@ -174,6 +176,7 @@ def plot_false_positive_match(
|
||||
|
||||
def plot_false_negative_match(
|
||||
match: MatchEvaluation,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
@ -203,6 +206,7 @@ def plot_false_negative_match(
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
@ -237,6 +241,7 @@ def plot_false_negative_match(
|
||||
def plot_true_positive_match(
|
||||
match: MatchEvaluation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
@ -267,6 +272,7 @@ def plot_true_positive_match(
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
@ -312,6 +318,7 @@ def plot_true_positive_match(
|
||||
def plot_cross_trigger_match(
|
||||
match: MatchEvaluation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
@ -342,6 +349,7 @@ def plot_cross_trigger_match(
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
|
||||
@ -1,49 +1,26 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
build_matcher,
|
||||
match_all_predictions,
|
||||
)
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol
|
||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||
from batdetect2.evaluate import Evaluator
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import get_image_plotter
|
||||
from batdetect2.typing import (
|
||||
MatchEvaluation,
|
||||
MetricsProtocol,
|
||||
)
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.train import TrainExample
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
metrics: List[MetricsProtocol],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
plot: bool = True,
|
||||
match_config: Optional[MatchConfig] = None,
|
||||
):
|
||||
def __init__(self, evaluator: Evaluator):
|
||||
super().__init__()
|
||||
|
||||
if len(metrics) == 0:
|
||||
raise ValueError("At least one metric needs to be provided")
|
||||
|
||||
self.match_config = match_config
|
||||
self.metrics = metrics
|
||||
self.preprocessor = preprocessor
|
||||
self.plot = plot
|
||||
|
||||
self.matcher = build_matcher(config=match_config)
|
||||
self.evaluator = evaluator
|
||||
|
||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||
self._predictions: List[List[RawPrediction]] = []
|
||||
@ -58,33 +35,22 @@ class ValidationMetrics(Callback):
|
||||
def plot_examples(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
matches: List[MatchEvaluation],
|
||||
evaluated_clips: List[ClipEvaluation],
|
||||
):
|
||||
plotter = get_image_plotter(pl_module.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
|
||||
for class_name, fig in plot_example_gallery(
|
||||
matches,
|
||||
preprocessor=self.preprocessor,
|
||||
n_examples=4,
|
||||
):
|
||||
plotter(
|
||||
f"examples/{class_name}",
|
||||
fig,
|
||||
pl_module.global_step,
|
||||
)
|
||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||
plotter(figure_name, fig, pl_module.global_step)
|
||||
|
||||
def log_metrics(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
matches: List[MatchEvaluation],
|
||||
evaluated_clips: List[ClipEvaluation],
|
||||
):
|
||||
metrics = {}
|
||||
for metric in self.metrics:
|
||||
metrics.update(metric(matches).items())
|
||||
|
||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||
pl_module.log_dict(metrics)
|
||||
|
||||
def on_validation_epoch_end(
|
||||
@ -92,17 +58,13 @@ class ValidationMetrics(Callback):
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
matches = match_all_predictions(
|
||||
clip_evaluations = self.evaluator.evaluate(
|
||||
self._clip_annotations,
|
||||
self._predictions,
|
||||
targets=pl_module.model.targets,
|
||||
matcher=self.matcher,
|
||||
)
|
||||
|
||||
self.log_metrics(pl_module, matches)
|
||||
|
||||
if self.plot:
|
||||
self.plot_examples(pl_module, matches)
|
||||
self.log_metrics(pl_module, clip_evaluations)
|
||||
self.plot_examples(pl_module, clip_evaluations)
|
||||
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
registry: Registry[ClipperProtocol] = Registry("clipper")
|
||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
||||
|
||||
|
||||
class RandomClipConfig(BaseConfig):
|
||||
@ -25,7 +25,6 @@ class RandomClipConfig(BaseConfig):
|
||||
min_sound_event_overlap: float = 0
|
||||
|
||||
|
||||
@registry.register(RandomClipConfig)
|
||||
class RandomClip:
|
||||
def __init__(
|
||||
self,
|
||||
@ -61,6 +60,9 @@ class RandomClip:
|
||||
)
|
||||
|
||||
|
||||
clipper_registry.register(RandomClipConfig, RandomClip)
|
||||
|
||||
|
||||
def get_subclip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
random: bool = True,
|
||||
@ -156,7 +158,6 @@ class PaddedClipConfig(BaseConfig):
|
||||
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
|
||||
|
||||
@registry.register(PaddedClipConfig)
|
||||
class PaddedClip:
|
||||
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||
self.chunk_size = chunk_size
|
||||
@ -183,6 +184,8 @@ class PaddedClip:
|
||||
return cls(chunk_size=config.chunk_size)
|
||||
|
||||
|
||||
clipper_registry.register(PaddedClipConfig, PaddedClip)
|
||||
|
||||
ClipConfig = Annotated[
|
||||
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
|
||||
]
|
||||
@ -195,4 +198,4 @@ def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
||||
"Building clipper with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return registry.build(config)
|
||||
return clipper_registry.build(config)
|
||||
|
||||
@ -10,10 +10,7 @@ from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
@ -146,7 +143,6 @@ def build_training_module(
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: EvaluationConfig,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
@ -161,6 +157,8 @@ def build_trainer_callbacks(
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
evaluator = build_evaluator(config=config, targets=targets)
|
||||
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
@ -168,16 +166,7 @@ def build_trainer_callbacks(
|
||||
filename="best-{epoch:02d}-{val_loss:.0f}",
|
||||
monitor="total_loss/val",
|
||||
),
|
||||
ValidationMetrics(
|
||||
metrics=[
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(
|
||||
class_names=targets.class_names
|
||||
),
|
||||
],
|
||||
preprocessor=preprocessor,
|
||||
match_config=config.match,
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
]
|
||||
|
||||
|
||||
@ -214,7 +203,6 @@ def build_trainer(
|
||||
callbacks=build_trainer_callbacks(
|
||||
targets,
|
||||
config=conf.evaluation,
|
||||
preprocessor=build_preprocessor(conf.preprocess),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
|
||||
@ -11,6 +11,7 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
@ -50,6 +51,12 @@ class MatchEvaluation:
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEvaluation:
|
||||
clip: data.Clip
|
||||
matches: List[MatchEvaluation]
|
||||
|
||||
|
||||
class MatcherProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
@ -67,10 +74,16 @@ class AffinityFunction(Protocol, Generic[Geom]):
|
||||
self,
|
||||
geometry1: Geom,
|
||||
geometry2: Geom,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float: ...
|
||||
|
||||
|
||||
class MetricsProtocol(Protocol):
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]: ...
|
||||
|
||||
|
||||
class PlotterProtocol(Protocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user