Restructure eval metrics and plotting

This commit is contained in:
mbsantiago 2025-09-15 16:01:15 +01:00
parent ec1c0ff020
commit e752e96b93
20 changed files with 991 additions and 630 deletions

View File

@ -1,6 +1,7 @@
from typing import Generic, Protocol, Type, TypeVar from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import ParamSpec
__all__ = [ __all__ = [
"Registry", "Registry",
@ -8,26 +9,36 @@ __all__ = [
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True) T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type")
class LogicProtocol(Generic[T_Config, T_Type], Protocol): class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol):
"""A generic protocol for the logic classes (conditions or transforms).""" """A generic protocol for the logic classes."""
@classmethod @classmethod
def from_config(cls, config: T_Config) -> T_Type: ... def from_config(
cls,
config: T_Config,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type: ...
T_Proto = TypeVar("T_Proto", bound=LogicProtocol) T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
class Registry(Generic[T_Type]): class Registry(Generic[T_Type, P_Type]):
"""A generic class to create and manage a registry of items.""" """A generic class to create and manage a registry of items."""
def __init__(self, name: str): def __init__(self, name: str):
self._name = name self._name = name
self._registry = {} self._registry = {}
def register(self, config_cls: Type[T_Config]): def register(
self,
config_cls: Type[T_Config],
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
) -> None:
"""A decorator factory to register a new item.""" """A decorator factory to register a new item."""
fields = config_cls.model_fields fields = config_cls.model_fields
@ -39,13 +50,14 @@ class Registry(Generic[T_Type]):
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.") raise ValueError("'name' field must be a string literal.")
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
self._registry[name] = logic_cls self._registry[name] = logic_cls
return logic_cls
return decorator def build(
self,
def build(self, config: BaseModel) -> T_Type: config: BaseModel,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type:
"""Builds a logic instance from a config object.""" """Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009 name = getattr(config, "name") # noqa: B009
@ -58,4 +70,4 @@ class Registry(Generic[T_Type]):
f"No {self._name} with name '{name}' is registered." f"No {self._name} with name '{name}' is registered."
) )
return self._registry[name].from_config(config) return self._registry[name].from_config(config, *args, **kwargs)

View File

@ -10,7 +10,7 @@ from batdetect2.data._core import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
_conditions: Registry[SoundEventCondition] = Registry("condition") condition_registry: Registry[SoundEventCondition, []] = Registry("condition")
class HasTagConfig(BaseConfig): class HasTagConfig(BaseConfig):
@ -18,7 +18,6 @@ class HasTagConfig(BaseConfig):
tag: data.Tag tag: data.Tag
@_conditions.register(HasTagConfig)
class HasTag: class HasTag:
def __init__(self, tag: data.Tag): def __init__(self, tag: data.Tag):
self.tag = tag self.tag = tag
@ -33,12 +32,14 @@ class HasTag:
return cls(tag=config.tag) return cls(tag=config.tag)
condition_registry.register(HasTagConfig, HasTag)
class HasAllTagsConfig(BaseConfig): class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags" name: Literal["has_all_tags"] = "has_all_tags"
tags: List[data.Tag] tags: List[data.Tag]
@_conditions.register(HasAllTagsConfig)
class HasAllTags: class HasAllTags:
def __init__(self, tags: List[data.Tag]): def __init__(self, tags: List[data.Tag]):
if not tags: if not tags:
@ -56,12 +57,14 @@ class HasAllTags:
return cls(tags=config.tags) return cls(tags=config.tags)
condition_registry.register(HasAllTagsConfig, HasAllTags)
class HasAnyTagConfig(BaseConfig): class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag" name: Literal["has_any_tag"] = "has_any_tag"
tags: List[data.Tag] tags: List[data.Tag]
@_conditions.register(HasAnyTagConfig)
class HasAnyTag: class HasAnyTag:
def __init__(self, tags: List[data.Tag]): def __init__(self, tags: List[data.Tag]):
if not tags: if not tags:
@ -79,6 +82,8 @@ class HasAnyTag:
return cls(tags=config.tags) return cls(tags=config.tags)
condition_registry.register(HasAnyTagConfig, HasAnyTag)
Operator = Literal["gt", "gte", "lt", "lte", "eq"] Operator = Literal["gt", "gte", "lt", "lte", "eq"]
@ -109,7 +114,6 @@ def _build_comparator(
raise ValueError(f"Invalid operator {operator}") raise ValueError(f"Invalid operator {operator}")
@_conditions.register(DurationConfig)
class Duration: class Duration:
def __init__(self, operator: Operator, seconds: float): def __init__(self, operator: Operator, seconds: float):
self.operator = operator self.operator = operator
@ -135,6 +139,9 @@ class Duration:
return cls(operator=config.operator, seconds=config.seconds) return cls(operator=config.operator, seconds=config.seconds)
condition_registry.register(DurationConfig, Duration)
class FrequencyConfig(BaseConfig): class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency" name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"] boundary: Literal["low", "high"]
@ -142,7 +149,6 @@ class FrequencyConfig(BaseConfig):
hertz: float hertz: float
@_conditions.register(FrequencyConfig)
class Frequency: class Frequency:
def __init__( def __init__(
self, self,
@ -184,12 +190,14 @@ class Frequency:
) )
condition_registry.register(FrequencyConfig, Frequency)
class AllOfConfig(BaseConfig): class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of" name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"] conditions: Sequence["SoundEventConditionConfig"]
@_conditions.register(AllOfConfig)
class AllOf: class AllOf:
def __init__(self, conditions: List[SoundEventCondition]): def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions self.conditions = conditions
@ -207,12 +215,14 @@ class AllOf:
return cls(conditions) return cls(conditions)
condition_registry.register(AllOfConfig, AllOf)
class AnyOfConfig(BaseConfig): class AnyOfConfig(BaseConfig):
name: Literal["any_of"] = "any_of" name: Literal["any_of"] = "any_of"
conditions: List["SoundEventConditionConfig"] conditions: List["SoundEventConditionConfig"]
@_conditions.register(AnyOfConfig)
class AnyOf: class AnyOf:
def __init__(self, conditions: List[SoundEventCondition]): def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions self.conditions = conditions
@ -230,12 +240,14 @@ class AnyOf:
return cls(conditions) return cls(conditions)
condition_registry.register(AnyOfConfig, AnyOf)
class NotConfig(BaseConfig): class NotConfig(BaseConfig):
name: Literal["not"] = "not" name: Literal["not"] = "not"
condition: "SoundEventConditionConfig" condition: "SoundEventConditionConfig"
@_conditions.register(NotConfig)
class Not: class Not:
def __init__(self, condition: SoundEventCondition): def __init__(self, condition: SoundEventCondition):
self.condition = condition self.condition = condition
@ -251,6 +263,8 @@ class Not:
return cls(condition) return cls(condition)
condition_registry.register(NotConfig, Not)
SoundEventConditionConfig = Annotated[ SoundEventConditionConfig = Annotated[
Union[ Union[
HasTagConfig, HasTagConfig,
@ -269,7 +283,7 @@ SoundEventConditionConfig = Annotated[
def build_sound_event_condition( def build_sound_event_condition(
config: SoundEventConditionConfig, config: SoundEventConditionConfig,
) -> SoundEventCondition: ) -> SoundEventCondition:
return _conditions.build(config) return condition_registry.build(config)
def filter_clip_annotation( def filter_clip_annotation(

View File

@ -17,7 +17,7 @@ SoundEventTransform = Callable[
data.SoundEventAnnotation, data.SoundEventAnnotation,
] ]
_transforms: Registry[SoundEventTransform] = Registry("transform") transform_registry: Registry[SoundEventTransform, []] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig): class SetFrequencyBoundConfig(BaseConfig):
@ -26,7 +26,6 @@ class SetFrequencyBoundConfig(BaseConfig):
hertz: float hertz: float
@_transforms.register(SetFrequencyBoundConfig)
class SetFrequencyBound: class SetFrequencyBound:
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"): def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
self.hertz = hertz self.hertz = hertz
@ -69,13 +68,15 @@ class SetFrequencyBound:
return cls(hertz=config.hertz, boundary=config.boundary) return cls(hertz=config.hertz, boundary=config.boundary)
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
class ApplyIfConfig(BaseConfig): class ApplyIfConfig(BaseConfig):
name: Literal["apply_if"] = "apply_if" name: Literal["apply_if"] = "apply_if"
transform: "SoundEventTransformConfig" transform: "SoundEventTransformConfig"
condition: SoundEventConditionConfig condition: SoundEventConditionConfig
@_transforms.register(ApplyIfConfig)
class ApplyIf: class ApplyIf:
def __init__( def __init__(
self, self,
@ -101,13 +102,15 @@ class ApplyIf:
return cls(condition=condition, transform=transform) return cls(condition=condition, transform=transform)
transform_registry.register(ApplyIfConfig, ApplyIf)
class ReplaceTagConfig(BaseConfig): class ReplaceTagConfig(BaseConfig):
name: Literal["replace_tag"] = "replace_tag" name: Literal["replace_tag"] = "replace_tag"
original: data.Tag original: data.Tag
replacement: data.Tag replacement: data.Tag
@_transforms.register(ReplaceTagConfig)
class ReplaceTag: class ReplaceTag:
def __init__( def __init__(
self, self,
@ -136,6 +139,9 @@ class ReplaceTag:
return cls(original=config.original, replacement=config.replacement) return cls(original=config.original, replacement=config.replacement)
transform_registry.register(ReplaceTagConfig, ReplaceTag)
class MapTagValueConfig(BaseConfig): class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value" name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str tag_key: str
@ -143,7 +149,6 @@ class MapTagValueConfig(BaseConfig):
target_key: Optional[str] = None target_key: Optional[str] = None
@_transforms.register(MapTagValueConfig)
class MapTagValue: class MapTagValue:
def __init__( def __init__(
self, self,
@ -193,12 +198,14 @@ class MapTagValue:
) )
transform_registry.register(MapTagValueConfig, MapTagValue)
class ApplyAllConfig(BaseConfig): class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all" name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list) steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
@_transforms.register(ApplyAllConfig)
class ApplyAll: class ApplyAll:
def __init__(self, steps: List[SoundEventTransform]): def __init__(self, steps: List[SoundEventTransform]):
self.steps = steps self.steps = steps
@ -218,6 +225,8 @@ class ApplyAll:
return cls(steps) return cls(steps)
transform_registry.register(ApplyAllConfig, ApplyAll)
SoundEventTransformConfig = Annotated[ SoundEventTransformConfig = Annotated[
Union[ Union[
SetFrequencyBoundConfig, SetFrequencyBoundConfig,
@ -233,7 +242,7 @@ SoundEventTransformConfig = Annotated[
def build_sound_event_transform( def build_sound_event_transform(
config: SoundEventTransformConfig, config: SoundEventTransformConfig,
) -> SoundEventTransform: ) -> SoundEventTransform:
return _transforms.build(config) return transform_registry.build(config)
def transform_clip_annotation( def transform_clip_annotation(

View File

@ -1,6 +1,9 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
"load_evaluation_config", "load_evaluation_config",
"Evaluator",
"build_evaluator",
] ]

View File

@ -0,0 +1,151 @@
from typing import Annotated, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing.evaluate import AffinityFunction
affinity_functions: Registry[AffinityFunction, []] = Registry(
"matching_strategy"
)
class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity"
time_buffer: float = 0.01
class TimeAffinity(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_timestamp_affinity(
geometry1, geometry2, time_buffer=self.time_buffer
)
@classmethod
def from_config(cls, config: TimeAffinityConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
def compute_timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
class IntervalIOUConfig(BaseConfig):
name: Literal["interval_iou"] = "interval_iou"
time_buffer: float = 0.01
class IntervalIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_interval_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
@classmethod
def from_config(cls, config: IntervalIOUConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
def compute_interval_iou(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
class GeometricIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_affinity(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
@classmethod
def from_config(cls, config: GeometricIOUConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
AffinityConfig = Annotated[
Union[
TimeAffinityConfig,
IntervalIOUConfig,
GeometricIOUConfig,
],
Field(discriminator="name"),
]
def build_affinity_function(
config: Optional[AffinityConfig] = None,
) -> AffinityFunction:
config = config or GeometricIOUConfig()
return affinity_functions.build(config)

View File

@ -1,10 +1,16 @@
from typing import Optional from typing import List, Optional
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
from batdetect2.evaluate.metrics import (
ClassificationAPConfig,
DetectionAPConfig,
MetricConfig,
)
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
@ -13,7 +19,19 @@ __all__ = [
class EvaluationConfig(BaseConfig): class EvaluationConfig(BaseConfig):
ignore_start_end: float = 0.01
match: MatchConfig = Field(default_factory=StartTimeMatchConfig) match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
metrics: List[MetricConfig] = Field(
default_factory=lambda: [
DetectionAPConfig(),
ClassificationAPConfig(),
]
)
plots: List[PlotConfig] = Field(
default_factory=lambda: [
ExampleGalleryConfig(),
]
)
def load_evaluation_config( def load_evaluation_config(

View File

@ -3,13 +3,14 @@ from typing import List
import pandas as pd import pandas as pd
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import MatchEvaluation from batdetect2.typing.evaluate import ClipEvaluation
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame: def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
data = [] data = []
for match in matches: for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None

View File

@ -4,11 +4,8 @@ import pandas as pd
from soundevent import data from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.match import build_matcher, match_all_predictions from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.plotting.clips import build_audio_loader from batdetect2.plotting.clips import build_audio_loader
from batdetect2.postprocess import get_raw_predictions from batdetect2.postprocess import get_raw_predictions
@ -55,6 +52,8 @@ def evaluate(
clip_annotations = [] clip_annotations = []
predictions = [] predictions = []
evaluator = build_evaluator(config=config.evaluation)
for batch in loader: for batch in loader:
outputs = model.detector(batch.spec) outputs = model.detector(batch.spec)
@ -76,20 +75,12 @@ def evaluate(
clip_annotations.extend(clip_annotations) clip_annotations.extend(clip_annotations)
predictions.extend(predictions) predictions.extend(predictions)
matcher = build_matcher(config.evaluation.match) matches = evaluator.evaluate(clip_annotations, predictions)
matches = match_all_predictions(
clip_annotations,
predictions,
targets=targets,
matcher=matcher,
)
df = extract_matches_dataframe(matches) df = extract_matches_dataframe(matches)
metrics = [ metrics = [
DetectionAveragePrecision(), DetectionAP(),
ClassificationMeanAveragePrecision(class_names=targets.class_names), ClassificationAP(class_names=targets.class_names),
] ]
results = { results = {

View File

@ -0,0 +1,169 @@
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from matplotlib.figure import Figure
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.match import build_matcher, match
from batdetect2.evaluate.metrics import build_metric
from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets
from batdetect2.typing.evaluate import (
ClipEvaluation,
MatcherProtocol,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"Evaluator",
"build_evaluator",
]
class Evaluator:
def __init__(
self,
config: EvaluationConfig,
targets: TargetProtocol,
matcher: MatcherProtocol,
metrics: List[MetricsProtocol],
plots: List[PlotterProtocol],
):
self.config = config
self.targets = targets
self.matcher = matcher
self.metrics = metrics
self.plots = plots
def match(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEvaluation:
clip = clip_annotation.clip
ground_truth = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
predictions = [
prediction
for prediction in predictions
if self.filter_predictions(prediction, clip)
]
return ClipEvaluation(
clip=clip_annotation.clip,
matches=match(
ground_truth,
predictions,
clip=clip,
targets=self.targets,
matcher=self.matcher,
),
)
def filter_sound_event_annotations(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.config.ignore_start_end,
)
def filter_predictions(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.config.ignore_start_end,
)
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]:
if len(clip_annotations) != len(predictions):
raise ValueError(
"Number of annotated clips and sets of predictions do not match"
)
return [
self.match(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def compute_metrics(
self,
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, float]:
results = {}
for metric in self.metrics:
results.update(metric(clip_evaluations))
return results
def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]:
for plotter in self.plots:
for name, fig in plotter(clip_evaluations):
yield name, fig
def build_evaluator(
config: Optional[EvaluationConfig] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None,
) -> Evaluator:
config = config or EvaluationConfig()
targets = targets or build_targets()
matcher = matcher or build_matcher(config.match)
if metrics is None:
metrics = [
build_metric(config, targets.class_names)
for config in config.metrics
]
if plots is None:
plots = [build_plotter(config) for config in config.plots]
return Evaluator(
config=config,
targets=targets,
matcher=matcher,
metrics=metrics,
plots=plots,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -1,9 +1,7 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
@ -12,6 +10,11 @@ from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry from batdetect2.data._core import Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
GeometricIOUConfig,
build_affinity_function,
)
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing import ( from batdetect2.typing import (
MatchEvaluation, MatchEvaluation,
@ -23,7 +26,88 @@ from batdetect2.typing.postprocess import RawPrediction
MatchingGeometry = Literal["bbox", "interval", "timestamp"] MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching.""" """The geometry representation to use for matching."""
matching_strategy = Registry("matching_strategy") matching_strategies = Registry("matching_strategy")
def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
) -> List[MatchEvaluation]:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in sound_event_annotations
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
scores = [
raw_prediction.detection_score for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
sound_event_annotations[target_idx]
if target_idx is not None
else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return matches
class StartTimeMatchConfig(BaseConfig): class StartTimeMatchConfig(BaseConfig):
@ -31,7 +115,6 @@ class StartTimeMatchConfig(BaseConfig):
distance_threshold: float = 0.01 distance_threshold: float = 0.01
@matching_strategy.register(StartTimeMatchConfig)
class StartTimeMatcher(MatcherProtocol): class StartTimeMatcher(MatcherProtocol):
def __init__(self, distance_threshold: float): def __init__(self, distance_threshold: float):
self.distance_threshold = distance_threshold self.distance_threshold = distance_threshold
@ -54,6 +137,9 @@ class StartTimeMatcher(MatcherProtocol):
return cls(distance_threshold=config.distance_threshold) return cls(distance_threshold=config.distance_threshold)
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
def match_start_times( def match_start_times(
ground_truth: Sequence[data.Geometry], ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry], predictions: Sequence[data.Geometry],
@ -74,8 +160,8 @@ def match_start_times(
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth]) gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions]) pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
scores = np.array(scores)
scores = np.array(scores)
sort_args = np.argsort(scores)[::-1] sort_args = np.argsort(scores)[::-1]
distances = np.abs(gt_times[None, :] - pred_times[:, None]) distances = np.abs(gt_times[None, :] - pred_times[:, None])
@ -143,89 +229,25 @@ _geometry_cast_functions: Mapping[
} }
def _timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
def _interval_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
"timestamp": _timestamp_affinity,
"interval": _interval_affinity,
"bbox": compute_affinity,
}
class GreedyMatchConfig(BaseConfig): class GreedyMatchConfig(BaseConfig):
name: Literal["greedy_match"] = "greedy_match" name: Literal["greedy_match"] = "greedy_match"
geometry: MatchingGeometry = "timestamp" geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.0 affinity_threshold: float = 0.5
time_buffer: float = 0.005 affinity_function: AffinityConfig = Field(
frequency_buffer: float = 1_000 default_factory=GeometricIOUConfig
)
@matching_strategy.register(GreedyMatchConfig)
class GreedyMatcher(MatcherProtocol): class GreedyMatcher(MatcherProtocol):
def __init__( def __init__(
self, self,
geometry: MatchingGeometry, geometry: MatchingGeometry,
affinity_threshold: float, affinity_threshold: float,
time_buffer: float, affinity_function: AffinityFunction,
frequency_buffer: float,
): ):
self.geometry = geometry self.geometry = geometry
self.affinity_function = affinity_function
self.affinity_threshold = affinity_threshold self.affinity_threshold = affinity_threshold
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.affinity_function = _affinity_functions[self.geometry]
self.cast_geometry = _geometry_cast_functions[self.geometry] self.cast_geometry = _geometry_cast_functions[self.geometry]
def __call__( def __call__(
@ -240,28 +262,27 @@ class GreedyMatcher(MatcherProtocol):
scores=scores, scores=scores,
affinity_function=self.affinity_function, affinity_function=self.affinity_function,
affinity_threshold=self.affinity_threshold, affinity_threshold=self.affinity_threshold,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
) )
@classmethod @classmethod
def from_config(cls, config: GreedyMatchConfig): def from_config(cls, config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return cls( return cls(
geometry=config.geometry, geometry=config.geometry,
affinity_threshold=config.affinity_threshold, affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer, affinity_function=affinity_function,
frequency_buffer=config.frequency_buffer,
) )
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
def greedy_match( def greedy_match(
ground_truth: Sequence[data.Geometry], ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry], predictions: Sequence[data.Geometry],
scores: Sequence[float], scores: Sequence[float],
affinity_threshold: float = 0.5, affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity, affinity_function: AffinityFunction = compute_affinity,
time_buffer: float = 0.001,
freq_buffer: float = 1000,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
"""Performs a greedy, one-to-one matching of source to target geometries. """Performs a greedy, one-to-one matching of source to target geometries.
@ -279,10 +300,6 @@ def greedy_match(
Confidence scores for each source geometry for prioritization. Confidence scores for each source geometry for prioritization.
affinity_threshold affinity_threshold
The minimum affinity score required for a valid match. The minimum affinity score required for a valid match.
time_buffer
Time tolerance in seconds for affinity calculation.
freq_buffer
Frequency tolerance in Hertz for affinity calculation.
Yields Yields
------ ------
@ -314,12 +331,7 @@ def greedy_match(
affinities = np.array( affinities = np.array(
[ [
affinity_function( affinity_function(source_geometry, target_geometry)
source_geometry,
target_geometry,
time_buffer=time_buffer,
freq_buffer=freq_buffer,
)
for target_geometry in ground_truth for target_geometry in ground_truth
] ]
) )
@ -344,12 +356,11 @@ def greedy_match(
class OptimalMatchConfig(BaseConfig): class OptimalMatchConfig(BaseConfig):
name: Literal["optimal_match"] = "optimal_match" name: Literal["optimal_match"] = "optimal_match"
affinity_threshold: float = 0.0 affinity_threshold: float = 0.5
time_buffer: float = 0.005 time_buffer: float = 0.005
frequency_buffer: float = 1_000 frequency_buffer: float = 1_000
@matching_strategy.register(OptimalMatchConfig)
class OptimalMatcher(MatcherProtocol): class OptimalMatcher(MatcherProtocol):
def __init__( def __init__(
self, self,
@ -384,6 +395,8 @@ class OptimalMatcher(MatcherProtocol):
) )
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
MatchConfig = Annotated[ MatchConfig = Annotated[
Union[ Union[
GreedyMatchConfig, GreedyMatchConfig,
@ -396,174 +409,4 @@ MatchConfig = Annotated[
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol: def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig() config = config or StartTimeMatchConfig()
return matching_strategy.build(config) return matching_strategies.build(config)
def _is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)
def match_sound_events_and_predictions(
clip_annotation: data.ClipAnnotation,
raw_predictions: List[RawPrediction],
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
and _is_in_bounds(
sound_event_annotation.sound_event.geometry,
clip=clip_annotation.clip,
buffer=ignore_start_end,
)
]
target_geometries: List[data.Geometry] = [
sound_event_annotation.sound_event.geometry
for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
]
raw_predictions = [
raw_prediction
for raw_prediction in raw_predictions
if _is_in_bounds(
raw_prediction.geometry,
clip=clip_annotation.clip,
buffer=ignore_start_end,
)
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
scores = [
raw_prediction.detection_score for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
target_sound_events[target_idx] if target_idx is not None else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip_annotation.clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return matches
def match_all_predictions(
clip_annotations: List[data.ClipAnnotation],
predictions: List[List[RawPrediction]],
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in zip(
clip_annotations,
predictions,
)
for match in match_sound_events_and_predictions(
clip_annotation,
raw_predictions,
targets=targets,
matcher=matcher,
ignore_start_end=ignore_start_end,
)
]
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(matches: List[MatchEvaluation]) -> ClassExamples:
class_examples = ClassExamples()
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples.false_negatives.append(match)
continue
if gt_class is None:
class_examples.false_positives.append(match)
continue
if gt_class != pred_class:
class_examples.cross_triggers.append(match)
class_examples.cross_triggers.append(match)
continue
class_examples.true_positives.append(match)
return class_examples

View File

@ -1,61 +1,151 @@
from typing import Dict, List from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import numpy as np import numpy as np
import pandas as pd from pydantic import Field
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
from batdetect2.typing import MatchEvaluation, MetricsProtocol from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing import MetricsProtocol
from batdetect2.typing.evaluate import ClipEvaluation
__all__ = ["DetectionAveragePrecision"] __all__ = ["DetectionAP", "ClassificationAP"]
class DetectionAveragePrecision(MetricsProtocol): metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
class DetectionAP(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip( y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches] *[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
) )
score = float(metrics.average_precision_score(y_true, y_score)) score = float(metrics.average_precision_score(y_true, y_score))
return {"detection_AP": score} return {"detection_AP": score}
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls()
class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str]): metrics_registry.register(DetectionAPConfig, DetectionAP)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: self.selected = class_names
# NOTE: Need to exclude generic but unclassified targets
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
if not (match.gt_det and match.gt_class is None)
],
classes=self.class_names,
)
y_pred = pd.DataFrame(
[
{
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
if not (match.gt_det and match.gt_class is None)
]
).fillna(0)
ret = {} if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names): for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index] y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name] y_pred_class = y_pred[:, class_index]
class_ap = metrics.average_precision_score( class_ap = metrics.average_precision_score(
y_true_class, y_true_class,
y_pred_class, y_pred_class,
) )
ret[f"classification_AP/{class_name}"] = float(class_ap) class_scores[class_name] = float(class_ap)
ret["classification_mAP"] = np.mean( mean_ap = np.mean(
[value for value in ret.values() if value != 0] [value for value in class_scores.values() if value != 0]
) )
return ret return {
"classification_mAP": float(mean_ap),
**{
f"classification_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationAPConfig,
class_names: List[str],
):
return cls(
class_names,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
MetricConfig = Annotated[
Union[ClassificationAPConfig, DetectionAPConfig],
Field(discriminator="name"),
]
def build_metric(config: MetricConfig, class_names: List[str]):
return metrics_registry.build(config, class_names)

View File

@ -0,0 +1,163 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import pandas as pd
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing.evaluate import (
ClipEvaluation,
MatchEvaluation,
PlotterProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
__all__ = [
"build_plotter",
"ExampleGallery",
"ExampleGalleryConfig",
]
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
class ExampleGalleryConfig(BaseConfig):
name: Literal["example_gallery"] = "example_gallery"
examples_per_class: int = 5
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleGallery(PlotterProtocol):
def __init__(
self,
examples_per_class: int,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.examples_per_class = examples_per_class
self.preprocessor = preprocessor or build_preprocessor()
self.audio_loader = audio_loader or build_audio_loader()
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
per_class_matches = group_matches(clip_evaluations)
for class_name, matches in per_class_matches.items():
true_positives = get_binned_sample(
matches.true_positives,
n_examples=self.examples_per_class,
)
false_positives = get_binned_sample(
matches.false_positives,
n_examples=self.examples_per_class,
)
false_negatives = random.sample(
matches.false_negatives,
k=min(self.examples_per_class, len(matches.false_negatives)),
)
cross_triggers = get_binned_sample(
matches.cross_triggers,
n_examples=self.examples_per_class,
)
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.examples_per_class,
)
yield f"example_gallery/{class_name}", fig
plt.close(fig)
@classmethod
def from_config(cls, config: ExampleGalleryConfig):
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
PlotConfig = Annotated[
Union[ExampleGalleryConfig,], Field(discriminator="name")
]
def build_plotter(config: PlotConfig) -> PlotterProtocol:
return plots_registry.build(config)
@dataclass
class ClassMatches:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -2,6 +2,7 @@ from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import plot_spectrogram from batdetect2.plotting.common import plot_spectrogram
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.heatmaps import ( from batdetect2.plotting.heatmaps import (
plot_classification_heatmap, plot_classification_heatmap,
plot_detection_heatmap, plot_detection_heatmap,
@ -26,4 +27,5 @@ __all__ = [
"plot_true_positive_match", "plot_true_positive_match",
"plot_detection_heatmap", "plot_detection_heatmap",
"plot_classification_heatmap", "plot_classification_heatmap",
"plot_match_gallery",
] ]

View File

@ -1,160 +0,0 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
from batdetect2 import plotting
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import PreprocessorProtocol
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_example_gallery(
matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
):
class_examples = defaultdict(ClassExamples)
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
for class_name, examples in class_examples.items():
true_positives = get_binned_sample(
examples.true_positives,
n_examples=n_examples,
)
false_positives = get_binned_sample(
examples.false_positives,
n_examples=n_examples,
)
false_negatives = random.sample(
examples.false_negatives,
k=min(n_examples, len(examples.false_negatives)),
)
cross_triggers = get_binned_sample(
examples.cross_triggers,
n_examples=n_examples,
)
fig = plot_class_examples(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=preprocessor,
n_examples=n_examples,
)
yield class_name, fig
plt.close(fig)
def plot_class_examples(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
duration: float = 0.1,
):
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives[:n_examples]):
ax = plt.subplot(4, n_examples, index + 1)
try:
plotting.plot_true_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plotting.plot_false_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plotting.plot_false_negative_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
try:
plotting.plot_cross_trigger_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
return fig
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -0,0 +1,81 @@
from typing import List, Optional
import matplotlib.pyplot as plt
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_true_positive_match,
)
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
__all__ = ["plot_match_gallery"]
def plot_match_gallery(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
n_examples: int = 5,
duration: float = 0.1,
):
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives[:n_examples]):
ax = plt.subplot(4, n_examples, index + 1)
try:
plot_true_positive_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plot_false_positive_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plot_false_negative_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
try:
plot_cross_trigger_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
return fig

View File

@ -7,8 +7,8 @@ from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.typing.evaluate import MatchEvaluation from batdetect2.typing.evaluate import MatchEvaluation
__all__ = [ __all__ = [
@ -32,6 +32,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches( def plot_matches(
matches: List[data.Match], matches: List[data.Match],
clip: data.Clip, clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None, ax: Optional[Axes] = None,
@ -46,12 +47,11 @@ def plot_matches(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes: ) -> Axes:
if preprocessor is None:
preprocessor = build_preprocessor()
ax = plot_clip( ax = plot_clip(
clip, clip,
ax=ax, ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize, figsize=figsize,
audio_dir=audio_dir, audio_dir=audio_dir,
spec_cmap=spec_cmap, spec_cmap=spec_cmap,
@ -116,6 +116,7 @@ def plot_matches(
def plot_false_positive_match( def plot_false_positive_match(
match: MatchEvaluation, match: MatchEvaluation,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None, ax: Optional[Axes] = None,
@ -143,6 +144,7 @@ def plot_false_positive_match(
ax = plot_clip( ax = plot_clip(
clip, clip,
audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
figsize=figsize, figsize=figsize,
ax=ax, ax=ax,
@ -174,6 +176,7 @@ def plot_false_positive_match(
def plot_false_negative_match( def plot_false_negative_match(
match: MatchEvaluation, match: MatchEvaluation,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None, ax: Optional[Axes] = None,
@ -203,6 +206,7 @@ def plot_false_negative_match(
ax = plot_clip( ax = plot_clip(
clip, clip,
audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
figsize=figsize, figsize=figsize,
ax=ax, ax=ax,
@ -237,6 +241,7 @@ def plot_false_negative_match(
def plot_true_positive_match( def plot_true_positive_match(
match: MatchEvaluation, match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None, ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
@ -267,6 +272,7 @@ def plot_true_positive_match(
ax = plot_clip( ax = plot_clip(
clip, clip,
audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
figsize=figsize, figsize=figsize,
ax=ax, ax=ax,
@ -312,6 +318,7 @@ def plot_true_positive_match(
def plot_cross_trigger_match( def plot_cross_trigger_match(
match: MatchEvaluation, match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None, ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
@ -342,6 +349,7 @@ def plot_cross_trigger_match(
ax = plot_clip( ax = plot_clip(
clip, clip,
audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
figsize=figsize, figsize=figsize,
ax=ax, ax=ax,

View File

@ -1,49 +1,26 @@
from typing import List, Optional from typing import List
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.match import ( from batdetect2.evaluate import Evaluator
MatchConfig,
build_matcher,
match_all_predictions,
)
from batdetect2.plotting.clips import PreprocessorProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_raw_predictions from batdetect2.postprocess import get_raw_predictions
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter from batdetect2.train.logging import get_image_plotter
from batdetect2.typing import ( from batdetect2.typing.evaluate import ClipEvaluation
MatchEvaluation,
MetricsProtocol,
)
from batdetect2.typing.models import ModelOutput from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import RawPrediction from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.train import TrainExample from batdetect2.typing.train import TrainExample
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__( def __init__(self, evaluator: Evaluator):
self,
metrics: List[MetricsProtocol],
preprocessor: PreprocessorProtocol,
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
super().__init__() super().__init__()
if len(metrics) == 0: self.evaluator = evaluator
raise ValueError("At least one metric needs to be provided")
self.match_config = match_config
self.metrics = metrics
self.preprocessor = preprocessor
self.plot = plot
self.matcher = build_matcher(config=match_config)
self._clip_annotations: List[data.ClipAnnotation] = [] self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[List[RawPrediction]] = [] self._predictions: List[List[RawPrediction]] = []
@ -58,33 +35,22 @@ class ValidationMetrics(Callback):
def plot_examples( def plot_examples(
self, self,
pl_module: LightningModule, pl_module: LightningModule,
matches: List[MatchEvaluation], evaluated_clips: List[ClipEvaluation],
): ):
plotter = get_image_plotter(pl_module.logger) # type: ignore plotter = get_image_plotter(pl_module.logger) # type: ignore
if plotter is None: if plotter is None:
return return
for class_name, fig in plot_example_gallery( for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
matches, plotter(figure_name, fig, pl_module.global_step)
preprocessor=self.preprocessor,
n_examples=4,
):
plotter(
f"examples/{class_name}",
fig,
pl_module.global_step,
)
def log_metrics( def log_metrics(
self, self,
pl_module: LightningModule, pl_module: LightningModule,
matches: List[MatchEvaluation], evaluated_clips: List[ClipEvaluation],
): ):
metrics = {} metrics = self.evaluator.compute_metrics(evaluated_clips)
for metric in self.metrics:
metrics.update(metric(matches).items())
pl_module.log_dict(metrics) pl_module.log_dict(metrics)
def on_validation_epoch_end( def on_validation_epoch_end(
@ -92,17 +58,13 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
matches = match_all_predictions( clip_evaluations = self.evaluator.evaluate(
self._clip_annotations, self._clip_annotations,
self._predictions, self._predictions,
targets=pl_module.model.targets,
matcher=self.matcher,
) )
self.log_metrics(pl_module, matches) self.log_metrics(pl_module, clip_evaluations)
self.plot_examples(pl_module, clip_evaluations)
if self.plot:
self.plot_examples(pl_module, matches)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)

View File

@ -14,7 +14,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
registry: Registry[ClipperProtocol] = Registry("clipper") clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
class RandomClipConfig(BaseConfig): class RandomClipConfig(BaseConfig):
@ -25,7 +25,6 @@ class RandomClipConfig(BaseConfig):
min_sound_event_overlap: float = 0 min_sound_event_overlap: float = 0
@registry.register(RandomClipConfig)
class RandomClip: class RandomClip:
def __init__( def __init__(
self, self,
@ -61,6 +60,9 @@ class RandomClip:
) )
clipper_registry.register(RandomClipConfig, RandomClip)
def get_subclip_annotation( def get_subclip_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
random: bool = True, random: bool = True,
@ -156,7 +158,6 @@ class PaddedClipConfig(BaseConfig):
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
@registry.register(PaddedClipConfig)
class PaddedClip: class PaddedClip:
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION): def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
self.chunk_size = chunk_size self.chunk_size = chunk_size
@ -183,6 +184,8 @@ class PaddedClip:
return cls(chunk_size=config.chunk_size) return cls(chunk_size=config.chunk_size)
clipper_registry.register(PaddedClipConfig, PaddedClip)
ClipConfig = Annotated[ ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
] ]
@ -195,4 +198,4 @@ def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
"Building clipper with config: \n{}", "Building clipper with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return registry.build(config) return clipper_registry.build(config)

View File

@ -10,10 +10,7 @@ from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.evaluator import build_evaluator
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.plotting.clips import AudioLoader, build_audio_loader from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
@ -146,7 +143,6 @@ def build_training_module(
def build_trainer_callbacks( def build_trainer_callbacks(
targets: TargetProtocol, targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: EvaluationConfig, config: EvaluationConfig,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
@ -161,6 +157,8 @@ def build_trainer_callbacks(
if run_name is not None: if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name checkpoint_dir = checkpoint_dir / run_name
evaluator = build_evaluator(config=config, targets=targets)
return [ return [
ModelCheckpoint( ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
@ -168,16 +166,7 @@ def build_trainer_callbacks(
filename="best-{epoch:02d}-{val_loss:.0f}", filename="best-{epoch:02d}-{val_loss:.0f}",
monitor="total_loss/val", monitor="total_loss/val",
), ),
ValidationMetrics( ValidationMetrics(evaluator),
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names
),
],
preprocessor=preprocessor,
match_config=config.match,
),
] ]
@ -214,7 +203,6 @@ def build_trainer(
callbacks=build_trainer_callbacks( callbacks=build_trainer_callbacks(
targets, targets,
config=conf.evaluation, config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,

View File

@ -11,6 +11,7 @@ from typing import (
TypeVar, TypeVar,
) )
from matplotlib.figure import Figure
from soundevent import data from soundevent import data
__all__ = [ __all__ = [
@ -50,6 +51,12 @@ class MatchEvaluation:
return self.pred_class_scores[pred_class] return self.pred_class_scores[pred_class]
@dataclass
class ClipEvaluation:
clip: data.Clip
matches: List[MatchEvaluation]
class MatcherProtocol(Protocol): class MatcherProtocol(Protocol):
def __call__( def __call__(
self, self,
@ -67,10 +74,16 @@ class AffinityFunction(Protocol, Generic[Geom]):
self, self,
geometry1: Geom, geometry1: Geom,
geometry2: Geom, geometry2: Geom,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ... ) -> float: ...
class MetricsProtocol(Protocol): class MetricsProtocol(Protocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ... def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]: ...
class PlotterProtocol(Protocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...