Restructure eval metrics and plotting

This commit is contained in:
mbsantiago 2025-09-15 16:01:15 +01:00
parent ec1c0ff020
commit e752e96b93
20 changed files with 991 additions and 630 deletions

View File

@ -1,6 +1,7 @@
from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel
from typing_extensions import ParamSpec
__all__ = [
"Registry",
@ -8,26 +9,36 @@ __all__ = [
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type")
class LogicProtocol(Generic[T_Config, T_Type], Protocol):
"""A generic protocol for the logic classes (conditions or transforms)."""
class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol):
"""A generic protocol for the logic classes."""
@classmethod
def from_config(cls, config: T_Config) -> T_Type: ...
def from_config(
cls,
config: T_Config,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type: ...
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
class Registry(Generic[T_Type]):
class Registry(Generic[T_Type, P_Type]):
"""A generic class to create and manage a registry of items."""
def __init__(self, name: str):
self._name = name
self._registry = {}
def register(self, config_cls: Type[T_Config]):
def register(
self,
config_cls: Type[T_Config],
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
) -> None:
"""A decorator factory to register a new item."""
fields = config_cls.model_fields
@ -39,13 +50,14 @@ class Registry(Generic[T_Type]):
if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.")
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
self._registry[name] = logic_cls
return logic_cls
self._registry[name] = logic_cls
return decorator
def build(self, config: BaseModel) -> T_Type:
def build(
self,
config: BaseModel,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type:
"""Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009
@ -58,4 +70,4 @@ class Registry(Generic[T_Type]):
f"No {self._name} with name '{name}' is registered."
)
return self._registry[name].from_config(config)
return self._registry[name].from_config(config, *args, **kwargs)

View File

@ -10,7 +10,7 @@ from batdetect2.data._core import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
_conditions: Registry[SoundEventCondition] = Registry("condition")
condition_registry: Registry[SoundEventCondition, []] = Registry("condition")
class HasTagConfig(BaseConfig):
@ -18,7 +18,6 @@ class HasTagConfig(BaseConfig):
tag: data.Tag
@_conditions.register(HasTagConfig)
class HasTag:
def __init__(self, tag: data.Tag):
self.tag = tag
@ -33,12 +32,14 @@ class HasTag:
return cls(tag=config.tag)
condition_registry.register(HasTagConfig, HasTag)
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: List[data.Tag]
@_conditions.register(HasAllTagsConfig)
class HasAllTags:
def __init__(self, tags: List[data.Tag]):
if not tags:
@ -56,12 +57,14 @@ class HasAllTags:
return cls(tags=config.tags)
condition_registry.register(HasAllTagsConfig, HasAllTags)
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: List[data.Tag]
@_conditions.register(HasAnyTagConfig)
class HasAnyTag:
def __init__(self, tags: List[data.Tag]):
if not tags:
@ -79,6 +82,8 @@ class HasAnyTag:
return cls(tags=config.tags)
condition_registry.register(HasAnyTagConfig, HasAnyTag)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
@ -109,7 +114,6 @@ def _build_comparator(
raise ValueError(f"Invalid operator {operator}")
@_conditions.register(DurationConfig)
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
@ -135,6 +139,9 @@ class Duration:
return cls(operator=config.operator, seconds=config.seconds)
condition_registry.register(DurationConfig, Duration)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
@ -142,7 +149,6 @@ class FrequencyConfig(BaseConfig):
hertz: float
@_conditions.register(FrequencyConfig)
class Frequency:
def __init__(
self,
@ -184,12 +190,14 @@ class Frequency:
)
condition_registry.register(FrequencyConfig, Frequency)
class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
@_conditions.register(AllOfConfig)
class AllOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
@ -207,12 +215,14 @@ class AllOf:
return cls(conditions)
condition_registry.register(AllOfConfig, AllOf)
class AnyOfConfig(BaseConfig):
name: Literal["any_of"] = "any_of"
conditions: List["SoundEventConditionConfig"]
@_conditions.register(AnyOfConfig)
class AnyOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
@ -230,12 +240,14 @@ class AnyOf:
return cls(conditions)
condition_registry.register(AnyOfConfig, AnyOf)
class NotConfig(BaseConfig):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
@_conditions.register(NotConfig)
class Not:
def __init__(self, condition: SoundEventCondition):
self.condition = condition
@ -251,6 +263,8 @@ class Not:
return cls(condition)
condition_registry.register(NotConfig, Not)
SoundEventConditionConfig = Annotated[
Union[
HasTagConfig,
@ -269,7 +283,7 @@ SoundEventConditionConfig = Annotated[
def build_sound_event_condition(
config: SoundEventConditionConfig,
) -> SoundEventCondition:
return _conditions.build(config)
return condition_registry.build(config)
def filter_clip_annotation(

View File

@ -17,7 +17,7 @@ SoundEventTransform = Callable[
data.SoundEventAnnotation,
]
_transforms: Registry[SoundEventTransform] = Registry("transform")
transform_registry: Registry[SoundEventTransform, []] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig):
@ -26,7 +26,6 @@ class SetFrequencyBoundConfig(BaseConfig):
hertz: float
@_transforms.register(SetFrequencyBoundConfig)
class SetFrequencyBound:
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
self.hertz = hertz
@ -69,13 +68,15 @@ class SetFrequencyBound:
return cls(hertz=config.hertz, boundary=config.boundary)
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
class ApplyIfConfig(BaseConfig):
name: Literal["apply_if"] = "apply_if"
transform: "SoundEventTransformConfig"
condition: SoundEventConditionConfig
@_transforms.register(ApplyIfConfig)
class ApplyIf:
def __init__(
self,
@ -101,13 +102,15 @@ class ApplyIf:
return cls(condition=condition, transform=transform)
transform_registry.register(ApplyIfConfig, ApplyIf)
class ReplaceTagConfig(BaseConfig):
name: Literal["replace_tag"] = "replace_tag"
original: data.Tag
replacement: data.Tag
@_transforms.register(ReplaceTagConfig)
class ReplaceTag:
def __init__(
self,
@ -136,6 +139,9 @@ class ReplaceTag:
return cls(original=config.original, replacement=config.replacement)
transform_registry.register(ReplaceTagConfig, ReplaceTag)
class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str
@ -143,7 +149,6 @@ class MapTagValueConfig(BaseConfig):
target_key: Optional[str] = None
@_transforms.register(MapTagValueConfig)
class MapTagValue:
def __init__(
self,
@ -193,12 +198,14 @@ class MapTagValue:
)
transform_registry.register(MapTagValueConfig, MapTagValue)
class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
@_transforms.register(ApplyAllConfig)
class ApplyAll:
def __init__(self, steps: List[SoundEventTransform]):
self.steps = steps
@ -218,6 +225,8 @@ class ApplyAll:
return cls(steps)
transform_registry.register(ApplyAllConfig, ApplyAll)
SoundEventTransformConfig = Annotated[
Union[
SetFrequencyBoundConfig,
@ -233,7 +242,7 @@ SoundEventTransformConfig = Annotated[
def build_sound_event_transform(
config: SoundEventTransformConfig,
) -> SoundEventTransform:
return _transforms.build(config)
return transform_registry.build(config)
def transform_clip_annotation(

View File

@ -1,6 +1,9 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
"Evaluator",
"build_evaluator",
]

View File

@ -0,0 +1,151 @@
from typing import Annotated, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing.evaluate import AffinityFunction
affinity_functions: Registry[AffinityFunction, []] = Registry(
"matching_strategy"
)
class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity"
time_buffer: float = 0.01
class TimeAffinity(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_timestamp_affinity(
geometry1, geometry2, time_buffer=self.time_buffer
)
@classmethod
def from_config(cls, config: TimeAffinityConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
def compute_timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
class IntervalIOUConfig(BaseConfig):
name: Literal["interval_iou"] = "interval_iou"
time_buffer: float = 0.01
class IntervalIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_interval_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
@classmethod
def from_config(cls, config: IntervalIOUConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
def compute_interval_iou(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
class GeometricIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_affinity(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
@classmethod
def from_config(cls, config: GeometricIOUConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
AffinityConfig = Annotated[
Union[
TimeAffinityConfig,
IntervalIOUConfig,
GeometricIOUConfig,
],
Field(discriminator="name"),
]
def build_affinity_function(
config: Optional[AffinityConfig] = None,
) -> AffinityFunction:
config = config or GeometricIOUConfig()
return affinity_functions.build(config)

View File

@ -1,10 +1,16 @@
from typing import Optional
from typing import List, Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
from batdetect2.evaluate.metrics import (
ClassificationAPConfig,
DetectionAPConfig,
MetricConfig,
)
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
__all__ = [
"EvaluationConfig",
@ -13,7 +19,19 @@ __all__ = [
class EvaluationConfig(BaseConfig):
ignore_start_end: float = 0.01
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
metrics: List[MetricConfig] = Field(
default_factory=lambda: [
DetectionAPConfig(),
ClassificationAPConfig(),
]
)
plots: List[PlotConfig] = Field(
default_factory=lambda: [
ExampleGalleryConfig(),
]
)
def load_evaluation_config(

View File

@ -3,60 +3,61 @@ from typing import List
import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.evaluate import ClipEvaluation
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
data = []
for match in matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
sound_event_annotation = match.sound_event_annotation
sound_event_annotation = match.sound_event_annotation
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
)
if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
compute_bounds(match.pred_geometry)
)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.pred_class,
("pred", "class_score"): match.pred_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
compute_bounds(match.pred_geometry)
)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.pred_class,
("pred", "class_score"): match.pred_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df

View File

@ -4,11 +4,8 @@ import pandas as pd
from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.match import build_matcher, match_all_predictions
from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
from batdetect2.models import Model
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.postprocess import get_raw_predictions
@ -55,6 +52,8 @@ def evaluate(
clip_annotations = []
predictions = []
evaluator = build_evaluator(config=config.evaluation)
for batch in loader:
outputs = model.detector(batch.spec)
@ -76,20 +75,12 @@ def evaluate(
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matcher = build_matcher(config.evaluation.match)
matches = match_all_predictions(
clip_annotations,
predictions,
targets=targets,
matcher=matcher,
)
matches = evaluator.evaluate(clip_annotations, predictions)
df = extract_matches_dataframe(matches)
metrics = [
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(class_names=targets.class_names),
DetectionAP(),
ClassificationAP(class_names=targets.class_names),
]
results = {

View File

@ -0,0 +1,169 @@
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from matplotlib.figure import Figure
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.match import build_matcher, match
from batdetect2.evaluate.metrics import build_metric
from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets
from batdetect2.typing.evaluate import (
ClipEvaluation,
MatcherProtocol,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"Evaluator",
"build_evaluator",
]
class Evaluator:
def __init__(
self,
config: EvaluationConfig,
targets: TargetProtocol,
matcher: MatcherProtocol,
metrics: List[MetricsProtocol],
plots: List[PlotterProtocol],
):
self.config = config
self.targets = targets
self.matcher = matcher
self.metrics = metrics
self.plots = plots
def match(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEvaluation:
clip = clip_annotation.clip
ground_truth = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
predictions = [
prediction
for prediction in predictions
if self.filter_predictions(prediction, clip)
]
return ClipEvaluation(
clip=clip_annotation.clip,
matches=match(
ground_truth,
predictions,
clip=clip,
targets=self.targets,
matcher=self.matcher,
),
)
def filter_sound_event_annotations(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.config.ignore_start_end,
)
def filter_predictions(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.config.ignore_start_end,
)
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]:
if len(clip_annotations) != len(predictions):
raise ValueError(
"Number of annotated clips and sets of predictions do not match"
)
return [
self.match(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def compute_metrics(
self,
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, float]:
results = {}
for metric in self.metrics:
results.update(metric(clip_evaluations))
return results
def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]:
for plotter in self.plots:
for name, fig in plotter(clip_evaluations):
yield name, fig
def build_evaluator(
config: Optional[EvaluationConfig] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None,
) -> Evaluator:
config = config or EvaluationConfig()
targets = targets or build_targets()
matcher = matcher or build_matcher(config.match)
if metrics is None:
metrics = [
build_metric(config, targets.class_names)
for config in config.metrics
]
if plots is None:
plots = [build_plotter(config) for config in config.plots]
return Evaluator(
config=config,
targets=targets,
matcher=matcher,
metrics=metrics,
plots=plots,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -1,9 +1,7 @@
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
from loguru import logger
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
@ -12,6 +10,11 @@ from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
GeometricIOUConfig,
build_affinity_function,
)
from batdetect2.targets import build_targets
from batdetect2.typing import (
MatchEvaluation,
@ -23,7 +26,88 @@ from batdetect2.typing.postprocess import RawPrediction
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
matching_strategy = Registry("matching_strategy")
matching_strategies = Registry("matching_strategy")
def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
) -> List[MatchEvaluation]:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in sound_event_annotations
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
scores = [
raw_prediction.detection_score for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
sound_event_annotations[target_idx]
if target_idx is not None
else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return matches
class StartTimeMatchConfig(BaseConfig):
@ -31,7 +115,6 @@ class StartTimeMatchConfig(BaseConfig):
distance_threshold: float = 0.01
@matching_strategy.register(StartTimeMatchConfig)
class StartTimeMatcher(MatcherProtocol):
def __init__(self, distance_threshold: float):
self.distance_threshold = distance_threshold
@ -54,6 +137,9 @@ class StartTimeMatcher(MatcherProtocol):
return cls(distance_threshold=config.distance_threshold)
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
def match_start_times(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
@ -74,8 +160,8 @@ def match_start_times(
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
scores = np.array(scores)
scores = np.array(scores)
sort_args = np.argsort(scores)[::-1]
distances = np.abs(gt_times[None, :] - pred_times[:, None])
@ -143,89 +229,25 @@ _geometry_cast_functions: Mapping[
}
def _timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
def _interval_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
"timestamp": _timestamp_affinity,
"interval": _interval_affinity,
"bbox": compute_affinity,
}
class GreedyMatchConfig(BaseConfig):
name: Literal["greedy_match"] = "greedy_match"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000
affinity_threshold: float = 0.5
affinity_function: AffinityConfig = Field(
default_factory=GeometricIOUConfig
)
@matching_strategy.register(GreedyMatchConfig)
class GreedyMatcher(MatcherProtocol):
def __init__(
self,
geometry: MatchingGeometry,
affinity_threshold: float,
time_buffer: float,
frequency_buffer: float,
affinity_function: AffinityFunction,
):
self.geometry = geometry
self.affinity_function = affinity_function
self.affinity_threshold = affinity_threshold
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.affinity_function = _affinity_functions[self.geometry]
self.cast_geometry = _geometry_cast_functions[self.geometry]
def __call__(
@ -240,28 +262,27 @@ class GreedyMatcher(MatcherProtocol):
scores=scores,
affinity_function=self.affinity_function,
affinity_threshold=self.affinity_threshold,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
@classmethod
def from_config(cls, config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return cls(
geometry=config.geometry,
affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
affinity_function=affinity_function,
)
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
def greedy_match(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
time_buffer: float = 0.001,
freq_buffer: float = 1000,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
"""Performs a greedy, one-to-one matching of source to target geometries.
@ -279,10 +300,6 @@ def greedy_match(
Confidence scores for each source geometry for prioritization.
affinity_threshold
The minimum affinity score required for a valid match.
time_buffer
Time tolerance in seconds for affinity calculation.
freq_buffer
Frequency tolerance in Hertz for affinity calculation.
Yields
------
@ -314,12 +331,7 @@ def greedy_match(
affinities = np.array(
[
affinity_function(
source_geometry,
target_geometry,
time_buffer=time_buffer,
freq_buffer=freq_buffer,
)
affinity_function(source_geometry, target_geometry)
for target_geometry in ground_truth
]
)
@ -344,12 +356,11 @@ def greedy_match(
class OptimalMatchConfig(BaseConfig):
name: Literal["optimal_match"] = "optimal_match"
affinity_threshold: float = 0.0
affinity_threshold: float = 0.5
time_buffer: float = 0.005
frequency_buffer: float = 1_000
@matching_strategy.register(OptimalMatchConfig)
class OptimalMatcher(MatcherProtocol):
def __init__(
self,
@ -384,6 +395,8 @@ class OptimalMatcher(MatcherProtocol):
)
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
MatchConfig = Annotated[
Union[
GreedyMatchConfig,
@ -396,174 +409,4 @@ MatchConfig = Annotated[
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategy.build(config)
def _is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)
def match_sound_events_and_predictions(
clip_annotation: data.ClipAnnotation,
raw_predictions: List[RawPrediction],
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
and _is_in_bounds(
sound_event_annotation.sound_event.geometry,
clip=clip_annotation.clip,
buffer=ignore_start_end,
)
]
target_geometries: List[data.Geometry] = [
sound_event_annotation.sound_event.geometry
for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
]
raw_predictions = [
raw_prediction
for raw_prediction in raw_predictions
if _is_in_bounds(
raw_prediction.geometry,
clip=clip_annotation.clip,
buffer=ignore_start_end,
)
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
scores = [
raw_prediction.detection_score for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
target_sound_events[target_idx] if target_idx is not None else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip_annotation.clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return matches
def match_all_predictions(
clip_annotations: List[data.ClipAnnotation],
predictions: List[List[RawPrediction]],
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in zip(
clip_annotations,
predictions,
)
for match in match_sound_events_and_predictions(
clip_annotation,
raw_predictions,
targets=targets,
matcher=matcher,
ignore_start_end=ignore_start_end,
)
]
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(matches: List[MatchEvaluation]) -> ClassExamples:
class_examples = ClassExamples()
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples.false_negatives.append(match)
continue
if gt_class is None:
class_examples.false_positives.append(match)
continue
if gt_class != pred_class:
class_examples.cross_triggers.append(match)
class_examples.cross_triggers.append(match)
continue
class_examples.true_positives.append(match)
return class_examples
return matching_strategies.build(config)

View File

@ -1,61 +1,151 @@
from typing import Dict, List
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import numpy as np
import pandas as pd
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.typing import MatchEvaluation, MetricsProtocol
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing import MetricsProtocol
from batdetect2.typing.evaluate import ClipEvaluation
__all__ = ["DetectionAveragePrecision"]
__all__ = ["DetectionAP", "ClassificationAP"]
class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
class DetectionAP(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(metrics.average_precision_score(y_true, y_score))
return {"detection_AP": score}
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls()
class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str]):
metrics_registry.register(DetectionAPConfig, DetectionAP)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
# NOTE: Need to exclude generic but unclassified targets
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
if not (match.gt_det and match.gt_class is None)
],
classes=self.class_names,
)
y_pred = pd.DataFrame(
[
{
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
if not (match.gt_det and match.gt_class is None)
]
).fillna(0)
self.selected = class_names
ret = {}
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name]
y_pred_class = y_pred[:, class_index]
class_ap = metrics.average_precision_score(
y_true_class,
y_pred_class,
)
ret[f"classification_AP/{class_name}"] = float(class_ap)
class_scores[class_name] = float(class_ap)
ret["classification_mAP"] = np.mean(
[value for value in ret.values() if value != 0]
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return ret
return {
"classification_mAP": float(mean_ap),
**{
f"classification_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationAPConfig,
class_names: List[str],
):
return cls(
class_names,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
MetricConfig = Annotated[
Union[ClassificationAPConfig, DetectionAPConfig],
Field(discriminator="name"),
]
def build_metric(config: MetricConfig, class_names: List[str]):
return metrics_registry.build(config, class_names)

View File

@ -0,0 +1,163 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import pandas as pd
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing.evaluate import (
ClipEvaluation,
MatchEvaluation,
PlotterProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
__all__ = [
"build_plotter",
"ExampleGallery",
"ExampleGalleryConfig",
]
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
class ExampleGalleryConfig(BaseConfig):
name: Literal["example_gallery"] = "example_gallery"
examples_per_class: int = 5
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleGallery(PlotterProtocol):
def __init__(
self,
examples_per_class: int,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.examples_per_class = examples_per_class
self.preprocessor = preprocessor or build_preprocessor()
self.audio_loader = audio_loader or build_audio_loader()
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
per_class_matches = group_matches(clip_evaluations)
for class_name, matches in per_class_matches.items():
true_positives = get_binned_sample(
matches.true_positives,
n_examples=self.examples_per_class,
)
false_positives = get_binned_sample(
matches.false_positives,
n_examples=self.examples_per_class,
)
false_negatives = random.sample(
matches.false_negatives,
k=min(self.examples_per_class, len(matches.false_negatives)),
)
cross_triggers = get_binned_sample(
matches.cross_triggers,
n_examples=self.examples_per_class,
)
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.examples_per_class,
)
yield f"example_gallery/{class_name}", fig
plt.close(fig)
@classmethod
def from_config(cls, config: ExampleGalleryConfig):
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
PlotConfig = Annotated[
Union[ExampleGalleryConfig,], Field(discriminator="name")
]
def build_plotter(config: PlotConfig) -> PlotterProtocol:
return plots_registry.build(config)
@dataclass
class ClassMatches:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -2,6 +2,7 @@ from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.heatmaps import (
plot_classification_heatmap,
plot_detection_heatmap,
@ -26,4 +27,5 @@ __all__ = [
"plot_true_positive_match",
"plot_detection_heatmap",
"plot_classification_heatmap",
"plot_match_gallery",
]

View File

@ -1,160 +0,0 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
from batdetect2 import plotting
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import PreprocessorProtocol
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_example_gallery(
matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
):
class_examples = defaultdict(ClassExamples)
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
for class_name, examples in class_examples.items():
true_positives = get_binned_sample(
examples.true_positives,
n_examples=n_examples,
)
false_positives = get_binned_sample(
examples.false_positives,
n_examples=n_examples,
)
false_negatives = random.sample(
examples.false_negatives,
k=min(n_examples, len(examples.false_negatives)),
)
cross_triggers = get_binned_sample(
examples.cross_triggers,
n_examples=n_examples,
)
fig = plot_class_examples(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=preprocessor,
n_examples=n_examples,
)
yield class_name, fig
plt.close(fig)
def plot_class_examples(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
duration: float = 0.1,
):
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives[:n_examples]):
ax = plt.subplot(4, n_examples, index + 1)
try:
plotting.plot_true_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plotting.plot_false_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plotting.plot_false_negative_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
try:
plotting.plot_cross_trigger_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
return fig
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -0,0 +1,81 @@
from typing import List, Optional
import matplotlib.pyplot as plt
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_true_positive_match,
)
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
__all__ = ["plot_match_gallery"]
def plot_match_gallery(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
n_examples: int = 5,
duration: float = 0.1,
):
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives[:n_examples]):
ax = plt.subplot(4, n_examples, index + 1)
try:
plot_true_positive_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plot_false_positive_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plot_false_negative_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
try:
plot_cross_trigger_match(
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
return fig

View File

@ -7,8 +7,8 @@ from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.typing.evaluate import MatchEvaluation
__all__ = [
@ -32,6 +32,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[data.Match],
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
@ -46,12 +47,11 @@ def plot_matches(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
if preprocessor is None:
preprocessor = build_preprocessor()
ax = plot_clip(
clip,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
@ -116,6 +116,7 @@ def plot_matches(
def plot_false_positive_match(
match: MatchEvaluation,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
@ -143,6 +144,7 @@ def plot_false_positive_match(
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
@ -174,6 +176,7 @@ def plot_false_positive_match(
def plot_false_negative_match(
match: MatchEvaluation,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
@ -203,6 +206,7 @@ def plot_false_negative_match(
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
@ -237,6 +241,7 @@ def plot_false_negative_match(
def plot_true_positive_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
@ -267,6 +272,7 @@ def plot_true_positive_match(
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
@ -312,6 +318,7 @@ def plot_true_positive_match(
def plot_cross_trigger_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
@ -342,6 +349,7 @@ def plot_cross_trigger_match(
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,

View File

@ -1,49 +1,26 @@
from typing import List, Optional
from typing import List
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.match import (
MatchConfig,
build_matcher,
match_all_predictions,
)
from batdetect2.plotting.clips import PreprocessorProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.evaluate import Evaluator
from batdetect2.postprocess import get_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter
from batdetect2.typing import (
MatchEvaluation,
MetricsProtocol,
)
from batdetect2.typing.evaluate import ClipEvaluation
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.train import TrainExample
class ValidationMetrics(Callback):
def __init__(
self,
metrics: List[MetricsProtocol],
preprocessor: PreprocessorProtocol,
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
def __init__(self, evaluator: Evaluator):
super().__init__()
if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided")
self.match_config = match_config
self.metrics = metrics
self.preprocessor = preprocessor
self.plot = plot
self.matcher = build_matcher(config=match_config)
self.evaluator = evaluator
self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[List[RawPrediction]] = []
@ -58,33 +35,22 @@ class ValidationMetrics(Callback):
def plot_examples(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
evaluated_clips: List[ClipEvaluation],
):
plotter = get_image_plotter(pl_module.logger) # type: ignore
if plotter is None:
return
for class_name, fig in plot_example_gallery(
matches,
preprocessor=self.preprocessor,
n_examples=4,
):
plotter(
f"examples/{class_name}",
fig,
pl_module.global_step,
)
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, pl_module.global_step)
def log_metrics(
self,
pl_module: LightningModule,
matches: List[MatchEvaluation],
evaluated_clips: List[ClipEvaluation],
):
metrics = {}
for metric in self.metrics:
metrics.update(metric(matches).items())
metrics = self.evaluator.compute_metrics(evaluated_clips)
pl_module.log_dict(metrics)
def on_validation_epoch_end(
@ -92,17 +58,13 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
matches = match_all_predictions(
clip_evaluations = self.evaluator.evaluate(
self._clip_annotations,
self._predictions,
targets=pl_module.model.targets,
matcher=self.matcher,
)
self.log_metrics(pl_module, matches)
if self.plot:
self.plot_examples(pl_module, matches)
self.log_metrics(pl_module, clip_evaluations)
self.plot_examples(pl_module, clip_evaluations)
return super().on_validation_epoch_end(trainer, pl_module)

View File

@ -14,7 +14,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1
registry: Registry[ClipperProtocol] = Registry("clipper")
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
class RandomClipConfig(BaseConfig):
@ -25,7 +25,6 @@ class RandomClipConfig(BaseConfig):
min_sound_event_overlap: float = 0
@registry.register(RandomClipConfig)
class RandomClip:
def __init__(
self,
@ -61,6 +60,9 @@ class RandomClip:
)
clipper_registry.register(RandomClipConfig, RandomClip)
def get_subclip_annotation(
clip_annotation: data.ClipAnnotation,
random: bool = True,
@ -156,7 +158,6 @@ class PaddedClipConfig(BaseConfig):
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
@registry.register(PaddedClipConfig)
class PaddedClip:
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
self.chunk_size = chunk_size
@ -183,6 +184,8 @@ class PaddedClip:
return cls(chunk_size=config.chunk_size)
clipper_registry.register(PaddedClipConfig, PaddedClip)
ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
]
@ -195,4 +198,4 @@ def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return registry.build(config)
return clipper_registry.build(config)

View File

@ -10,10 +10,7 @@ from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
@ -146,7 +143,6 @@ def build_training_module(
def build_trainer_callbacks(
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: EvaluationConfig,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
@ -161,6 +157,8 @@ def build_trainer_callbacks(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = build_evaluator(config=config, targets=targets)
return [
ModelCheckpoint(
dirpath=str(checkpoint_dir),
@ -168,16 +166,7 @@ def build_trainer_callbacks(
filename="best-{epoch:02d}-{val_loss:.0f}",
monitor="total_loss/val",
),
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names
),
],
preprocessor=preprocessor,
match_config=config.match,
),
ValidationMetrics(evaluator),
]
@ -214,7 +203,6 @@ def build_trainer(
callbacks=build_trainer_callbacks(
targets,
config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name,
run_name=run_name,

View File

@ -11,6 +11,7 @@ from typing import (
TypeVar,
)
from matplotlib.figure import Figure
from soundevent import data
__all__ = [
@ -50,6 +51,12 @@ class MatchEvaluation:
return self.pred_class_scores[pred_class]
@dataclass
class ClipEvaluation:
clip: data.Clip
matches: List[MatchEvaluation]
class MatcherProtocol(Protocol):
def __call__(
self,
@ -67,10 +74,16 @@ class AffinityFunction(Protocol, Generic[Geom]):
self,
geometry1: Geom,
geometry2: Geom,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ...
class MetricsProtocol(Protocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]: ...
class PlotterProtocol(Protocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...