Compare commits

...

6 Commits

Author SHA1 Message Date
mbsantiago
30159d64a9 Update example config 2025-09-28 16:22:21 +01:00
mbsantiago
c9f0c5c431 Added bbox iou affinity function 2025-09-28 16:08:21 +01:00
mbsantiago
10865ee600 Re-org gallery example plots 2025-09-28 15:45:48 +01:00
mbsantiago
87ed44c8f7 Plotting reorganised 2025-09-27 23:58:06 +01:00
mbsantiago
df2abff654 Task/Metrics restructure 2025-09-26 15:23:38 +01:00
mbsantiago
d6ddc4514c Better evaluation organisation 2025-09-25 17:48:29 +01:00
53 changed files with 4566 additions and 1973 deletions

View File

@ -138,27 +138,49 @@ train:
name: csv name: csv
validation: validation:
metrics: tasks:
- name: detection_ap - name: sound_event_detection
- name: detection_roc_auc metrics:
- name: classification_ap - name: average_precision
- name: classification_roc_auc - name: sound_event_classification
- name: top_class_ap metrics:
- name: classification_balanced_accuracy - name: average_precision
- name: clip_ap
- name: clip_roc_auc
evaluation: evaluation:
match_strategy: tasks:
name: start_time_match - name: sound_event_detection
distance_threshold: 0.01 metrics:
metrics: - name: average_precision
- name: classification_ap - name: roc_auc
- name: detection_ap plots:
plots: - name: pr_curve
- name: example_gallery - name: score_distribution
- name: example_clip - name: example_detection
- name: detection_pr_curve - name: sound_event_classification
- name: classification_pr_curves metrics:
- name: detection_roc_curve - name: average_precision
- name: classification_roc_curves - name: roc_auc
plots:
- name: pr_curve
- name: top_class_detection
metrics:
- name: average_precision
plots:
- name: pr_curve
- name: confusion_matrix
- name: example_classification
- name: clip_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
- name: score_distribution
- name: clip_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
from typing import Optional, Sequence from typing import List, Optional, Sequence
import torch
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
@ -8,6 +9,7 @@ from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.decoding import to_raw_predictions
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.targets import build_targets from batdetect2.targets.targets import build_targets
from batdetect2.train import train from batdetect2.train import train
@ -19,6 +21,7 @@ from batdetect2.typing import (
PreprocessorProtocol, PreprocessorProtocol,
TargetProtocol, TargetProtocol,
) )
from batdetect2.typing.postprocess import RawPrediction
class BatDetect2API: class BatDetect2API:
@ -92,6 +95,18 @@ class BatDetect2API:
run_name=run_name, run_name=run_name,
) )
def process_spectrogram(
self,
spec: torch.Tensor,
start_times: Optional[Sequence[float]] = None,
) -> List[List[RawPrediction]]:
outputs = self.model.detector(spec)
clip_detections = self.postprocessor(outputs, start_times=start_times)
return [
to_raw_predictions(clip_dets.numpy(), self.targets)
for clip_dets in clip_detections
]
@classmethod @classmethod
def from_config(cls, config: BatDetect2Config): def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)
@ -108,10 +123,7 @@ class BatDetect2API:
config=config.postprocess, config=config.postprocess,
) )
evaluator = build_evaluator( evaluator = build_evaluator(config=config.evaluation, targets=targets)
config=config.evaluation,
targets=targets,
)
# NOTE: Better to have a separate instance of # NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved # preprocessor and postprocessor as these may be moved
@ -163,10 +175,7 @@ class BatDetect2API:
config=config.postprocess, config=config.postprocess,
) )
evaluator = build_evaluator( evaluator = build_evaluator(config=config.evaluation, targets=targets)
config=config.evaluation,
targets=targets,
)
return cls( return cls(
config=config, config=config,

View File

@ -56,18 +56,16 @@ class RandomClip:
min_sound_event_overlap=self.min_sound_event_overlap, min_sound_event_overlap=self.min_sound_event_overlap,
) )
@classmethod @clipper_registry.register(RandomClipConfig)
def from_config(cls, config: RandomClipConfig): @staticmethod
return cls( def from_config(config: RandomClipConfig):
return RandomClip(
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap, min_sound_event_overlap=config.min_sound_event_overlap,
) )
clipper_registry.register(RandomClipConfig, RandomClip)
def get_subclip_annotation( def get_subclip_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
random: bool = True, random: bool = True,
@ -184,13 +182,12 @@ class PaddedClip:
) )
return clip_annotation.model_copy(update=dict(clip=clip)) return clip_annotation.model_copy(update=dict(clip=clip))
@classmethod @clipper_registry.register(PaddedClipConfig)
def from_config(cls, config: PaddedClipConfig): @staticmethod
return cls(chunk_size=config.chunk_size) def from_config(config: PaddedClipConfig):
return PaddedClip(chunk_size=config.chunk_size)
clipper_registry.register(PaddedClipConfig, PaddedClip)
ClipConfig = Annotated[ ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
] ]

View File

@ -53,6 +53,7 @@ class BaseConfig(BaseModel):
""" """
return yaml.dump( return yaml.dump(
self.model_dump( self.model_dump(
mode="json",
exclude_none=exclude_none, exclude_none=exclude_none,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,

View File

@ -1,16 +1,16 @@
import sys import sys
from typing import Generic, Protocol, Type, TypeVar from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import assert_type
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import ParamSpec from typing import Concatenate, ParamSpec
else: else:
from typing_extensions import ParamSpec from typing_extensions import Concatenate, ParamSpec
__all__ = [ __all__ = [
"Registry", "Registry",
"SimpleRegistry",
] ]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
@ -18,19 +18,26 @@ T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type") P_Type = ParamSpec("P_Type")
class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol): T = TypeVar("T")
"""A generic protocol for the logic classes."""
@classmethod
def from_config(
cls,
config: T_Config,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type: ...
T_Proto = TypeVar("T_Proto", bound=LogicProtocol) class SimpleRegistry(Generic[T]):
def __init__(self, name: str):
self._name = name
self._registry = {}
def register(self, name: str):
def decorator(obj: T) -> T:
self._registry[name] = obj
return obj
return decorator
def get(self, name: str) -> T:
return self._registry[name]
def has(self, name: str) -> bool:
return name in self._registry
class Registry(Generic[T_Type, P_Type]): class Registry(Generic[T_Type, P_Type]):
@ -38,13 +45,15 @@ class Registry(Generic[T_Type, P_Type]):
def __init__(self, name: str): def __init__(self, name: str):
self._name = name self._name = name
self._registry = {} self._registry: Dict[
str, Callable[Concatenate[..., P_Type], T_Type]
] = {}
self._config_types: Dict[str, Type[BaseModel]] = {}
def register( def register(
self, self,
config_cls: Type[T_Config], config_cls: Type[T_Config],
logic_cls: LogicProtocol[T_Config, T_Type, P_Type], ):
) -> None:
fields = config_cls.model_fields fields = config_cls.model_fields
if "name" not in fields: if "name" not in fields:
@ -52,10 +61,21 @@ class Registry(Generic[T_Type, P_Type]):
name = fields["name"].default name = fields["name"].default
self._config_types[name] = config_cls
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.") raise ValueError("'name' field must be a string literal.")
self._registry[name] = logic_cls def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
):
self._registry[name] = func
return func
return decorator
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
return tuple(self._config_types.values())
def build( def build(
self, self,
@ -75,4 +95,4 @@ class Registry(Generic[T_Type, P_Type]):
f"No {self._name} with name '{name}' is registered." f"No {self._name} with name '{name}' is registered."
) )
return self._registry[name].from_config(config, *args, **kwargs) return self._registry[name](config, *args, **kwargs)

View File

@ -10,7 +10,7 @@ from batdetect2.core.registries import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
condition_registry: Registry[SoundEventCondition, []] = Registry("condition") conditions: Registry[SoundEventCondition, []] = Registry("condition")
class HasTagConfig(BaseConfig): class HasTagConfig(BaseConfig):
@ -27,12 +27,10 @@ class HasTag:
) -> bool: ) -> bool:
return self.tag in sound_event_annotation.tags return self.tag in sound_event_annotation.tags
@classmethod @conditions.register(HasTagConfig)
def from_config(cls, config: HasTagConfig): @staticmethod
return cls(tag=config.tag) def from_config(config: HasTagConfig):
return HasTag(tag=config.tag)
condition_registry.register(HasTagConfig, HasTag)
class HasAllTagsConfig(BaseConfig): class HasAllTagsConfig(BaseConfig):
@ -52,12 +50,10 @@ class HasAllTags:
) -> bool: ) -> bool:
return self.tags.issubset(sound_event_annotation.tags) return self.tags.issubset(sound_event_annotation.tags)
@classmethod @conditions.register(HasAllTagsConfig)
def from_config(cls, config: HasAllTagsConfig): @staticmethod
return cls(tags=config.tags) def from_config(config: HasAllTagsConfig):
return HasAllTags(tags=config.tags)
condition_registry.register(HasAllTagsConfig, HasAllTags)
class HasAnyTagConfig(BaseConfig): class HasAnyTagConfig(BaseConfig):
@ -77,13 +73,12 @@ class HasAnyTag:
) -> bool: ) -> bool:
return bool(self.tags.intersection(sound_event_annotation.tags)) return bool(self.tags.intersection(sound_event_annotation.tags))
@classmethod @conditions.register(HasAnyTagConfig)
def from_config(cls, config: HasAnyTagConfig): @staticmethod
return cls(tags=config.tags) def from_config(config: HasAnyTagConfig):
return HasAnyTag(tags=config.tags)
condition_registry.register(HasAnyTagConfig, HasAnyTag)
Operator = Literal["gt", "gte", "lt", "lte", "eq"] Operator = Literal["gt", "gte", "lt", "lte", "eq"]
@ -134,12 +129,10 @@ class Duration:
return self._comparator(duration) return self._comparator(duration)
@classmethod @conditions.register(DurationConfig)
def from_config(cls, config: DurationConfig): @staticmethod
return cls(operator=config.operator, seconds=config.seconds) def from_config(config: DurationConfig):
return Duration(operator=config.operator, seconds=config.seconds)
condition_registry.register(DurationConfig, Duration)
class FrequencyConfig(BaseConfig): class FrequencyConfig(BaseConfig):
@ -181,18 +174,16 @@ class Frequency:
return self._comparator(high_freq) return self._comparator(high_freq)
@classmethod @conditions.register(FrequencyConfig)
def from_config(cls, config: FrequencyConfig): @staticmethod
return cls( def from_config(config: FrequencyConfig):
return Frequency(
operator=config.operator, operator=config.operator,
boundary=config.boundary, boundary=config.boundary,
hertz=config.hertz, hertz=config.hertz,
) )
condition_registry.register(FrequencyConfig, Frequency)
class AllOfConfig(BaseConfig): class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of" name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"] conditions: Sequence["SoundEventConditionConfig"]
@ -207,15 +198,13 @@ class AllOf:
) -> bool: ) -> bool:
return all(c(sound_event_annotation) for c in self.conditions) return all(c(sound_event_annotation) for c in self.conditions)
@classmethod @conditions.register(AllOfConfig)
def from_config(cls, config: AllOfConfig): @staticmethod
def from_config(config: AllOfConfig):
conditions = [ conditions = [
build_sound_event_condition(cond) for cond in config.conditions build_sound_event_condition(cond) for cond in config.conditions
] ]
return cls(conditions) return AllOf(conditions)
condition_registry.register(AllOfConfig, AllOf)
class AnyOfConfig(BaseConfig): class AnyOfConfig(BaseConfig):
@ -232,15 +221,13 @@ class AnyOf:
) -> bool: ) -> bool:
return any(c(sound_event_annotation) for c in self.conditions) return any(c(sound_event_annotation) for c in self.conditions)
@classmethod @conditions.register(AnyOfConfig)
def from_config(cls, config: AnyOfConfig): @staticmethod
def from_config(config: AnyOfConfig):
conditions = [ conditions = [
build_sound_event_condition(cond) for cond in config.conditions build_sound_event_condition(cond) for cond in config.conditions
] ]
return cls(conditions) return AnyOf(conditions)
condition_registry.register(AnyOfConfig, AnyOf)
class NotConfig(BaseConfig): class NotConfig(BaseConfig):
@ -257,14 +244,13 @@ class Not:
) -> bool: ) -> bool:
return not self.condition(sound_event_annotation) return not self.condition(sound_event_annotation)
@classmethod @conditions.register(NotConfig)
def from_config(cls, config: NotConfig): @staticmethod
def from_config(config: NotConfig):
condition = build_sound_event_condition(config.condition) condition = build_sound_event_condition(config.condition)
return cls(condition) return Not(condition)
condition_registry.register(NotConfig, Not)
SoundEventConditionConfig = Annotated[ SoundEventConditionConfig = Annotated[
Union[ Union[
HasTagConfig, HasTagConfig,
@ -283,7 +269,7 @@ SoundEventConditionConfig = Annotated[
def build_sound_event_condition( def build_sound_event_condition(
config: SoundEventConditionConfig, config: SoundEventConditionConfig,
) -> SoundEventCondition: ) -> SoundEventCondition:
return condition_registry.build(config) return conditions.build(config)
def filter_clip_annotation( def filter_clip_annotation(

View File

@ -17,7 +17,7 @@ SoundEventTransform = Callable[
data.SoundEventAnnotation, data.SoundEventAnnotation,
] ]
transform_registry: Registry[SoundEventTransform, []] = Registry("transform") transforms: Registry[SoundEventTransform, []] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig): class SetFrequencyBoundConfig(BaseConfig):
@ -63,12 +63,10 @@ class SetFrequencyBound:
update=dict(sound_event=sound_event) update=dict(sound_event=sound_event)
) )
@classmethod @transforms.register(SetFrequencyBoundConfig)
def from_config(cls, config: SetFrequencyBoundConfig): @staticmethod
return cls(hertz=config.hertz, boundary=config.boundary) def from_config(config: SetFrequencyBoundConfig):
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
class ApplyIfConfig(BaseConfig): class ApplyIfConfig(BaseConfig):
@ -95,14 +93,12 @@ class ApplyIf:
return self.transform(sound_event_annotation) return self.transform(sound_event_annotation)
@classmethod @transforms.register(ApplyIfConfig)
def from_config(cls, config: ApplyIfConfig): @staticmethod
def from_config(config: ApplyIfConfig):
transform = build_sound_event_transform(config.transform) transform = build_sound_event_transform(config.transform)
condition = build_sound_event_condition(config.condition) condition = build_sound_event_condition(config.condition)
return cls(condition=condition, transform=transform) return ApplyIf(condition=condition, transform=transform)
transform_registry.register(ApplyIfConfig, ApplyIf)
class ReplaceTagConfig(BaseConfig): class ReplaceTagConfig(BaseConfig):
@ -134,12 +130,12 @@ class ReplaceTag:
return sound_event_annotation.model_copy(update=dict(tags=tags)) return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod @transforms.register(ReplaceTagConfig)
def from_config(cls, config: ReplaceTagConfig): @staticmethod
return cls(original=config.original, replacement=config.replacement) def from_config(config: ReplaceTagConfig):
return ReplaceTag(
original=config.original, replacement=config.replacement
transform_registry.register(ReplaceTagConfig, ReplaceTag) )
class MapTagValueConfig(BaseConfig): class MapTagValueConfig(BaseConfig):
@ -189,18 +185,16 @@ class MapTagValue:
return sound_event_annotation.model_copy(update=dict(tags=tags)) return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod @transforms.register(MapTagValueConfig)
def from_config(cls, config: MapTagValueConfig): @staticmethod
return cls( def from_config(config: MapTagValueConfig):
return MapTagValue(
tag_key=config.tag_key, tag_key=config.tag_key,
value_mapping=config.value_mapping, value_mapping=config.value_mapping,
target_key=config.target_key, target_key=config.target_key,
) )
transform_registry.register(MapTagValueConfig, MapTagValue)
class ApplyAllConfig(BaseConfig): class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all" name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list) steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
@ -219,14 +213,13 @@ class ApplyAll:
return sound_event_annotation return sound_event_annotation
@classmethod @transforms.register(ApplyAllConfig)
def from_config(cls, config: ApplyAllConfig): @staticmethod
def from_config(config: ApplyAllConfig):
steps = [build_sound_event_transform(step) for step in config.steps] steps = [build_sound_event_transform(step) for step in config.steps]
return cls(steps) return ApplyAll(steps)
transform_registry.register(ApplyAllConfig, ApplyAll)
SoundEventTransformConfig = Annotated[ SoundEventTransformConfig = Annotated[
Union[ Union[
SetFrequencyBoundConfig, SetFrequencyBoundConfig,
@ -242,7 +235,7 @@ SoundEventTransformConfig = Annotated[
def build_sound_event_transform( def build_sound_event_transform(
config: SoundEventTransformConfig, config: SoundEventTransformConfig,
) -> SoundEventTransform: ) -> SoundEventTransform:
return transform_registry.build(config) return transforms.build(config)
def transform_clip_annotation( def transform_clip_annotation(

View File

@ -1,11 +1,14 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate from batdetect2.evaluate.evaluate import evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
"load_evaluation_config",
"evaluate",
"Evaluator", "Evaluator",
"TaskConfig",
"build_evaluator", "build_evaluator",
"build_task",
"evaluate",
"load_evaluation_config",
] ]

View File

@ -3,6 +3,7 @@ from typing import Annotated, Literal, Optional, Union
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from soundevent.geometry import compute_interval_overlap
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import Registry
@ -27,12 +28,10 @@ class TimeAffinity(AffinityFunction):
geometry1, geometry2, time_buffer=self.time_buffer geometry1, geometry2, time_buffer=self.time_buffer
) )
@classmethod @affinity_functions.register(TimeAffinityConfig)
def from_config(cls, config: TimeAffinityConfig): @staticmethod
return cls(time_buffer=config.time_buffer) def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer)
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
def compute_timestamp_affinity( def compute_timestamp_affinity(
@ -73,12 +72,10 @@ class IntervalIOU(AffinityFunction):
time_buffer=self.time_buffer, time_buffer=self.time_buffer,
) )
@classmethod @affinity_functions.register(IntervalIOUConfig)
def from_config(cls, config: IntervalIOUConfig): @staticmethod
return cls(time_buffer=config.time_buffer) def from_config(config: IntervalIOUConfig):
return IntervalIOU(time_buffer=config.time_buffer)
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
def compute_interval_iou( def compute_interval_iou(
@ -97,9 +94,11 @@ def compute_interval_iou(
end_time1 += time_buffer end_time1 += time_buffer
end_time2 += time_buffer end_time2 += time_buffer
intersection = max( intersection = compute_interval_overlap(
0, min(end_time1, end_time2) - max(start_time1, start_time2) (start_time1, end_time1),
(start_time2, end_time2),
) )
union = ( union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection (end_time1 - start_time1) + (end_time2 - start_time2) - intersection
) )
@ -110,6 +109,86 @@ def compute_interval_iou(
return intersection / union return intersection / union
class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float):
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
if not isinstance(geometry1, data.BoundingBox):
raise TypeError(
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
)
if not isinstance(geometry2, data.BoundingBox):
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
@affinity_functions.register(BBoxIOUConfig)
@staticmethod
def from_config(config: BBoxIOUConfig):
return BBoxIOU(
time_buffer=config.time_buffer,
freq_buffer=config.freq_buffer,
)
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig): class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou" name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01 time_buffer: float = 0.01
@ -127,17 +206,17 @@ class GeometricIOU(AffinityFunction):
time_buffer=self.time_buffer, time_buffer=self.time_buffer,
) )
@classmethod @affinity_functions.register(GeometricIOUConfig)
def from_config(cls, config: GeometricIOUConfig): @staticmethod
return cls(time_buffer=config.time_buffer) def from_config(config: GeometricIOUConfig):
return GeometricIOU(time_buffer=config.time_buffer)
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
AffinityConfig = Annotated[ AffinityConfig = Annotated[
Union[ Union[
TimeAffinityConfig, TimeAffinityConfig,
IntervalIOUConfig, IntervalIOUConfig,
BBoxIOUConfig,
GeometricIOUConfig, GeometricIOUConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),

View File

@ -4,13 +4,11 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig from batdetect2.evaluate.tasks import (
from batdetect2.evaluate.metrics import ( TaskConfig,
ClassificationAPConfig,
DetectionAPConfig,
MetricConfig,
) )
from batdetect2.evaluate.plots import PlotConfig from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [ __all__ = [
@ -20,15 +18,12 @@ __all__ = [
class EvaluationConfig(BaseConfig): class EvaluationConfig(BaseConfig):
ignore_start_end: float = 0.01 tasks: List[TaskConfig] = Field(
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
metrics: List[MetricConfig] = Field(
default_factory=lambda: [ default_factory=lambda: [
DetectionAPConfig(), DetectionTaskConfig(),
ClassificationAPConfig(), ClassificationTaskConfig(),
] ]
) )
plots: List[PlotConfig] = Field(default_factory=list)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)

View File

@ -1,23 +1,12 @@
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from matplotlib.figure import Figure from matplotlib.figure import Figure
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.match import build_matcher, match from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.metrics import build_metric
from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing.evaluate import ( from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
ClipEvaluation,
EvaluatorProtocol,
MatcherProtocol,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"Evaluator", "Evaluator",
@ -28,146 +17,51 @@ __all__ = [
class Evaluator: class Evaluator:
def __init__( def __init__(
self, self,
config: EvaluationConfig,
targets: TargetProtocol, targets: TargetProtocol,
matcher: MatcherProtocol, tasks: Sequence[EvaluatorProtocol],
metrics: List[MetricsProtocol],
plots: List[PlotterProtocol],
): ):
self.config = config
self.targets = targets self.targets = targets
self.matcher = matcher self.tasks = tasks
self.metrics = metrics
self.plots = plots
def match(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEvaluation:
clip = clip_annotation.clip
ground_truth = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
predictions = [
prediction
for prediction in predictions
if self.filter_predictions(prediction, clip)
]
return ClipEvaluation(
clip=clip_annotation.clip,
matches=match(
ground_truth,
predictions,
clip=clip,
targets=self.targets,
matcher=self.matcher,
),
)
def filter_sound_event_annotations(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.config.ignore_start_end,
)
def filter_predictions(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.config.ignore_start_end,
)
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]], predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]: ) -> List[Any]:
if len(clip_annotations) != len(predictions):
raise ValueError(
"Number of annotated clips and sets of predictions do not match"
)
return [ return [
self.match(clip_annotation, preds) task.evaluate(clip_annotations, predictions) for task in self.tasks
for clip_annotation, preds in zip(clip_annotations, predictions)
] ]
def compute_metrics( def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
self,
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, float]:
results = {} results = {}
for metric in self.metrics: for task, outputs in zip(self.tasks, eval_outputs):
results.update(metric(clip_evaluations)) results.update(task.compute_metrics(outputs))
return results return results
def generate_plots( def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation] self,
eval_outputs: List[Any],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
for plotter in self.plots: for task, outputs in zip(self.tasks, eval_outputs):
for name, fig in plotter(clip_evaluations): for name, fig in task.generate_plots(outputs):
yield name, fig yield name, fig
def build_evaluator( def build_evaluator(
config: Optional[EvaluationConfig] = None, config: Optional[Union[EvaluationConfig, dict]] = None,
targets: Optional[TargetProtocol] = None, targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None,
) -> EvaluatorProtocol: ) -> EvaluatorProtocol:
config = config or EvaluationConfig()
targets = targets or build_targets() targets = targets or build_targets()
matcher = matcher or build_matcher(config.match_strategy)
if metrics is None: if config is None:
metrics = [ config = EvaluationConfig()
build_metric(config, targets.class_names)
for config in config.metrics
]
if plots is None: if not isinstance(config, EvaluationConfig):
plots = [ config = EvaluationConfig.model_validate(config)
build_plotter(config, targets.class_names)
for config in config.plots
]
return Evaluator( return Evaluator(
config=config,
targets=targets, targets=targets,
matcher=matcher, tasks=[build_task(task, targets=targets) for task in config.tasks],
metrics=metrics,
plots=plots,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
) )

View File

@ -4,11 +4,10 @@ from lightning import LightningModule
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.tables import FullEvaluationTable from batdetect2.logging import get_image_logger
from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol from batdetect2.typing import ClipMatches, EvaluatorProtocol
class EvaluationModule(LightningModule): class EvaluationModule(LightningModule):
@ -54,18 +53,8 @@ class EvaluationModule(LightningModule):
def on_test_epoch_end(self): def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations) self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations) self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations)
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]): def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None:
return
df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
plotter = get_image_logger(self.logger) # type: ignore plotter = get_image_logger(self.logger) # type: ignore
if plotter is None: if plotter is None:
@ -74,7 +63,7 @@ class EvaluationModule(LightningModule):
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step) plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]): def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
metrics = self.evaluator.compute_metrics(evaluated_clips) metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics) self.log_dict(metrics)

View File

@ -8,8 +8,7 @@ from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig from batdetect2.core import BaseConfig, Registry
from batdetect2.core.registries import Registry
from batdetect2.evaluate.affinity import ( from batdetect2.evaluate.affinity import (
AffinityConfig, AffinityConfig,
GeometricIOUConfig, GeometricIOUConfig,
@ -17,11 +16,13 @@ from batdetect2.evaluate.affinity import (
) )
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing import ( from batdetect2.typing import (
AffinityFunction,
MatcherProtocol,
MatchEvaluation, MatchEvaluation,
RawPrediction,
TargetProtocol, TargetProtocol,
) )
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol from batdetect2.typing.evaluate import ClipMatches
from batdetect2.typing.postprocess import RawPrediction
MatchingGeometry = Literal["bbox", "interval", "timestamp"] MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching.""" """The geometry representation to use for matching."""
@ -33,9 +34,10 @@ def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation], sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction], raw_predictions: Sequence[RawPrediction],
clip: data.Clip, clip: data.Clip,
scores: Optional[Sequence[float]] = None,
targets: Optional[TargetProtocol] = None, targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None, matcher: Optional[MatcherProtocol] = None,
) -> List[MatchEvaluation]: ) -> ClipMatches:
if matcher is None: if matcher is None:
matcher = build_matcher() matcher = build_matcher()
@ -51,9 +53,11 @@ def match(
raw_prediction.geometry for raw_prediction in raw_predictions raw_prediction.geometry for raw_prediction in raw_predictions
] ]
scores = [ if scores is None:
raw_prediction.detection_score for raw_prediction in raw_predictions scores = [
] raw_prediction.detection_score
for raw_prediction in raw_predictions
]
matches = [] matches = []
@ -73,9 +77,11 @@ def match(
gt_det = target_idx is not None gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0 pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = ( pred_geometry = (
predicted_geometries[source_idx] predicted_geometries[source_idx]
if source_idx is not None if source_idx is not None
@ -84,7 +90,7 @@ def match(
class_scores = ( class_scores = (
{ {
str(class_name): float(score) class_name: score
for class_name, score in zip( for class_name, score in zip(
targets.class_names, targets.class_names,
prediction.class_scores, prediction.class_scores,
@ -100,6 +106,7 @@ def match(
sound_event_annotation=target, sound_event_annotation=target,
gt_det=gt_det, gt_det=gt_det,
gt_class=gt_class, gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score, pred_score=pred_score,
pred_class_scores=class_scores, pred_class_scores=class_scores,
pred_geometry=pred_geometry, pred_geometry=pred_geometry,
@ -107,7 +114,7 @@ def match(
) )
) )
return matches return ClipMatches(clip=clip, matches=matches)
class StartTimeMatchConfig(BaseConfig): class StartTimeMatchConfig(BaseConfig):
@ -132,12 +139,10 @@ class StartTimeMatcher(MatcherProtocol):
distance_threshold=self.distance_threshold, distance_threshold=self.distance_threshold,
) )
@classmethod @matching_strategies.register(StartTimeMatchConfig)
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher": @staticmethod
return cls(distance_threshold=config.distance_threshold) def from_config(config: StartTimeMatchConfig):
return StartTimeMatcher(distance_threshold=config.distance_threshold)
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
def match_start_times( def match_start_times(
@ -264,19 +269,17 @@ class GreedyMatcher(MatcherProtocol):
affinity_threshold=self.affinity_threshold, affinity_threshold=self.affinity_threshold,
) )
@classmethod @matching_strategies.register(GreedyMatchConfig)
def from_config(cls, config: GreedyMatchConfig): @staticmethod
def from_config(config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function) affinity_function = build_affinity_function(config.affinity_function)
return cls( return GreedyMatcher(
geometry=config.geometry, geometry=config.geometry,
affinity_threshold=config.affinity_threshold, affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function, affinity_function=affinity_function,
) )
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
def greedy_match( def greedy_match(
ground_truth: Sequence[data.Geometry], ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry], predictions: Sequence[data.Geometry],
@ -313,21 +316,21 @@ def greedy_match(
unassigned_gt = set(range(len(ground_truth))) unassigned_gt = set(range(len(ground_truth)))
if not predictions: if not predictions:
for target_idx in range(len(ground_truth)): for gt_idx in range(len(ground_truth)):
yield None, target_idx, 0 yield None, gt_idx, 0
return return
if not ground_truth: if not ground_truth:
for source_idx in range(len(predictions)): for pred_idx in range(len(predictions)):
yield source_idx, None, 0 yield pred_idx, None, 0
return return
indices = np.argsort(scores)[::-1] indices = np.argsort(scores)[::-1]
for source_idx in indices: for pred_idx in indices:
source_geometry = predictions[source_idx] source_geometry = predictions[pred_idx]
affinities = np.array( affinities = np.array(
[ [
@ -340,18 +343,18 @@ def greedy_match(
affinity = affinities[closest_target] affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold: if affinities[closest_target] <= affinity_threshold:
yield source_idx, None, 0 yield pred_idx, None, 0
continue continue
if closest_target not in unassigned_gt: if closest_target not in unassigned_gt:
yield source_idx, None, 0 yield pred_idx, None, 0
continue continue
unassigned_gt.remove(closest_target) unassigned_gt.remove(closest_target)
yield source_idx, closest_target, affinity yield pred_idx, closest_target, affinity
for target_idx in unassigned_gt: for gt_idx in unassigned_gt:
yield None, target_idx, 0 yield None, gt_idx, 0
class OptimalMatchConfig(BaseConfig): class OptimalMatchConfig(BaseConfig):
@ -386,17 +389,16 @@ class OptimalMatcher(MatcherProtocol):
affinity_threshold=self.affinity_threshold, affinity_threshold=self.affinity_threshold,
) )
@classmethod @matching_strategies.register(OptimalMatchConfig)
def from_config(cls, config: OptimalMatchConfig): @staticmethod
return cls( def from_config(config: OptimalMatchConfig):
return OptimalMatcher(
affinity_threshold=config.affinity_threshold, affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer, time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer, frequency_buffer=config.frequency_buffer,
) )
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
MatchConfig = Annotated[ MatchConfig = Annotated[
Union[ Union[
GreedyMatchConfig, GreedyMatchConfig,

View File

@ -1,712 +0,0 @@
from collections import defaultdict
from collections.abc import Callable, Mapping
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics, preprocessing
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation, MetricsProtocol
__all__ = ["DetectionAP", "ClassificationAP"]
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
APImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class DetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(self.metric(y_true, y_score))
return {"detection_AP": score}
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(DetectionAPConfig, DetectionAP)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["detection_roc_auc"] = "detection_roc_auc"
class DetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(metrics.roc_auc_score(y_true, y_score))
return {"detection_ROC_AUC": score}
@classmethod
def from_config(
cls, config: DetectionROCAUCConfig, class_names: List[str]
):
return cls()
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation = "pascal_voc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_mAP": float(mean_ap),
**{
f"classification_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationAPConfig,
class_names: List[str],
):
return cls(
class_names,
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
class ClassificationROCAUCConfig(BaseConfig):
name: Literal["classification_roc_auc"] = "classification_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
**{
f"classification_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationROCAUCConfig,
class_names: List[str],
):
return cls(
class_names,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
class TopClassAPConfig(BaseConfig):
name: Literal["top_class_ap"] = "top_class_ap"
ap_implementation: APImplementation = "pascal_voc"
class TopClassAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
top_class = match.pred_class
y_true.append(top_class == match.gt_class)
y_score.append(match.pred_class_score)
score = float(self.metric(y_true, y_score))
return {"top_class_AP": score}
@classmethod
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(TopClassAPConfig, TopClassAP)
class ClassificationBalancedAccuracyConfig(BaseConfig):
name: Literal["classification_balanced_accuracy"] = (
"classification_balanced_accuracy"
)
class ClassificationBalancedAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
top_class = match.pred_class
# Focus on matches
if match.gt_class is None or top_class is None:
continue
y_true.append(self.class_names.index(match.gt_class))
y_pred.append(self.class_names.index(top_class))
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
return {"classification_balanced_accuracy": score}
@classmethod
def from_config(
cls,
config: ClassificationBalancedAccuracyConfig,
class_names: List[str],
):
return cls(class_names)
metrics_registry.register(
ClassificationBalancedAccuracyConfig,
ClassificationBalancedAccuracy,
)
class ClipDetectionAPConfig(BaseConfig):
name: Literal["clip_detection_ap"] = "clip_detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class ClipDetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {"clip_detection_ap": self.metric(y_true, y_score)}
@classmethod
def from_config(
cls,
config: ClipDetectionAPConfig,
class_names: List[str],
):
return cls(implementation=config.ap_implementation)
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
class ClipDetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
}
@classmethod
def from_config(
cls,
config: ClipDetectionROCAUCConfig,
class_names: List[str],
):
return cls()
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
class ClipMulticlassAPConfig(BaseConfig):
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation,
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get max score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_mAP": float(mean_ap),
**{
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls, config: ClipMulticlassAPConfig, class_names: List[str]
):
return cls(
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
class ClipMulticlassROCAUCConfig(BaseConfig):
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get maximum score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
**{
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClipMulticlassROCAUCConfig,
class_names: List[str],
):
return cls(
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
MetricConfig = Annotated[
Union[
DetectionAPConfig,
DetectionROCAUCConfig,
ClassificationAPConfig,
ClassificationROCAUCConfig,
TopClassAPConfig,
ClassificationBalancedAccuracyConfig,
ClipDetectionAPConfig,
ClipDetectionROCAUCConfig,
ClipMulticlassAPConfig,
ClipMulticlassROCAUCConfig,
],
Field(discriminator="name"),
]
def build_metric(config: MetricConfig, class_names: List[str]):
return metrics_registry.build(config, class_names)
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}

View File

@ -0,0 +1,267 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction, TargetProtocol
__all__ = [
"ClassificationMetric",
"ClassificationMetricConfig",
"build_classification_metric",
]
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_prediction: bool
is_ground_truth: bool
is_generic: bool
true_class: Optional[str]
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: Mapping[str, List[MatchEval]]
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
Registry("classification_metric")
)
class BaseClassificationConfig(BaseConfig):
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class BaseClassificationMetric:
def __init__(
self,
targets: TargetProtocol,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.include = include
self.exclude = exclude
def include_class(self, class_name: str) -> bool:
if self.include is not None:
return class_name in self.include
if self.exclude is not None:
return class_name not in self.exclude
return True
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
name: Literal["average_precision"] = "average_precision"
ignore_non_predictions: bool = True
ignore_generic: bool = True
label: str = "average_precision"
class ClassificationAveragePrecision(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "average_precision",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: average_precision(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])
)
return {
f"mean_{self.label}": mean_score,
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if self.include_class(class_name)
},
}
@classification_metrics.register(ClassificationAveragePrecisionConfig)
@staticmethod
def from_config(
config: ClassificationAveragePrecisionConfig,
targets: TargetProtocol,
):
return ClassificationAveragePrecision(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
include=config.include,
exclude=config.exclude,
)
class ClassificationROCAUCConfig(BaseClassificationConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ClassificationROCAUC(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "roc_auc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
self.include = include
self.exclude = exclude
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: float(
metrics.roc_auc_score(
y_true[class_name],
y_score[class_name],
)
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])
)
return {
f"mean_{self.label}": mean_score,
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if self.include_class(class_name)
},
}
@classification_metrics.register(ClassificationROCAUCConfig)
@staticmethod
def from_config(
config: ClassificationROCAUCConfig, targets: TargetProtocol
):
return ClassificationROCAUC(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
)
ClassificationMetricConfig = Annotated[
Union[
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
Field(discriminator="name"),
]
def build_classification_metric(
config: ClassificationMetricConfig,
targets: TargetProtocol,
) -> ClassificationMetric:
return classification_metrics.build(config, targets)
def _extract_per_class_metric_data(
clip_evaluations: Sequence[ClipEval],
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
y_true = defaultdict(list)
y_score = defaultdict(list)
num_positives = defaultdict(lambda: 0)
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and ignore_generic:
continue
is_class = m.true_class == class_name
if is_class:
num_positives[class_name] += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and ignore_non_predictions:
continue
y_true[class_name].append(is_class)
y_score[class_name].append(m.score)
return y_true, y_score, num_positives

View File

@ -0,0 +1,135 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.metrics.common import average_precision
@dataclass
class ClipEval:
true_classes: Set[str]
class_scores: Dict[str, float]
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
"clip_classification_metric"
)
class ClipClassificationAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
class ClipClassificationAveragePrecision:
def __init__(self, label: str = "average_precision"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
for clip_eval in clip_evaluations:
for class_name, score in clip_eval.class_scores.items():
y_true[class_name].append(class_name in clip_eval.true_classes)
y_score[class_name].append(score)
class_scores = {
class_name: float(
average_precision(
y_true=y_true[class_name],
y_score=y_score[class_name],
)
)
for class_name in y_true
}
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
return {
f"mean_{self.label}": float(mean),
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if not np.isnan(score)
},
}
@clip_classification_metrics.register(
ClipClassificationAveragePrecisionConfig
)
@staticmethod
def from_config(config: ClipClassificationAveragePrecisionConfig):
return ClipClassificationAveragePrecision(label=config.label)
class ClipClassificationROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
class ClipClassificationROCAUC:
def __init__(self, label: str = "roc_auc"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
for clip_eval in clip_evaluations:
for class_name, score in clip_eval.class_scores.items():
y_true[class_name].append(class_name in clip_eval.true_classes)
y_score[class_name].append(score)
class_scores = {
class_name: float(
metrics.roc_auc_score(
y_true=y_true[class_name],
y_score=y_score[class_name],
)
)
for class_name in y_true
}
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
return {
f"mean_{self.label}": float(mean),
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if not np.isnan(score)
},
}
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
@staticmethod
def from_config(config: ClipClassificationROCAUCConfig):
return ClipClassificationROCAUC(label=config.label)
ClipClassificationMetricConfig = Annotated[
Union[
ClipClassificationAveragePrecisionConfig,
ClipClassificationROCAUCConfig,
],
Field(discriminator="name"),
]
def build_clip_metric(config: ClipClassificationMetricConfig):
return clip_classification_metrics.build(config)

View File

@ -0,0 +1,173 @@
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.metrics.common import average_precision
@dataclass
class ClipEval:
gt_det: bool
score: float
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
"clip_detection_metric"
)
class ClipDetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
class ClipDetectionAveragePrecision:
def __init__(self, label: str = "average_precision"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.score)
score = average_precision(y_true=y_true, y_score=y_score)
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
@staticmethod
def from_config(config: ClipDetectionAveragePrecisionConfig):
return ClipDetectionAveragePrecision(label=config.label)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
class ClipDetectionROCAUC:
def __init__(self, label: str = "roc_auc"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.score)
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
@staticmethod
def from_config(config: ClipDetectionROCAUCConfig):
return ClipDetectionROCAUC(label=config.label)
class ClipDetectionRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
threshold: float = 0.5
label: str = "recall"
class ClipDetectionRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
if clip_eval.gt_det:
num_positives += 1
if clip_eval.score >= self.threshold and clip_eval.gt_det:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionRecallConfig)
@staticmethod
def from_config(config: ClipDetectionRecallConfig):
return ClipDetectionRecall(
threshold=config.threshold, label=config.label
)
class ClipDetectionPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
threshold: float = 0.5
label: str = "precision"
class ClipDetectionPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
if clip_eval.score >= self.threshold:
num_detections += 1
if clip_eval.score >= self.threshold and clip_eval.gt_det:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
@staticmethod
def from_config(config: ClipDetectionPrecisionConfig):
return ClipDetectionPrecision(
threshold=config.threshold, label=config.label
)
ClipDetectionMetricConfig = Annotated[
Union[
ClipDetectionAveragePrecisionConfig,
ClipDetectionROCAUCConfig,
ClipDetectionRecallConfig,
ClipDetectionPrecisionConfig,
],
Field(discriminator="name"),
]
def build_clip_metric(config: ClipDetectionMetricConfig):
return clip_detection_metrics.build(config)

View File

@ -0,0 +1,60 @@
from typing import Optional, Tuple
import numpy as np
__all__ = [
"compute_precision_recall",
"average_precision",
]
def compute_precision_recall(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true)
y_score = np.array(y_score)
if num_positives is None:
num_positives = y_true.sum()
# Sort by score
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
y_score_sorted = y_score[sort_ind]
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
return precision, recall, y_score_sorted
def average_precision(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> float:
precision, recall, _ = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec

View File

@ -0,0 +1,226 @@
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
__all__ = [
"DetectionMetricConfig",
"DetectionMetric",
"build_detection_metric",
]
@dataclass
class MatchEval:
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_prediction: bool
is_ground_truth: bool
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: List[MatchEval]
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
class DetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
ignore_non_predictions: bool = True
class DetectionAveragePrecision:
def __init__(self, label: str, ignore_non_predictions: bool = True):
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
ap = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: ap}
@detection_metrics.register(DetectionAveragePrecisionConfig)
@staticmethod
def from_config(config: DetectionAveragePrecisionConfig):
return DetectionAveragePrecision(
label=config.label,
ignore_non_predictions=config.ignore_non_predictions,
)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
ignore_non_predictions: bool = True
class DetectionROCAUC:
def __init__(
self,
label: str = "roc_auc",
ignore_non_predictions: bool = True,
):
self.label = label
self.ignore_non_predictions = ignore_non_predictions
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
y_true: List[bool] = []
y_score: List[float] = []
for clip_eval in clip_evals:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}
@detection_metrics.register(DetectionROCAUCConfig)
@staticmethod
def from_config(config: DetectionROCAUCConfig):
return DetectionROCAUC(
label=config.label,
ignore_non_predictions=config.ignore_non_predictions,
)
class DetectionRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
label: str = "recall"
threshold: float = 0.5
class DetectionRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.label = label
self.threshold = threshold
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_ground_truth:
num_positives += 1
if m.score >= self.threshold and m.is_ground_truth:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@detection_metrics.register(DetectionRecallConfig)
@staticmethod
def from_config(config: DetectionRecallConfig):
return DetectionRecall(threshold=config.threshold, label=config.label)
class DetectionPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
label: str = "precision"
threshold: float = 0.5
class DetectionPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_detection = m.score >= self.threshold
if is_detection:
num_detections += 1
if is_detection and m.is_ground_truth:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@detection_metrics.register(DetectionPrecisionConfig)
@staticmethod
def from_config(config: DetectionPrecisionConfig):
return DetectionPrecision(
threshold=config.threshold,
label=config.label,
)
DetectionMetricConfig = Annotated[
Union[
DetectionAveragePrecisionConfig,
DetectionROCAUCConfig,
DetectionRecallConfig,
DetectionPrecisionConfig,
],
Field(discriminator="name"),
]
def build_detection_metric(config: DetectionMetricConfig):
return detection_metrics.build(config)

View File

@ -0,0 +1,314 @@
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics, preprocessing
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
__all__ = [
"TopClassMetricConfig",
"TopClassMetric",
"build_top_class_metric",
]
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_ground_truth: bool
is_generic: bool
is_prediction: bool
pred_class: Optional[str]
true_class: Optional[str]
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: List[MatchEval]
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
class TopClassAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
ignore_generic: bool = True
ignore_non_predictions: bool = True
class TopClassAveragePrecision:
def __init__(
self,
ignore_generic: bool = True,
ignore_non_predictions: bool = True,
label: str = "average_precision",
):
self.ignore_generic = ignore_generic
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
score = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: score}
@top_class_metrics.register(TopClassAveragePrecisionConfig)
@staticmethod
def from_config(config: TopClassAveragePrecisionConfig):
return TopClassAveragePrecision(
ignore_generic=config.ignore_generic,
label=config.label,
)
class TopClassROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
ignore_generic: bool = True
ignore_non_predictions: bool = True
label: str = "roc_auc"
class TopClassROCAUC:
def __init__(
self,
ignore_generic: bool = True,
ignore_non_predictions: bool = True,
label: str = "roc_auc",
):
self.ignore_generic = ignore_generic
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
y_true: List[bool] = []
y_score: List[float] = []
for clip_eval in clip_evals:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}
@top_class_metrics.register(TopClassROCAUCConfig)
@staticmethod
def from_config(config: TopClassROCAUCConfig):
return TopClassROCAUC(
ignore_generic=config.ignore_generic,
label=config.label,
)
class TopClassRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
threshold: float = 0.5
label: str = "recall"
class TopClassRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_ground_truth:
num_positives += 1
if m.score >= self.threshold and m.pred_class == m.true_class:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@top_class_metrics.register(TopClassRecallConfig)
@staticmethod
def from_config(config: TopClassRecallConfig):
return TopClassRecall(
threshold=config.threshold,
label=config.label,
)
class TopClassPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
threshold: float = 0.5
label: str = "precision"
class TopClassPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_detection = m.score >= self.threshold
if is_detection:
num_detections += 1
if is_detection and m.pred_class == m.true_class:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@top_class_metrics.register(TopClassPrecisionConfig)
@staticmethod
def from_config(config: TopClassPrecisionConfig):
return TopClassPrecision(
threshold=config.threshold,
label=config.label,
)
class BalancedAccuracyConfig(BaseConfig):
name: Literal["balanced_accuracy"] = "balanced_accuracy"
label: str = "balanced_accuracy"
exclude_noise: bool = False
noise_class: str = "noise"
class BalancedAccuracy:
def __init__(
self,
exclude_noise: bool = True,
noise_class: str = "noise",
label: str = "balanced_accuracy",
):
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic:
# Ignore matches that correspond to a sound event
# with unknown class
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore predictions that were not matched to a
# ground truth
continue
if m.pred_class is None and self.exclude_noise:
# Ignore non-predictions
continue
y_true.append(m.true_class or self.noise_class)
y_pred.append(m.pred_class or self.noise_class)
encoder = preprocessing.LabelEncoder()
encoder.fit(list(set(y_true) | set(y_pred)))
y_true = encoder.transform(y_true)
y_pred = encoder.transform(y_pred)
score = metrics.balanced_accuracy_score(y_true, y_pred)
return {self.label: score}
@top_class_metrics.register(BalancedAccuracyConfig)
@staticmethod
def from_config(config: BalancedAccuracyConfig):
return BalancedAccuracy(
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
label=config.label,
)
TopClassMetricConfig = Annotated[
Union[
TopClassAveragePrecisionConfig,
TopClassROCAUCConfig,
TopClassRecallConfig,
TopClassPrecisionConfig,
BalancedAccuracyConfig,
],
Field(discriminator="name"),
]
def build_top_class_metric(config: TopClassMetricConfig):
return top_class_metrics.build(config)

View File

@ -1,560 +0,0 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import BaseConfig, Registry
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipEvaluation,
MatchEvaluation,
PlotterProtocol,
PreprocessorProtocol,
)
__all__ = [
"build_plotter",
"ExampleGallery",
"ExampleGalleryConfig",
]
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
class ExampleGalleryConfig(BaseConfig):
name: Literal["example_gallery"] = "example_gallery"
examples_per_class: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleGallery(PlotterProtocol):
def __init__(
self,
examples_per_class: int,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.examples_per_class = examples_per_class
self.preprocessor = preprocessor or build_preprocessor()
self.audio_loader = audio_loader or build_audio_loader()
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
per_class_matches = group_matches(clip_evaluations)
for class_name, matches in per_class_matches.items():
true_positives = get_binned_sample(
matches.true_positives,
n_examples=self.examples_per_class,
)
false_positives = get_binned_sample(
matches.false_positives,
n_examples=self.examples_per_class,
)
false_negatives = random.sample(
matches.false_negatives,
k=min(self.examples_per_class, len(matches.false_negatives)),
)
cross_triggers = get_binned_sample(
matches.cross_triggers,
n_examples=self.examples_per_class,
)
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.examples_per_class,
)
yield f"example_gallery/{class_name}", fig
plt.close(fig)
@classmethod
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
class ClipEvaluationPlotConfig(BaseConfig):
name: Literal["example_clip"] = "example_clip"
num_plots: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class PlotClipEvaluation(PlotterProtocol):
def __init__(
self,
num_plots: int = 3,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.num_plots = num_plots
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
examples = random.sample(
clip_evaluations,
k=min(self.num_plots, len(clip_evaluations)),
)
for index, clip_evaluation in enumerate(examples):
fig, ax = plt.subplots()
plot_matches(
clip_evaluation.matches,
clip=clip_evaluation.clip,
audio_loader=self.audio_loader,
ax=ax,
)
yield f"clip_evaluation/example_{index}", fig
plt.close(fig)
@classmethod
def from_config(
cls,
config: ClipEvaluationPlotConfig,
class_names: List[str],
):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
num_plots=config.num_plots,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
class DetectionPRCurveConfig(BaseConfig):
name: Literal["detection_pr_curve"] = "detection_pr_curve"
class DetectionPRCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(recall, precision, label="Detector")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend()
yield "detection_pr_curve", fig
@classmethod
def from_config(
cls,
config: DetectionPRCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
class ClassificationPRCurvesConfig(BaseConfig):
name: Literal["classification_pr_curves"] = "classification_pr_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationPRCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
precision, recall, _ = metrics.precision_recall_curve(
y_true_class,
y_pred_class,
)
ax.plot(recall, precision, label=class_name)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_pr_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationPRCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
class DetectionROCCurveConfig(BaseConfig):
name: Literal["detection_roc_curve"] = "detection_roc_curve"
class DetectionROCCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, label="Detection")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
yield "detection_roc_curve", fig
@classmethod
def from_config(
cls,
config: DetectionROCCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
class ClassificationROCCurvesConfig(BaseConfig):
name: Literal["classification_roc_curves"] = "classification_roc_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_roced_class = y_pred[:, class_index]
fpr, tpr, _ = metrics.roc_curve(
y_true_class,
y_roced_class,
)
ax.plot(fpr, tpr, label=class_name)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_roc_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationROCCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
class ConfusionMatrixConfig(BaseConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
background_class: str = "noise"
class ConfusionMatrix(PlotterProtocol):
def __init__(self, background_class: str, class_names: List[str]):
self.background_class = background_class
self.class_names = class_names
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else self.background_class
)
top_class = match.pred_class
y_pred.append(
top_class
if top_class is not None
else self.background_class
)
display = metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=[*self.class_names, self.background_class],
)
yield "confusion_matrix", display.figure_
@classmethod
def from_config(
cls,
config: ConfusionMatrixConfig,
class_names: List[str],
):
return cls(
background_class=config.background_class,
class_names=class_names,
)
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
PlotConfig = Annotated[
Union[
ExampleGalleryConfig,
ClipEvaluationPlotConfig,
DetectionPRCurveConfig,
ClassificationPRCurvesConfig,
DetectionROCCurveConfig,
ClassificationROCCurvesConfig,
ConfusionMatrixConfig,
],
Field(discriminator="name"),
]
def build_plotter(
config: PlotConfig, class_names: List[str]
) -> PlotterProtocol:
return plots_registry.build(config, class_names)
@dataclass
class ClassMatches:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -0,0 +1,54 @@
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol
class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
figsize: tuple[int, int] = (10, 10)
dpi: int = 100
class BasePlot:
def __init__(
self,
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None,
dpi: int = 100,
theme: str = "default",
):
self.targets = targets
self.label = label
self.figsize = figsize
self.dpi = dpi
self.theme = theme
self.title = title
def create_figure(self) -> Figure:
plt.style.use(self.theme)
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
if self.title is not None:
fig.suptitle(self.title)
return fig
@classmethod
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
return cls(
targets=targets,
figsize=config.figsize,
dpi=config.dpi,
theme=config.theme,
label=config.label,
title=config.title,
**kwargs,
)

View File

@ -0,0 +1,370 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.classification import (
ClipEval,
_extract_per_class_metric_data,
)
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
plot_threshold_precision_curve,
plot_threshold_precision_curves,
plot_threshold_recall_curve,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
ClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
Registry("classification_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdPrecisionCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_precision_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, _, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_precision_curve(
thresholds,
precision,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdPrecisionCurveConfig)
@staticmethod
def from_config(
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
):
return ThresholdPrecisionCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdRecallCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
yield self.label, fig
return
for class_name, (_, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_recall_curve(
thresholds,
recall,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdRecallCurveConfig)
@staticmethod
def from_config(
config: ThresholdRecallCurveConfig, targets: TargetProtocol
):
return ThresholdRecallCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: metrics.roc_curve(
y_true[class_name],
y_score[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPrecisionCurveConfig,
ThresholdRecallCurveConfig,
],
Field(discriminator="name"),
]
def build_classification_plotter(
config: ClassificationPlotConfig,
targets: TargetProtocol,
) -> ClassificationPlotter:
return classification_plots.build(config, targets)

View File

@ -0,0 +1,189 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipClassificationPlotConfig",
"ClipClassificationPlotter",
"build_clip_classification_plotter",
]
ClipClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_classification_plots: Registry[
ClipClassificationPlotter, [TargetProtocol]
] = Registry("clip_classification_plot")
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False
class PRCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
data[class_name] = (precision, recall, thresholds)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Classification ROC Curve"
separate_figures: bool = False
class ROCCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
data[class_name] = (fpr, tpr, thresholds)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)
ClipClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
],
Field(discriminator="name"),
]
def build_clip_classification_plotter(
config: ClipClassificationPlotConfig,
targets: TargetProtocol,
) -> ClipClassificationPlotter:
return clip_classification_plots.build(config, targets)

View File

@ -0,0 +1,163 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipDetectionPlotConfig",
"ClipDetectionPlotter",
"build_clip_detection_plotter",
]
ClipDetectionPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
Registry("clip_detection_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@clip_detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Detection ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@clip_detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fig = self.create_figure()
ax = fig.subplots()
df = pd.DataFrame({"is_true": y_true, "score": y_score})
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
yield self.label, fig
@clip_detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
)
ClipDetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
Field(discriminator="name"),
]
def build_clip_detection_plotter(
config: ClipDetectionPlotConfig,
targets: TargetProtocol,
) -> ClipDetectionPlotter:
return clip_detection_plots.build(config, targets)

View File

@ -0,0 +1,309 @@
import random
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
name="detection_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ScoreDistributionPlot(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
df = pd.DataFrame({"is_true": y_true, "score": y_score})
fig = self.create_figure()
ax = fig.subplots()
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
yield self.label, fig
@detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
title: Optional[str] = "Example Detection"
figsize: tuple[int, int] = (10, 4)
num_examples: int = 5
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleDetectionPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 5,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
sample = clip_evaluations
if self.num_examples < len(sample):
sample = random.sample(sample, self.num_examples)
for num_example, clip_eval in enumerate(sample):
fig = self.create_figure()
ax = fig.subplots()
plot_clip_detections(
clip_eval,
ax=ax,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
)
yield f"{self.label}/example_{num_example}", fig
plt.close(fig)
@detection_plots.register(ExampleDetectionPlotConfig)
@staticmethod
def from_config(
config: ExampleDetectionPlotConfig,
targets: TargetProtocol,
):
return ExampleDetectionPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
DetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
Field(discriminator="name"),
]
def build_detection_plotter(
config: DetectionPlotConfig,
targets: TargetProtocol,
) -> DetectionPlotter:
return detection_plots.build(config, targets)

View File

@ -0,0 +1,444 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Annotated,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@top_class_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@top_class_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
exclude_noise: bool = False
noise_class: str = "noise"
normalize: Literal["true", "pred", "all", "none"] = "true"
threshold: float = 0.2
add_colorbar: bool = True
cmap: str = "Blues"
class ConfusionMatrix(BasePlot):
def __init__(
self,
*args,
exclude_generic: bool = True,
exclude_noise: bool = False,
noise_class: str = "noise",
add_colorbar: bool = True,
normalize: Literal["true", "pred", "all", "none"] = "true",
cmap: str = "Blues",
threshold: float = 0.2,
**kwargs,
):
super().__init__(*args, **kwargs)
self.exclude_generic = exclude_generic
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.normalize = normalize
self.add_colorbar = add_colorbar
self.threshold = threshold
self.cmap = cmap
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and self.exclude_noise:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < self.threshold:
if self.exclude_noise:
continue
pred_class = self.noise_class
if m.is_generic:
if self.exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = self.targets.detection_class_name
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
fig = self.create_figure()
ax = fig.subplots()
class_names = [*self.targets.class_names]
if not self.exclude_generic:
class_names.append(self.targets.detection_class_name)
if not self.exclude_noise:
class_names.append(self.noise_class)
metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=class_names,
ax=ax,
xticks_rotation="vertical",
cmap=self.cmap,
colorbar=self.add_colorbar,
normalize=self.normalize if self.normalize != "none" else None,
values_format=".2f",
)
yield self.label, fig
@top_class_plots.register(ConfusionMatrixConfig)
@staticmethod
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
return ConfusionMatrix.build(
config=config,
targets=targets,
exclude_generic=config.exclude_generic,
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
add_colorbar=config.add_colorbar,
normalize=config.normalize,
cmap=config.cmap,
)
class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification"
label: str = "example_classification"
title: Optional[str] = "Example Classification"
num_examples: int = 4
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleClassificationPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 4,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
self.num_examples = num_examples
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample(
matches.true_positives,
n_examples=self.num_examples,
)
false_positives: List[MatchEval] = get_binned_sample(
matches.false_positives,
n_examples=self.num_examples,
)
false_negatives: List[MatchEval] = random.sample(
matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)),
)
cross_triggers: List[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples
)
fig = self.create_figure()
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.num_examples,
fig=fig,
)
if self.title is not None:
fig.suptitle(f"{self.title}: {class_name}")
else:
fig.suptitle(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@top_class_plots.register(ExampleClassificationPlotConfig)
@staticmethod
def from_config(
config: ExampleClassificationPlotConfig,
targets: TargetProtocol,
):
return ExampleClassificationPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
threshold=config.threshold,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
ExampleClassificationPlotConfig,
],
Field(discriminator="name"),
]
def build_top_class_plotter(
config: TopClassPlotConfig,
targets: TargetProtocol,
) -> TopClassPlotter:
return top_class_plots.build(config, targets)
@dataclass
class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list)
def group_matches(
clip_evals: Sequence[ClipEval],
threshold: float = 0.2,
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals:
for match in clip_eval.matches:
gt_class = match.true_class
pred_class = match.pred_class
is_pred = match.score >= threshold
if not is_pred and gt_class is not None:
class_examples[gt_class].false_negatives.append(match)
continue
if not is_pred:
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -5,9 +5,9 @@ from pydantic import Field
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation from batdetect2.typing import ClipMatches
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame] EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
tables_registry: Registry[EvaluationTableGenerator, []] = Registry( tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
@ -21,20 +21,18 @@ class FullEvaluationTableConfig(BaseConfig):
class FullEvaluationTable: class FullEvaluationTable:
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEvaluation] self, clip_evaluations: Sequence[ClipMatches]
) -> pd.DataFrame: ) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations) return extract_matches_dataframe(clip_evaluations)
@classmethod @tables_registry.register(FullEvaluationTableConfig)
def from_config(cls, config: FullEvaluationTableConfig): @staticmethod
return cls() def from_config(config: FullEvaluationTableConfig):
return FullEvaluationTable()
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
def extract_matches_dataframe( def extract_matches_dataframe(
clip_evaluations: Sequence[ClipEvaluation], clip_evaluations: Sequence[ClipMatches],
) -> pd.DataFrame: ) -> pd.DataFrame:
data = [] data = []
@ -78,8 +76,8 @@ def extract_matches_dataframe(
("gt", "low_freq"): gt_low_freq, ("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq, ("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score, ("pred", "score"): match.pred_score,
("pred", "class"): match.pred_class, ("pred", "class"): match.top_class,
("pred", "class_score"): match.pred_class_score, ("pred", "class_score"): match.top_class_score,
("pred", "start_time"): pred_start_time, ("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time, ("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq, ("pred", "low_freq"): pred_low_freq,

View File

@ -0,0 +1,39 @@
from typing import Annotated, Optional, Union
from pydantic import Field
from batdetect2.evaluate.tasks.base import tasks_registry
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.clip_classification import (
ClipClassificationTaskConfig,
)
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
__all__ = [
"TaskConfig",
"build_task",
]
TaskConfig = Annotated[
Union[
ClassificationTaskConfig,
DetectionTaskConfig,
ClipDetectionTaskConfig,
ClipClassificationTaskConfig,
TopClassDetectionTaskConfig,
],
Field(discriminator="name"),
]
def build_task(
config: TaskConfig,
targets: Optional[TargetProtocol] = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
return tasks_registry.build(config, targets)

View File

@ -0,0 +1,175 @@
from typing import (
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.match import (
MatchConfig,
StartTimeMatchConfig,
build_matcher,
)
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BaseTaskConfig",
"BaseTask",
]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
"tasks"
)
T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig):
prefix: str
ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol
matcher: MatcherProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
prefix: str
def __init__(
self,
matcher: MatcherProtocol,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str,
ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
):
self.matcher = matcher
self.metrics = metrics
self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end
def compute_metrics(
self,
eval_outputs: List[T_Output],
) -> Dict[str, float]:
scores = [metric(eval_outputs) for metric in self.metrics]
return {
f"{self.prefix}/{name}": score
for metric_output in scores
for name, score in metric_output.items()
}
def generate_plots(
self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots:
for name, fig in plot(eval_outputs):
yield f"{self.prefix}/{name}", fig
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> T_Output: ...
def include_sound_event_annotation(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.ignore_start_end,
)
def include_prediction(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.ignore_start_end,
)
@classmethod
def build(
cls,
config: BaseTaskConfig,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
**kwargs,
):
matcher = build_matcher(config.matching_strategy)
return cls(
matcher=matcher,
targets=targets,
metrics=metrics,
plots=plots,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
**kwargs,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -0,0 +1,149 @@
from typing import (
List,
Literal,
Sequence,
)
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig,
ClassificationMetricConfig,
ClipEval,
MatchEval,
build_classification_metric,
)
from batdetect2.evaluate.plots.classification import (
ClassificationPlotConfig,
build_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClassificationTaskConfig(BaseTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification"
metrics: List[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
)
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True
class ClassificationTask(BaseTask[ClipEval]):
def __init__(
self,
*args,
include_generics: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.include_generics = include_generics
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
all_gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
per_class_matches = {}
for class_name in self.targets.class_names:
class_idx = self.targets.class_names.index(class_name)
# Only match to targets of the given class
gts = [
sound_event
for sound_event in all_gts
if self.is_class(sound_event, class_name)
]
scores = [float(pred.class_scores[class_idx]) for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
score = (
float(pred.class_scores[class_idx])
if pred is not None
else 0
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
is_generic=gt is not None and true_class is None,
true_class=true_class,
score=score,
)
)
per_class_matches[class_name] = matches
return ClipEval(clip=clip, matches=per_class_matches)
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@tasks_registry.register(ClassificationTaskConfig)
@staticmethod
def from_config(
config: ClassificationTaskConfig,
targets: TargetProtocol,
):
metrics = [
build_classification_metric(metric, targets)
for metric in config.metrics
]
plots = [
build_classification_plotter(plot, targets)
for plot in config.plots
]
return ClassificationTask.build(
config=config,
plots=plots,
targets=targets,
metrics=metrics,
)

View File

@ -0,0 +1,85 @@
from collections import defaultdict
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.clip_classification import (
ClipClassificationAveragePrecisionConfig,
ClipClassificationMetricConfig,
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_classification import (
ClipClassificationPlotConfig,
build_clip_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification"
metrics: List[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(),
]
)
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gt_classes = set()
for sound_event in clip_annotation.sound_events:
if not self.include_sound_event_annotation(sound_event, clip):
continue
class_name = self.targets.encode_class(sound_event)
if class_name is None:
continue
gt_classes.add(class_name)
pred_scores = defaultdict(float)
for pred in predictions:
if not self.include_prediction(pred, clip):
continue
for class_idx, class_name in enumerate(self.targets.class_names):
pred_scores[class_name] = max(
float(pred.class_scores[class_idx]),
pred_scores[class_name],
)
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
@tasks_registry.register(ClipClassificationTaskConfig)
@staticmethod
def from_config(
config: ClipClassificationTaskConfig,
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_classification_plotter(plot, targets)
for plot in config.plots
]
return ClipClassificationTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -0,0 +1,76 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.clip_detection import (
ClipDetectionAveragePrecisionConfig,
ClipDetectionMetricConfig,
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_detection import (
ClipDetectionPlotConfig,
build_clip_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection"
metrics: List[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(),
]
)
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gt_det = any(
self.include_sound_event_annotation(sound_event, clip)
for sound_event in clip_annotation.sound_events
)
pred_score = 0
for pred in predictions:
if not self.include_prediction(pred, clip):
continue
pred_score = max(pred_score, pred.detection_score)
return ClipEval(
gt_det=gt_det,
score=pred_score,
)
@tasks_registry.register(ClipDetectionTaskConfig)
@staticmethod
def from_config(
config: ClipDetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_detection_plotter(plot, targets)
for plot in config.plots
]
return ClipDetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -0,0 +1,88 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.detection import (
ClipEval,
DetectionAveragePrecisionConfig,
DetectionMetricConfig,
MatchEval,
build_detection_metric,
)
from batdetect2.evaluate.plots.detection import (
DetectionPlotConfig,
build_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class DetectionTaskConfig(BaseTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection"
metrics: List[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
)
plots: List[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
scores = [pred.detection_score for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
matches.append(
MatchEval(
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
score=pred.detection_score if pred is not None else 0,
)
)
return ClipEval(clip=clip, matches=matches)
@tasks_registry.register(DetectionTaskConfig)
@staticmethod
def from_config(
config: DetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_detection_metric(metric) for metric in config.metrics]
plots = [
build_detection_plotter(plot, targets) for plot in config.plots
]
return DetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -0,0 +1,111 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
MatchEval,
TopClassAveragePrecisionConfig,
TopClassMetricConfig,
build_top_class_metric,
)
from batdetect2.evaluate.plots.top_class import (
TopClassPlotConfig,
build_top_class_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class TopClassDetectionTaskConfig(BaseTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class"
metrics: List[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
)
plots: List[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
class_idx = (
pred.class_scores.argmax() if pred is not None else None
)
score = (
float(pred.class_scores[class_idx]) if pred is not None else 0
)
pred_class = (
self.targets.class_names[class_idx]
if class_idx is not None
else None
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_ground_truth=gt is not None,
is_prediction=pred is not None,
true_class=true_class,
is_generic=gt is not None and true_class is None,
pred_class=pred_class,
score=score,
)
)
return ClipEval(clip=clip, matches=matches)
@tasks_registry.register(TopClassDetectionTaskConfig)
@staticmethod
def from_config(
config: TopClassDetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_top_class_metric(metric) for metric in config.metrics]
plots = [
build_top_class_plotter(plot, targets) for plot in config.plots
]
return TopClassDetectionTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -11,7 +11,6 @@ from batdetect2.plotting.matches import (
plot_cross_trigger_match, plot_cross_trigger_match,
plot_false_negative_match, plot_false_negative_match,
plot_false_positive_match, plot_false_positive_match,
plot_matches,
plot_true_positive_match, plot_true_positive_match,
) )
@ -22,7 +21,6 @@ __all__ = [
"plot_cross_trigger_match", "plot_cross_trigger_match",
"plot_false_negative_match", "plot_false_negative_match",
"plot_false_positive_match", "plot_false_positive_match",
"plot_matches",
"plot_spectrogram", "plot_spectrogram",
"plot_true_positive_match", "plot_true_positive_match",
"plot_detection_heatmap", "plot_detection_heatmap",

View File

@ -65,8 +65,6 @@ def plot_anchor_points(
if not targets.filter(sound_event): if not targets.filter(sound_event):
continue continue
sound_event = targets.transform(sound_event)
position, _ = targets.encode_roi(sound_event) position, _ = targets.encode_roi(sound_event)
positions.append(position) positions.append(position)

View File

@ -19,7 +19,7 @@ def create_ax(
) -> axes.Axes: ) -> axes.Axes:
"""Create a new axis if none is provided""" """Create a new axis if none is provided"""
if ax is None: if ax is None:
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore _, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
return ax # type: ignore return ax # type: ignore
@ -66,6 +66,9 @@ def plot_spectrogram(
vmax=vmax, vmax=vmax,
) )
ax.set_xlim(start_time, end_time)
ax.set_ylim(min_freq, max_freq)
if add_colorbar: if add_colorbar:
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {})) plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))

View File

@ -0,0 +1,113 @@
from typing import Optional
from matplotlib import axes, patches
from soundevent.plot import plot_geometry
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.plotting.clips import (
AudioLoader,
PreprocessorProtocol,
plot_clip,
)
from batdetect2.plotting.common import create_ax
__all__ = [
"plot_clip_detections",
]
def plot_clip_detections(
clip_eval: ClipEval,
figsize: tuple[int, int] = (10, 10),
ax: Optional[axes.Axes] = None,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
threshold: float = 0.2,
add_legend: bool = True,
add_title: bool = True,
fill: bool = False,
linewidth: float = 1.0,
gt_color: str = "green",
gt_linestyle: str = "-",
true_pred_color: str = "yellow",
true_pred_linestyle: str = "--",
false_pred_color: str = "blue",
false_pred_linestyle: str = "-",
missed_gt_color: str = "red",
missed_gt_linestyle: str = "-",
) -> axes.Axes:
ax = create_ax(figsize=figsize, ax=ax)
plot_clip(
clip_eval.clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
ax=ax,
)
for m in clip_eval.matches:
is_match = (
m.pred is not None and m.gt is not None and m.score >= threshold
)
if m.pred is not None:
color = true_pred_color if is_match else false_pred_color
plot_geometry(
m.pred.geometry,
ax=ax,
add_points=False,
facecolor="none" if not fill else color,
alpha=m.pred.detection_score,
linewidth=linewidth,
linestyle=true_pred_linestyle
if is_match
else missed_gt_linestyle,
color=color,
)
if m.gt is not None:
color = gt_color if is_match else missed_gt_color
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
ax=ax,
add_points=False,
linewidth=linewidth,
facecolor="none" if not fill else color,
linestyle=gt_linestyle if is_match else false_pred_linestyle,
color=color,
)
if add_title:
ax.set_title(clip_eval.clip.recording.path.name)
if add_legend:
ax.legend(
handles=[
patches.Patch(
label="found GT",
edgecolor=gt_color,
facecolor="none" if not fill else gt_color,
linestyle=gt_linestyle,
),
patches.Patch(
label="missed GT",
edgecolor=missed_gt_color,
facecolor="none" if not fill else missed_gt_color,
linestyle=missed_gt_linestyle,
),
patches.Patch(
label="true Det",
edgecolor=true_pred_color,
facecolor="none" if not fill else true_pred_color,
linestyle=true_pred_linestyle,
),
patches.Patch(
label="false Det",
edgecolor=false_pred_color,
facecolor="none" if not fill else false_pred_color,
linestyle=false_pred_linestyle,
),
]
)
return ax

View File

@ -1,81 +1,109 @@
from typing import List, Optional from typing import Optional, Sequence
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.plotting.matches import ( from batdetect2.plotting.matches import (
MatchProtocol,
plot_cross_trigger_match, plot_cross_trigger_match,
plot_false_negative_match, plot_false_negative_match,
plot_false_positive_match, plot_false_positive_match,
plot_true_positive_match, plot_true_positive_match,
) )
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
__all__ = ["plot_match_gallery"] __all__ = ["plot_match_gallery"]
def plot_match_gallery( def plot_match_gallery(
true_positives: List[MatchEvaluation], true_positives: Sequence[MatchProtocol],
false_positives: List[MatchEvaluation], false_positives: Sequence[MatchProtocol],
false_negatives: List[MatchEvaluation], false_negatives: Sequence[MatchProtocol],
cross_triggers: List[MatchEvaluation], cross_triggers: Sequence[MatchProtocol],
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
n_examples: int = 5, n_examples: int = 5,
duration: float = 0.1, duration: float = 0.1,
fig: Optional[Figure] = None,
): ):
fig = plt.figure(figsize=(20, 20)) if fig is None:
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives[:n_examples]): axes = fig.subplots(
ax = plt.subplot(4, n_examples, index + 1) nrows=4,
ncols=n_examples,
sharex="none",
sharey="row",
)
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
try: try:
plot_true_positive_match( plot_true_positive_match(
match, tp_match,
ax=ax, ax=tp_ax,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError, FileNotFoundError): except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue continue
for index, match in enumerate(false_positives[:n_examples]): for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try: try:
plot_false_positive_match( plot_false_positive_match(
match, fp_match,
ax=ax, ax=fp_ax,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError, FileNotFoundError): except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue continue
for index, match in enumerate(false_negatives[:n_examples]): for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try: try:
plot_false_negative_match( plot_false_negative_match(
match, fn_match,
ax=ax, ax=fn_ax,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError, FileNotFoundError): except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue continue
for index, match in enumerate(cross_triggers[:n_examples]): for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
try: try:
plot_cross_trigger_match( plot_cross_trigger_match(
match, ct_match,
ax=ax, ax=ct_ax,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError, FileNotFoundError): except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue continue
fig.tight_layout()
return fig return fig

View File

@ -1,16 +1,17 @@
from typing import List, Optional, Tuple, Union from typing import Optional, Protocol, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clips import AudioLoader, plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
RawPrediction,
)
__all__ = [ __all__ = [
"plot_matches",
"plot_false_positive_match", "plot_false_positive_match",
"plot_true_positive_match", "plot_true_positive_match",
"plot_false_negative_match", "plot_false_negative_match",
@ -18,6 +19,14 @@ __all__ = [
] ]
class MatchProtocol(Protocol):
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
score: float
true_class: Optional[str]
DEFAULT_DURATION = 0.05 DEFAULT_DURATION = 0.05
DEFAULT_FALSE_POSITIVE_COLOR = "orange" DEFAULT_FALSE_POSITIVE_COLOR = "orange"
DEFAULT_FALSE_NEGATIVE_COLOR = "red" DEFAULT_FALSE_NEGATIVE_COLOR = "red"
@ -27,88 +36,8 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
DEFAULT_PREDICTION_LINE_STYLE = "--" DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[MatchEvaluation],
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
) -> Axes:
ax = plot_clip(
clip,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if color_mapper is None:
color_mapper = TagColorMapper()
for match in matches:
if match.is_cross_trigger():
plot_cross_trigger_match(
match,
ax=ax,
fill=fill,
add_points=add_points,
add_spectrogram=False,
use_score=True,
color=cross_trigger_color,
add_text=False,
)
elif match.is_true_positive():
plot_true_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=true_positive_color,
add_text=False,
)
elif match.is_false_negative():
plot_false_negative_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
add_points=add_points,
color=false_negative_color,
add_text=False,
)
elif match.is_false_positive:
plot_false_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=false_positive_color,
add_text=False,
)
else:
continue
return ax
def plot_false_positive_match( def plot_false_positive_match(
match: MatchEvaluation, match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
@ -119,21 +48,24 @@ def plot_false_positive_match(
add_spectrogram: bool = True, add_spectrogram: bool = True,
add_text: bool = True, add_text: bool = True,
add_points: bool = False, add_points: bool = False,
add_title: bool = True,
fill: bool = False, fill: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_POSITIVE_COLOR, color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small", fontsize: Union[float, str] = "small",
) -> Axes: ) -> Axes:
assert match.pred_geometry is not None assert match.pred is not None
assert match.sound_event_annotation is None
start_time, _, _, high_freq = compute_bounds(match.pred_geometry) start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
clip = data.Clip( clip = data.Clip(
start_time=max(start_time - duration / 2, 0), start_time=max(
start_time - duration / 2,
0,
),
end_time=min( end_time=min(
start_time + duration / 2, start_time + duration / 2,
match.clip.end_time, match.clip.recording.duration,
), ),
recording=match.clip.recording, recording=match.clip.recording,
) )
@ -150,30 +82,33 @@ def plot_false_positive_match(
) )
ax = plot.plot_geometry( ax = plot.plot_geometry(
match.pred_geometry, match.pred.geometry,
ax=ax, ax=ax,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=match.pred_score if use_score else 1, alpha=match.score if use_score else 1,
color=color, color=color,
) )
if add_text: if add_text:
plt.text( ax.text(
start_time, start_time,
high_freq, high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ", f"score={match.score:.2f}",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
fontsize=fontsize, fontsize=fontsize,
) )
if add_title:
ax.set_title("False Positive")
return ax return ax
def plot_false_negative_match( def plot_false_negative_match(
match: MatchEvaluation, match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
@ -182,26 +117,28 @@ def plot_false_negative_match(
duration: float = DEFAULT_DURATION, duration: float = DEFAULT_DURATION,
add_spectrogram: bool = True, add_spectrogram: bool = True,
add_points: bool = False, add_points: bool = False,
add_text: bool = True, add_title: bool = True,
fill: bool = False, fill: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR, color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes: ) -> Axes:
assert match.pred_geometry is None assert match.gt is not None
assert match.sound_event_annotation is not None
sound_event = match.sound_event_annotation.sound_event geometry = match.gt.sound_event.geometry
geometry = sound_event.geometry
assert geometry is not None assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry) start_time = compute_bounds(geometry)[0]
clip = data.Clip( clip = data.Clip(
start_time=max(start_time - duration / 2, 0), start_time=max(
end_time=min( start_time - duration / 2,
start_time + duration / 2, sound_event.recording.duration 0,
), ),
recording=sound_event.recording, end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
) )
if add_spectrogram: if add_spectrogram:
@ -215,33 +152,23 @@ def plot_false_negative_match(
spec_cmap=spec_cmap, spec_cmap=spec_cmap,
) )
ax = plot.plot_annotation( ax = plot.plot_geometry(
match.sound_event_annotation, geometry,
ax=ax, ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=1, alpha=1,
color=color, color=color,
) )
if add_text: if add_title:
plt.text( ax.set_title("False Negative")
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax return ax
def plot_true_positive_match( def plot_true_positive_match(
match: MatchEvaluation, match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
@ -258,39 +185,42 @@ def plot_true_positive_match(
fontsize: Union[float, str] = "small", fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
add_title: bool = True,
) -> Axes: ) -> Axes:
assert match.sound_event_annotation is not None assert match.gt is not None
assert match.pred_geometry is not None assert match.pred is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry geometry = match.gt.sound_event.geometry
assert geometry is not None assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry) start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip( clip = data.Clip(
start_time=max(start_time - duration / 2, 0), start_time=max(
end_time=min( start_time - duration / 2,
start_time + duration / 2, sound_event.recording.duration 0,
), ),
recording=sound_event.recording, end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
) )
if add_spectrogram: if add_spectrogram:
ax = plot_clip( ax = plot_clip(
clip, clip,
ax=ax,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
figsize=figsize, figsize=figsize,
ax=ax,
audio_dir=audio_dir, audio_dir=audio_dir,
spec_cmap=spec_cmap, spec_cmap=spec_cmap,
) )
ax = plot.plot_annotation( ax = plot.plot_geometry(
match.sound_event_annotation, geometry,
ax=ax, ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=1, alpha=1,
@ -299,31 +229,34 @@ def plot_true_positive_match(
) )
plot.plot_geometry( plot.plot_geometry(
match.pred_geometry, match.pred.geometry,
ax=ax, ax=ax,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=match.pred_score if use_score else 1, alpha=match.score if use_score else 1,
color=color, color=color,
linestyle=prediction_linestyle, linestyle=prediction_linestyle,
) )
if add_text: if add_text:
plt.text( ax.text(
start_time, start_time,
high_freq, high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", f"score={match.score:.2f}",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
fontsize=fontsize, fontsize=fontsize,
) )
if add_title:
ax.set_title("True Positive")
return ax return ax
def plot_cross_trigger_match( def plot_cross_trigger_match(
match: MatchEvaluation, match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
@ -334,6 +267,7 @@ def plot_cross_trigger_match(
add_spectrogram: bool = True, add_spectrogram: bool = True,
add_points: bool = False, add_points: bool = False,
add_text: bool = True, add_text: bool = True,
add_title: bool = True,
fill: bool = False, fill: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR, color: str = DEFAULT_CROSS_TRIGGER_COLOR,
@ -341,20 +275,24 @@ def plot_cross_trigger_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes: ) -> Axes:
assert match.sound_event_annotation is not None assert match.gt is not None
assert match.pred_geometry is not None assert match.pred is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry geometry = match.gt.sound_event.geometry
assert geometry is not None assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry) start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip( clip = data.Clip(
start_time=max(start_time - duration / 2, 0), start_time=max(
end_time=min( start_time - duration / 2,
start_time + duration / 2, sound_event.recording.duration 0,
), ),
recording=sound_event.recording, end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
) )
if add_spectrogram: if add_spectrogram:
@ -368,11 +306,9 @@ def plot_cross_trigger_match(
spec_cmap=spec_cmap, spec_cmap=spec_cmap,
) )
ax = plot.plot_annotation( ax = plot.plot_geometry(
match.sound_event_annotation, geometry,
ax=ax, ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=1, alpha=1,
@ -381,24 +317,28 @@ def plot_cross_trigger_match(
) )
ax = plot.plot_geometry( ax = plot.plot_geometry(
match.pred_geometry, match.pred.geometry,
ax=ax, ax=ax,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=match.pred_score if use_score else 1, alpha=match.score if use_score else 1,
color=color, color=color,
linestyle=prediction_linestyle, linestyle=prediction_linestyle,
) )
if add_text: if add_text:
plt.text( ax.text(
start_time, start_time,
high_freq, high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", f"score={match.score:.2f}\nclass={match.true_class}",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
fontsize=fontsize, fontsize=fontsize,
) )
if add_title:
ax.set_title("Cross Trigger")
return ax return ax

View File

@ -0,0 +1,286 @@
from typing import Dict, Optional, Tuple
import numpy as np
import seaborn as sns
from cycler import cycler
from matplotlib import axes
from batdetect2.plotting.common import create_ax
def set_default_styler(ax: axes.Axes) -> axes.Axes:
color_cycler = cycler(color=sns.color_palette("muted"))
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
marker=["o", "s", "^"]
)
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
color_cycler
)
ax.set_prop_cycle(custom_cycler)
return ax
def set_default_style(ax: axes.Axes) -> axes.Axes:
ax = set_default_styler(ax)
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
return ax
def plot_pr_curve(
precision: np.ndarray,
recall: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
add_legend: bool = False,
label: str = "PR Curve",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
recall,
precision,
label=label,
marker="o",
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_legend:
ax.legend()
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
return ax
def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, recall, thresholds) in data.items():
ax.plot(
recall,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
return ax
def plot_threshold_precision_curve(
threshold: np.ndarray,
precision: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, _, thresholds) in data.items():
ax.plot(
thresholds,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_recall_curve(
threshold: np.ndarray,
recall: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (_, recall, thresholds) in data.items():
ax.plot(
thresholds,
recall,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_roc_curve(
fpr: np.ndarray,
tpr: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
fpr,
tpr,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (fpr, tpr, thresholds) in data.items():
ax.plot(
fpr,
tpr,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def _get_marker_positions(
thresholds: np.ndarray,
n_points: int = 11,
) -> np.ndarray:
size = len(thresholds)
cut_points = np.linspace(0, 1, n_points)
indices = np.searchsorted(thresholds[::-1], cut_points)
return np.clip(size - indices, 0, size - 1) # type: ignore

View File

@ -28,12 +28,10 @@ class CenterAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
return center_tensor(wav) return center_tensor(wav)
@classmethod @audio_transforms.register(CenterAudioConfig)
def from_config(cls, config: CenterAudioConfig, samplerate: int): @staticmethod
return cls() def from_config(config: CenterAudioConfig, samplerate: int):
return CenterAudio()
audio_transforms.register(CenterAudioConfig, CenterAudio)
class ScaleAudioConfig(BaseConfig): class ScaleAudioConfig(BaseConfig):
@ -44,12 +42,10 @@ class ScaleAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
return peak_normalize(wav) return peak_normalize(wav)
@classmethod @audio_transforms.register(ScaleAudioConfig)
def from_config(cls, config: ScaleAudioConfig, samplerate: int): @staticmethod
return cls() def from_config(config: ScaleAudioConfig, samplerate: int):
return ScaleAudio()
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
class FixDurationConfig(BaseConfig): class FixDurationConfig(BaseConfig):
@ -75,13 +71,12 @@ class FixDuration(torch.nn.Module):
return torch.nn.functional.pad(wav, (0, self.length - length)) return torch.nn.functional.pad(wav, (0, self.length - length))
@classmethod @audio_transforms.register(FixDurationConfig)
def from_config(cls, config: FixDurationConfig, samplerate: int): @staticmethod
return cls(samplerate=samplerate, duration=config.duration) def from_config(config: FixDurationConfig, samplerate: int):
return FixDuration(samplerate=samplerate, duration=config.duration)
audio_transforms.register(FixDurationConfig, FixDuration)
AudioTransform = Annotated[ AudioTransform = Annotated[
Union[ Union[
FixDurationConfig, FixDurationConfig,

View File

@ -285,10 +285,11 @@ class PCEN(torch.nn.Module):
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias)) * torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
).to(spec.dtype) ).to(spec.dtype)
@classmethod @spectrogram_transforms.register(PcenConfig)
def from_config(cls, config: PcenConfig, samplerate: int): @staticmethod
def from_config(config: PcenConfig, samplerate: int):
smooth = _compute_smoothing_constant(samplerate, config.time_constant) smooth = _compute_smoothing_constant(samplerate, config.time_constant)
return cls( return PCEN(
smoothing_constant=smooth, smoothing_constant=smooth,
gain=config.gain, gain=config.gain,
bias=config.bias, bias=config.bias,
@ -296,9 +297,6 @@ class PCEN(torch.nn.Module):
) )
spectrogram_transforms.register(PcenConfig, PCEN)
def _compute_smoothing_constant( def _compute_smoothing_constant(
samplerate: int, samplerate: int,
time_constant: float, time_constant: float,
@ -335,12 +333,10 @@ class ScaleAmplitude(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
return self.scaler(spec) return self.scaler(spec)
@classmethod @spectrogram_transforms.register(ScaleAmplitudeConfig)
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int): @staticmethod
return cls(scale=config.scale) def from_config(config: ScaleAmplitudeConfig, samplerate: int):
return ScaleAmplitude(scale=config.scale)
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
class SpectralMeanSubstractionConfig(BaseConfig): class SpectralMeanSubstractionConfig(BaseConfig):
@ -352,19 +348,13 @@ class SpectralMeanSubstraction(torch.nn.Module):
mean = spec.mean(-1, keepdim=True) mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0) return (spec - mean).clamp(min=0)
@classmethod @spectrogram_transforms.register(SpectralMeanSubstractionConfig)
@staticmethod
def from_config( def from_config(
cls,
config: SpectralMeanSubstractionConfig, config: SpectralMeanSubstractionConfig,
samplerate: int, samplerate: int,
): ):
return cls() return SpectralMeanSubstraction()
spectrogram_transforms.register(
SpectralMeanSubstractionConfig,
SpectralMeanSubstraction,
)
class PeakNormalizeConfig(BaseConfig): class PeakNormalizeConfig(BaseConfig):
@ -375,13 +365,12 @@ class PeakNormalize(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
return peak_normalize(spec) return peak_normalize(spec)
@classmethod @spectrogram_transforms.register(PeakNormalizeConfig)
def from_config(cls, config: PeakNormalizeConfig, samplerate: int): @staticmethod
return cls() def from_config(config: PeakNormalizeConfig, samplerate: int):
return PeakNormalize()
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
SpectrogramTransform = Annotated[ SpectrogramTransform = Annotated[
Union[ Union[
PcenConfig, PcenConfig,

View File

@ -99,7 +99,7 @@ DEFAULT_DETECTION_CLASS = TargetClassConfig(
DEFAULT_CLASSES = [ DEFAULT_CLASSES = [
TargetClassConfig( TargetClassConfig(
name="barbar", name="barbar",
tags=[data.Tag(key="class", value="Barbastellus barbastellus")], tags=[data.Tag(key="class", value="Barbastella barbastellus")],
), ),
TargetClassConfig( TargetClassConfig(
name="eptser", name="eptser",

View File

@ -1,11 +1,11 @@
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
AugmentationsConfig, AugmentationsConfig,
EchoAugmentationConfig, AddEchoConfig,
FrequencyMaskAugmentationConfig, MaskFrequencyConfig,
RandomAudioSource, RandomAudioSource,
TimeMaskAugmentationConfig, MaskTimeConfig,
VolumeAugmentationConfig, ScaleVolumeConfig,
WarpAugmentationConfig, WarpConfig,
add_echo, add_echo,
build_augmentations, build_augmentations,
mask_frequency, mask_frequency,
@ -43,20 +43,20 @@ __all__ = [
"AugmentationsConfig", "AugmentationsConfig",
"ClassificationLossConfig", "ClassificationLossConfig",
"DetectionLossConfig", "DetectionLossConfig",
"EchoAugmentationConfig", "AddEchoConfig",
"FrequencyMaskAugmentationConfig", "MaskFrequencyConfig",
"LossConfig", "LossConfig",
"LossFunction", "LossFunction",
"PLTrainerConfig", "PLTrainerConfig",
"RandomAudioSource", "RandomAudioSource",
"SizeLossConfig", "SizeLossConfig",
"TimeMaskAugmentationConfig", "MaskTimeConfig",
"TrainingConfig", "TrainingConfig",
"TrainingDataset", "TrainingDataset",
"TrainingModule", "TrainingModule",
"ValidationDataset", "ValidationDataset",
"VolumeAugmentationConfig", "ScaleVolumeConfig",
"WarpAugmentationConfig", "WarpConfig",
"add_echo", "add_echo",
"build_augmentations", "build_augmentations",
"build_clip_labeler", "build_clip_labeler",

View File

@ -12,21 +12,23 @@ from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.audio.clips import get_subclip_annotation from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.arrays import adjust_width from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
from batdetect2.typing import AudioLoader, Augmentation from batdetect2.typing import AudioLoader, Augmentation
__all__ = [ __all__ = [
"AugmentationConfig", "AugmentationConfig",
"AugmentationsConfig", "AugmentationsConfig",
"DEFAULT_AUGMENTATION_CONFIG", "DEFAULT_AUGMENTATION_CONFIG",
"EchoAugmentationConfig", "AddEchoConfig",
"AudioSource", "AudioSource",
"FrequencyMaskAugmentationConfig", "MaskFrequencyConfig",
"MixAugmentationConfig", "MixAudioConfig",
"TimeMaskAugmentationConfig", "MaskTimeConfig",
"VolumeAugmentationConfig", "ScaleVolumeConfig",
"WarpAugmentationConfig", "WarpConfig",
"add_echo", "add_echo",
"build_augmentations", "build_augmentations",
"load_augmentation_config", "load_augmentation_config",
@ -37,10 +39,19 @@ __all__ = [
"warp_spectrogram", "warp_spectrogram",
] ]
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]] AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
Registry(name="audio_augmentation")
)
class MixAugmentationConfig(BaseConfig): spec_augmentations: Registry[Augmentation, []] = Registry(
name="spec_augmentation"
)
class MixAudioConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples).""" """Configuration for MixUp augmentation (mixing two examples)."""
name: Literal["mix_audio"] = "mix_audio" name: Literal["mix_audio"] = "mix_audio"
@ -87,6 +98,27 @@ class MixAudio(torch.nn.Module):
) )
return mixed_audio, mixed_annotations return mixed_audio, mixed_annotations
@audio_augmentations.register(MixAudioConfig)
@staticmethod
def from_config(
config: MixAudioConfig,
samplerate: int,
source: Optional[AudioSource],
):
if source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.",
stacklevel=2,
)
return lambda wav, clip_annotation: (wav, clip_annotation)
return MixAudio(
example_source=source,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
def mix_audio( def mix_audio(
wav1: torch.Tensor, wav1: torch.Tensor,
@ -136,7 +168,7 @@ def combine_clip_annotations(
) )
class EchoAugmentationConfig(BaseConfig): class AddEchoConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb.""" """Configuration for adding synthetic echo/reverb."""
name: Literal["add_echo"] = "add_echo" name: Literal["add_echo"] = "add_echo"
@ -149,14 +181,17 @@ class EchoAugmentationConfig(BaseConfig):
class AddEcho(torch.nn.Module): class AddEcho(torch.nn.Module):
def __init__( def __init__(
self, self,
samplerate: int = TARGET_SAMPLERATE_HZ,
min_weight: float = 0.1, min_weight: float = 0.1,
max_weight: float = 1.0, max_weight: float = 1.0,
max_delay: int = 2560, max_delay: float = 0.005,
): ):
super().__init__() super().__init__()
self.samplerate = samplerate
self.min_weight = min_weight self.min_weight = min_weight
self.max_weight = max_weight self.max_weight = max_weight
self.max_delay = max_delay self.max_delay_s = max_delay
self.max_delay = int(max_delay * samplerate)
def forward( def forward(
self, self,
@ -167,6 +202,20 @@ class AddEcho(torch.nn.Module):
weight = np.random.uniform(self.min_weight, self.max_weight) weight = np.random.uniform(self.min_weight, self.max_weight)
return add_echo(wav, delay=delay, weight=weight), clip_annotation return add_echo(wav, delay=delay, weight=weight), clip_annotation
@audio_augmentations.register(AddEchoConfig)
@staticmethod
def from_config(
config: AddEchoConfig,
samplerate: int,
source: Optional[AudioSource],
):
return AddEcho(
samplerate=samplerate,
min_weight=config.min_weight,
max_weight=config.max_weight,
max_delay=config.max_delay,
)
def add_echo( def add_echo(
wav: torch.Tensor, wav: torch.Tensor,
@ -183,7 +232,7 @@ def add_echo(
return mix_audio(wav, audio_delay, weight) return mix_audio(wav, audio_delay, weight)
class VolumeAugmentationConfig(BaseConfig): class ScaleVolumeConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram.""" """Configuration for random volume scaling of the spectrogram."""
name: Literal["scale_volume"] = "scale_volume" name: Literal["scale_volume"] = "scale_volume"
@ -206,19 +255,27 @@ class ScaleVolume(torch.nn.Module):
factor = np.random.uniform(self.min_scaling, self.max_scaling) factor = np.random.uniform(self.min_scaling, self.max_scaling)
return scale_volume(spec, factor=factor), clip_annotation return scale_volume(spec, factor=factor), clip_annotation
@spec_augmentations.register(ScaleVolumeConfig)
@staticmethod
def from_config(config: ScaleVolumeConfig):
return ScaleVolume(
min_scaling=config.min_scaling,
max_scaling=config.max_scaling,
)
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor: def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
"""Scale the amplitude of the spectrogram by a factor.""" """Scale the amplitude of the spectrogram by a factor."""
return spec * factor return spec * factor
class WarpAugmentationConfig(BaseConfig): class WarpConfig(BaseConfig):
name: Literal["warp"] = "warp" name: Literal["warp"] = "warp"
probability: float = 0.2 probability: float = 0.2
delta: float = 0.04 delta: float = 0.04
class WarpSpectrogram(torch.nn.Module): class Warp(torch.nn.Module):
def __init__(self, delta: float = 0.04) -> None: def __init__(self, delta: float = 0.04) -> None:
super().__init__() super().__init__()
self.delta = delta self.delta = delta
@ -234,6 +291,11 @@ class WarpSpectrogram(torch.nn.Module):
warp_clip_annotation(clip_annotation, factor=factor), warp_clip_annotation(clip_annotation, factor=factor),
) )
@spec_augmentations.register(WarpConfig)
@staticmethod
def from_config(config: WarpConfig):
return Warp(delta=config.delta)
def warp_sound_event_annotation( def warp_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation, sound_event_annotation: data.SoundEventAnnotation,
@ -294,7 +356,7 @@ def warp_spectrogram(
).squeeze(0) ).squeeze(0)
class TimeMaskAugmentationConfig(BaseConfig): class MaskTimeConfig(BaseConfig):
name: Literal["mask_time"] = "mask_time" name: Literal["mask_time"] = "mask_time"
probability: float = 0.2 probability: float = 0.2
max_perc: float = 0.05 max_perc: float = 0.05
@ -336,6 +398,14 @@ class MaskTime(torch.nn.Module):
] ]
return mask_time(spec, masks), clip_annotation return mask_time(spec, masks), clip_annotation
@spec_augmentations.register(MaskTimeConfig)
@staticmethod
def from_config(config: MaskTimeConfig):
return MaskTime(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
def mask_time( def mask_time(
spec: torch.Tensor, spec: torch.Tensor,
@ -351,7 +421,7 @@ def mask_time(
return spec return spec
class FrequencyMaskAugmentationConfig(BaseConfig): class MaskFrequencyConfig(BaseConfig):
name: Literal["mask_freq"] = "mask_freq" name: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2 probability: float = 0.2
max_perc: float = 0.10 max_perc: float = 0.10
@ -394,6 +464,14 @@ class MaskFrequency(torch.nn.Module):
] ]
return mask_frequency(spec, masks), clip_annotation return mask_frequency(spec, masks), clip_annotation
@spec_augmentations.register(MaskFrequencyConfig)
@staticmethod
def from_config(config: MaskFrequencyConfig):
return MaskFrequency(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
def mask_frequency( def mask_frequency(
spec: torch.Tensor, spec: torch.Tensor,
@ -410,8 +488,8 @@ def mask_frequency(
AudioAugmentationConfig = Annotated[ AudioAugmentationConfig = Annotated[
Union[ Union[
MixAugmentationConfig, MixAudioConfig,
EchoAugmentationConfig, AddEchoConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -419,22 +497,22 @@ AudioAugmentationConfig = Annotated[
SpectrogramAugmentationConfig = Annotated[ SpectrogramAugmentationConfig = Annotated[
Union[ Union[
VolumeAugmentationConfig, ScaleVolumeConfig,
WarpAugmentationConfig, WarpConfig,
FrequencyMaskAugmentationConfig, MaskFrequencyConfig,
TimeMaskAugmentationConfig, MaskTimeConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[
Union[ Union[
MixAugmentationConfig, MixAudioConfig,
EchoAugmentationConfig, AddEchoConfig,
VolumeAugmentationConfig, ScaleVolumeConfig,
WarpAugmentationConfig, WarpConfig,
FrequencyMaskAugmentationConfig, MaskFrequencyConfig,
TimeMaskAugmentationConfig, MaskTimeConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -513,7 +591,7 @@ def build_augmentation_from_config(
) )
if config.name == "warp": if config.name == "warp":
return WarpSpectrogram( return Warp(
delta=config.delta, delta=config.delta,
) )
@ -538,14 +616,14 @@ def build_augmentation_from_config(
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
enabled=True, enabled=True,
audio=[ audio=[
MixAugmentationConfig(), MixAudioConfig(),
EchoAugmentationConfig(), AddEchoConfig(),
], ],
spectrogram=[ spectrogram=[
VolumeAugmentationConfig(), ScaleVolumeConfig(),
WarpAugmentationConfig(), WarpConfig(),
TimeMaskAugmentationConfig(), MaskTimeConfig(),
FrequencyMaskAugmentationConfig(), MaskFrequencyConfig(),
], ],
) )
@ -566,9 +644,9 @@ class AugmentationSequence(torch.nn.Module):
return tensor, clip_annotation return tensor, clip_annotation
def build_augmentation_sequence( def build_audio_augmentations(
samplerate: int, steps: Optional[Sequence[AudioAugmentationConfig]] = None,
steps: Optional[Sequence[AugmentationConfig]] = None, samplerate: int = TARGET_SAMPLERATE_HZ,
audio_source: Optional[AudioSource] = None, audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]: ) -> Optional[Augmentation]:
if not steps: if not steps:
@ -577,10 +655,8 @@ def build_augmentation_sequence(
augmentations = [] augmentations = []
for step_config in steps: for step_config in steps:
augmentation = build_augmentation_from_config( augmentation = audio_augmentations.build(
step_config, step_config, samplerate, audio_source
samplerate=samplerate,
audio_source=audio_source,
) )
if augmentation is None: if augmentation is None:
@ -596,6 +672,30 @@ def build_augmentation_sequence(
return AugmentationSequence(augmentations) return AugmentationSequence(augmentations)
def build_spectrogram_augmentations(
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
) -> Optional[Augmentation]:
if not steps:
return None
augmentations = []
for step_config in steps:
augmentation = spec_augmentations.build(step_config)
if augmentation is None:
continue
augmentations.append(
MaybeApply(
augmentation=augmentation,
probability=step_config.probability,
)
)
return AugmentationSequence(augmentations)
def build_augmentations( def build_augmentations(
samplerate: int, samplerate: int,
config: Optional[AugmentationsConfig] = None, config: Optional[AugmentationsConfig] = None,
@ -609,16 +709,14 @@ def build_augmentations(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
audio_augmentation = build_augmentation_sequence( audio_augmentation = build_audio_augmentations(
samplerate,
steps=config.audio, steps=config.audio,
samplerate=samplerate,
audio_source=audio_source, audio_source=audio_source,
) )
spectrogram_augmentation = build_augmentation_sequence( spectrogram_augmentation = build_spectrogram_augmentations(
samplerate, steps=config.spectrogram,
steps=config.audio,
audio_source=audio_source,
) )
return audio_augmentation, spectrogram_augmentation return audio_augmentation, spectrogram_augmentation

View File

@ -1,4 +1,4 @@
from typing import List from typing import Any, List
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
@ -10,7 +10,6 @@ from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import ( from batdetect2.typing import (
ClipEvaluation,
EvaluatorProtocol, EvaluatorProtocol,
ModelOutput, ModelOutput,
RawPrediction, RawPrediction,
@ -36,23 +35,23 @@ class ValidationMetrics(Callback):
def generate_plots( def generate_plots(
self, self,
eval_outputs: Any,
pl_module: LightningModule, pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
): ):
plotter = get_image_logger(pl_module.logger) # type: ignore plotter = get_image_logger(pl_module.logger) # type: ignore
if plotter is None: if plotter is None:
return return
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): for figure_name, fig in self.evaluator.generate_plots(eval_outputs):
plotter(figure_name, fig, pl_module.global_step) plotter(figure_name, fig, pl_module.global_step)
def log_metrics( def log_metrics(
self, self,
eval_outputs: Any,
pl_module: LightningModule, pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
): ):
metrics = self.evaluator.compute_metrics(evaluated_clips) metrics = self.evaluator.compute_metrics(eval_outputs)
pl_module.log_dict(metrics) pl_module.log_dict(metrics)
def on_validation_epoch_end( def on_validation_epoch_end(
@ -60,13 +59,13 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
clip_evaluations = self.evaluator.evaluate( eval_outputs = self.evaluator.evaluate(
self._clip_annotations, self._clip_annotations,
self._predictions, self._predictions,
) )
self.log_metrics(pl_module, clip_evaluations) self.log_metrics(eval_outputs, pl_module)
self.generate_plots(pl_module, clip_evaluations) self.generate_plots(eval_outputs, pl_module)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)

View File

@ -8,7 +8,7 @@ from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate import build_evaluator
from batdetect2.logging import build_logger from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
@ -105,7 +105,10 @@ def train(
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
config, config,
targets=targets, targets=targets,
evaluator=build_evaluator(config.train.validation, targets=targets), evaluator=build_evaluator(
config.train.validation,
targets=targets,
),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
@ -143,7 +146,7 @@ def build_trainer_callbacks(
ModelCheckpoint( ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
save_top_k=1, save_top_k=1,
monitor="total_loss/val", monitor="classification/mean_average_precision",
), ),
ValidationMetrics(evaluator), ValidationMetrics(evaluator),
] ]

View File

@ -1,6 +1,8 @@
from batdetect2.typing.evaluate import ( from batdetect2.typing.evaluate import (
ClipEvaluation, AffinityFunction,
ClipMatches,
EvaluatorProtocol, EvaluatorProtocol,
MatcherProtocol,
MatchEvaluation, MatchEvaluation,
MetricsProtocol, MetricsProtocol,
PlotterProtocol, PlotterProtocol,
@ -36,19 +38,22 @@ from batdetect2.typing.train import (
) )
__all__ = [ __all__ = [
"AffinityFunction",
"AudioLoader", "AudioLoader",
"Augmentation", "Augmentation",
"BackboneModel", "BackboneModel",
"BatDetect2Prediction", "BatDetect2Prediction",
"ClipEvaluation", "ClipMatches",
"ClipLabeller", "ClipLabeller",
"ClipperProtocol", "ClipperProtocol",
"DetectionModel", "DetectionModel",
"EvaluatorProtocol",
"GeometryDecoder", "GeometryDecoder",
"Heatmaps", "Heatmaps",
"LossProtocol", "LossProtocol",
"Losses", "Losses",
"MatchEvaluation", "MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol", "MetricsProtocol",
"ModelOutput", "ModelOutput",
"PlotterProtocol", "PlotterProtocol",
@ -63,5 +68,4 @@ __all__ = [
"SoundEventFilter", "SoundEventFilter",
"TargetProtocol", "TargetProtocol",
"TrainExample", "TrainExample",
"EvaluatorProtocol",
] ]

View File

@ -31,6 +31,7 @@ class MatchEvaluation:
sound_event_annotation: Optional[data.SoundEventAnnotation] sound_event_annotation: Optional[data.SoundEventAnnotation]
gt_det: bool gt_det: bool
gt_class: Optional[str] gt_class: Optional[str]
gt_geometry: Optional[data.Geometry]
pred_score: float pred_score: float
pred_class_scores: Dict[str, float] pred_class_scores: Dict[str, float]
@ -39,44 +40,32 @@ class MatchEvaluation:
affinity: float affinity: float
@property @property
def pred_class(self) -> Optional[str]: def top_class(self) -> Optional[str]:
if not self.pred_class_scores: if not self.pred_class_scores:
return None return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property @property
def pred_class_score(self) -> float: def is_prediction(self) -> bool:
pred_class = self.pred_class return self.pred_geometry is not None
@property
def is_generic(self) -> bool:
return self.gt_det and self.gt_class is None
@property
def top_class_score(self) -> float:
pred_class = self.top_class
if pred_class is None: if pred_class is None:
return 0 return 0
return self.pred_class_scores[pred_class] return self.pred_class_scores[pred_class]
def is_true_positive(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class == self.pred_class
)
def is_false_positive(self, threshold: float = 0) -> bool:
return self.gt_det is None and self.pred_score > threshold
def is_false_negative(self, threshold: float = 0) -> bool:
return self.gt_det and self.pred_score <= threshold
def is_cross_trigger(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class != self.pred_class
)
@dataclass @dataclass
class ClipEvaluation: class ClipMatches:
clip: data.Clip clip: data.Clip
matches: List[MatchEvaluation] matches: List[MatchEvaluation]
@ -103,29 +92,36 @@ class AffinityFunction(Protocol, Generic[Geom]):
class MetricsProtocol(Protocol): class MetricsProtocol(Protocol):
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEvaluation] self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> Dict[str, float]: ... ) -> Dict[str, float]: ...
class PlotterProtocol(Protocol): class PlotterProtocol(Protocol):
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEvaluation] self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[Tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol): EvaluationOutput = TypeVar("EvaluationOutput")
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol targets: TargetProtocol
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]], predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]: ... ) -> EvaluationOutput: ...
def compute_metrics( def compute_metrics(
self, clip_evaluations: Sequence[ClipEvaluation] self, eval_outputs: EvaluationOutput
) -> Dict[str, float]: ... ) -> Dict[str, float]: ...
def generate_plots( def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation] self, eval_outputs: EvaluationOutput
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[Tuple[str, Figure]]: ...