mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
6 Commits
4cd983a2c2
...
30159d64a9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30159d64a9 | ||
|
|
c9f0c5c431 | ||
|
|
10865ee600 | ||
|
|
87ed44c8f7 | ||
|
|
df2abff654 | ||
|
|
d6ddc4514c |
@ -138,27 +138,49 @@ train:
|
|||||||
name: csv
|
name: csv
|
||||||
|
|
||||||
validation:
|
validation:
|
||||||
|
tasks:
|
||||||
|
- name: sound_event_detection
|
||||||
metrics:
|
metrics:
|
||||||
- name: detection_ap
|
- name: average_precision
|
||||||
- name: detection_roc_auc
|
- name: sound_event_classification
|
||||||
- name: classification_ap
|
metrics:
|
||||||
- name: classification_roc_auc
|
- name: average_precision
|
||||||
- name: top_class_ap
|
|
||||||
- name: classification_balanced_accuracy
|
|
||||||
- name: clip_ap
|
|
||||||
- name: clip_roc_auc
|
|
||||||
|
|
||||||
evaluation:
|
evaluation:
|
||||||
match_strategy:
|
tasks:
|
||||||
name: start_time_match
|
- name: sound_event_detection
|
||||||
distance_threshold: 0.01
|
|
||||||
metrics:
|
metrics:
|
||||||
- name: classification_ap
|
- name: average_precision
|
||||||
- name: detection_ap
|
- name: roc_auc
|
||||||
plots:
|
plots:
|
||||||
- name: example_gallery
|
- name: pr_curve
|
||||||
- name: example_clip
|
- name: score_distribution
|
||||||
- name: detection_pr_curve
|
- name: example_detection
|
||||||
- name: classification_pr_curves
|
- name: sound_event_classification
|
||||||
- name: detection_roc_curve
|
metrics:
|
||||||
- name: classification_roc_curves
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: top_class_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: confusion_matrix
|
||||||
|
- name: example_classification
|
||||||
|
- name: clip_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
|
- name: score_distribution
|
||||||
|
- name: clip_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
@ -8,6 +9,7 @@ from batdetect2.config import BatDetect2Config
|
|||||||
from batdetect2.evaluate import build_evaluator, evaluate
|
from batdetect2.evaluate import build_evaluator, evaluate
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.postprocess.decoding import to_raw_predictions
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets.targets import build_targets
|
from batdetect2.targets.targets import build_targets
|
||||||
from batdetect2.train import train
|
from batdetect2.train import train
|
||||||
@ -19,6 +21,7 @@ from batdetect2.typing import (
|
|||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
class BatDetect2API:
|
||||||
@ -92,6 +95,18 @@ class BatDetect2API:
|
|||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_spectrogram(
|
||||||
|
self,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
start_times: Optional[Sequence[float]] = None,
|
||||||
|
) -> List[List[RawPrediction]]:
|
||||||
|
outputs = self.model.detector(spec)
|
||||||
|
clip_detections = self.postprocessor(outputs, start_times=start_times)
|
||||||
|
return [
|
||||||
|
to_raw_predictions(clip_dets.numpy(), self.targets)
|
||||||
|
for clip_dets in clip_detections
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: BatDetect2Config):
|
def from_config(cls, config: BatDetect2Config):
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
@ -108,10 +123,7 @@ class BatDetect2API:
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
config=config.evaluation,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: Better to have a separate instance of
|
# NOTE: Better to have a separate instance of
|
||||||
# preprocessor and postprocessor as these may be moved
|
# preprocessor and postprocessor as these may be moved
|
||||||
@ -163,10 +175,7 @@ class BatDetect2API:
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
config=config.evaluation,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
@ -56,18 +56,16 @@ class RandomClip:
|
|||||||
min_sound_event_overlap=self.min_sound_event_overlap,
|
min_sound_event_overlap=self.min_sound_event_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@clipper_registry.register(RandomClipConfig)
|
||||||
def from_config(cls, config: RandomClipConfig):
|
@staticmethod
|
||||||
return cls(
|
def from_config(config: RandomClipConfig):
|
||||||
|
return RandomClip(
|
||||||
duration=config.duration,
|
duration=config.duration,
|
||||||
max_empty=config.max_empty,
|
max_empty=config.max_empty,
|
||||||
min_sound_event_overlap=config.min_sound_event_overlap,
|
min_sound_event_overlap=config.min_sound_event_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
clipper_registry.register(RandomClipConfig, RandomClip)
|
|
||||||
|
|
||||||
|
|
||||||
def get_subclip_annotation(
|
def get_subclip_annotation(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
random: bool = True,
|
random: bool = True,
|
||||||
@ -184,13 +182,12 @@ class PaddedClip:
|
|||||||
)
|
)
|
||||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
return clip_annotation.model_copy(update=dict(clip=clip))
|
||||||
|
|
||||||
@classmethod
|
@clipper_registry.register(PaddedClipConfig)
|
||||||
def from_config(cls, config: PaddedClipConfig):
|
@staticmethod
|
||||||
return cls(chunk_size=config.chunk_size)
|
def from_config(config: PaddedClipConfig):
|
||||||
|
return PaddedClip(chunk_size=config.chunk_size)
|
||||||
|
|
||||||
|
|
||||||
clipper_registry.register(PaddedClipConfig, PaddedClip)
|
|
||||||
|
|
||||||
ClipConfig = Annotated[
|
ClipConfig = Annotated[
|
||||||
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
|
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
|
||||||
]
|
]
|
||||||
|
|||||||
@ -53,6 +53,7 @@ class BaseConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return yaml.dump(
|
return yaml.dump(
|
||||||
self.model_dump(
|
self.model_dump(
|
||||||
|
mode="json",
|
||||||
exclude_none=exclude_none,
|
exclude_none=exclude_none,
|
||||||
exclude_unset=exclude_unset,
|
exclude_unset=exclude_unset,
|
||||||
exclude_defaults=exclude_defaults,
|
exclude_defaults=exclude_defaults,
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
import sys
|
import sys
|
||||||
from typing import Generic, Protocol, Type, TypeVar
|
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import assert_type
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 10):
|
if sys.version_info >= (3, 10):
|
||||||
from typing import ParamSpec
|
from typing import Concatenate, ParamSpec
|
||||||
else:
|
else:
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Registry",
|
"Registry",
|
||||||
|
"SimpleRegistry",
|
||||||
]
|
]
|
||||||
|
|
||||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||||
@ -18,19 +18,26 @@ T_Type = TypeVar("T_Type", covariant=True)
|
|||||||
P_Type = ParamSpec("P_Type")
|
P_Type = ParamSpec("P_Type")
|
||||||
|
|
||||||
|
|
||||||
class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol):
|
T = TypeVar("T")
|
||||||
"""A generic protocol for the logic classes."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: T_Config,
|
|
||||||
*args: P_Type.args,
|
|
||||||
**kwargs: P_Type.kwargs,
|
|
||||||
) -> T_Type: ...
|
|
||||||
|
|
||||||
|
|
||||||
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
|
class SimpleRegistry(Generic[T]):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self._name = name
|
||||||
|
self._registry = {}
|
||||||
|
|
||||||
|
def register(self, name: str):
|
||||||
|
def decorator(obj: T) -> T:
|
||||||
|
self._registry[name] = obj
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get(self, name: str) -> T:
|
||||||
|
return self._registry[name]
|
||||||
|
|
||||||
|
def has(self, name: str) -> bool:
|
||||||
|
return name in self._registry
|
||||||
|
|
||||||
|
|
||||||
class Registry(Generic[T_Type, P_Type]):
|
class Registry(Generic[T_Type, P_Type]):
|
||||||
@ -38,13 +45,15 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._registry = {}
|
self._registry: Dict[
|
||||||
|
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||||
|
] = {}
|
||||||
|
self._config_types: Dict[str, Type[BaseModel]] = {}
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
config_cls: Type[T_Config],
|
config_cls: Type[T_Config],
|
||||||
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
):
|
||||||
) -> None:
|
|
||||||
fields = config_cls.model_fields
|
fields = config_cls.model_fields
|
||||||
|
|
||||||
if "name" not in fields:
|
if "name" not in fields:
|
||||||
@ -52,10 +61,21 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
|
|
||||||
name = fields["name"].default
|
name = fields["name"].default
|
||||||
|
|
||||||
|
self._config_types[name] = config_cls
|
||||||
|
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
raise ValueError("'name' field must be a string literal.")
|
raise ValueError("'name' field must be a string literal.")
|
||||||
|
|
||||||
self._registry[name] = logic_cls
|
def decorator(
|
||||||
|
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||||
|
):
|
||||||
|
self._registry[name] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
||||||
|
return tuple(self._config_types.values())
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
@ -75,4 +95,4 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
f"No {self._name} with name '{name}' is registered."
|
f"No {self._name} with name '{name}' is registered."
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._registry[name].from_config(config, *args, **kwargs)
|
return self._registry[name](config, *args, **kwargs)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.core.registries import Registry
|
|||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
condition_registry: Registry[SoundEventCondition, []] = Registry("condition")
|
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
class HasTagConfig(BaseConfig):
|
||||||
@ -27,12 +27,10 @@ class HasTag:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return self.tag in sound_event_annotation.tags
|
return self.tag in sound_event_annotation.tags
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(HasTagConfig)
|
||||||
def from_config(cls, config: HasTagConfig):
|
@staticmethod
|
||||||
return cls(tag=config.tag)
|
def from_config(config: HasTagConfig):
|
||||||
|
return HasTag(tag=config.tag)
|
||||||
|
|
||||||
condition_registry.register(HasTagConfig, HasTag)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTagsConfig(BaseConfig):
|
class HasAllTagsConfig(BaseConfig):
|
||||||
@ -52,12 +50,10 @@ class HasAllTags:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return self.tags.issubset(sound_event_annotation.tags)
|
return self.tags.issubset(sound_event_annotation.tags)
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(HasAllTagsConfig)
|
||||||
def from_config(cls, config: HasAllTagsConfig):
|
@staticmethod
|
||||||
return cls(tags=config.tags)
|
def from_config(config: HasAllTagsConfig):
|
||||||
|
return HasAllTags(tags=config.tags)
|
||||||
|
|
||||||
condition_registry.register(HasAllTagsConfig, HasAllTags)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTagConfig(BaseConfig):
|
class HasAnyTagConfig(BaseConfig):
|
||||||
@ -77,13 +73,12 @@ class HasAnyTag:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
return bool(self.tags.intersection(sound_event_annotation.tags))
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(HasAnyTagConfig)
|
||||||
def from_config(cls, config: HasAnyTagConfig):
|
@staticmethod
|
||||||
return cls(tags=config.tags)
|
def from_config(config: HasAnyTagConfig):
|
||||||
|
return HasAnyTag(tags=config.tags)
|
||||||
|
|
||||||
|
|
||||||
condition_registry.register(HasAnyTagConfig, HasAnyTag)
|
|
||||||
|
|
||||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||||
|
|
||||||
|
|
||||||
@ -134,12 +129,10 @@ class Duration:
|
|||||||
|
|
||||||
return self._comparator(duration)
|
return self._comparator(duration)
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(DurationConfig)
|
||||||
def from_config(cls, config: DurationConfig):
|
@staticmethod
|
||||||
return cls(operator=config.operator, seconds=config.seconds)
|
def from_config(config: DurationConfig):
|
||||||
|
return Duration(operator=config.operator, seconds=config.seconds)
|
||||||
|
|
||||||
condition_registry.register(DurationConfig, Duration)
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
class FrequencyConfig(BaseConfig):
|
||||||
@ -181,18 +174,16 @@ class Frequency:
|
|||||||
|
|
||||||
return self._comparator(high_freq)
|
return self._comparator(high_freq)
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(FrequencyConfig)
|
||||||
def from_config(cls, config: FrequencyConfig):
|
@staticmethod
|
||||||
return cls(
|
def from_config(config: FrequencyConfig):
|
||||||
|
return Frequency(
|
||||||
operator=config.operator,
|
operator=config.operator,
|
||||||
boundary=config.boundary,
|
boundary=config.boundary,
|
||||||
hertz=config.hertz,
|
hertz=config.hertz,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
condition_registry.register(FrequencyConfig, Frequency)
|
|
||||||
|
|
||||||
|
|
||||||
class AllOfConfig(BaseConfig):
|
class AllOfConfig(BaseConfig):
|
||||||
name: Literal["all_of"] = "all_of"
|
name: Literal["all_of"] = "all_of"
|
||||||
conditions: Sequence["SoundEventConditionConfig"]
|
conditions: Sequence["SoundEventConditionConfig"]
|
||||||
@ -207,15 +198,13 @@ class AllOf:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return all(c(sound_event_annotation) for c in self.conditions)
|
return all(c(sound_event_annotation) for c in self.conditions)
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(AllOfConfig)
|
||||||
def from_config(cls, config: AllOfConfig):
|
@staticmethod
|
||||||
|
def from_config(config: AllOfConfig):
|
||||||
conditions = [
|
conditions = [
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
build_sound_event_condition(cond) for cond in config.conditions
|
||||||
]
|
]
|
||||||
return cls(conditions)
|
return AllOf(conditions)
|
||||||
|
|
||||||
|
|
||||||
condition_registry.register(AllOfConfig, AllOf)
|
|
||||||
|
|
||||||
|
|
||||||
class AnyOfConfig(BaseConfig):
|
class AnyOfConfig(BaseConfig):
|
||||||
@ -232,15 +221,13 @@ class AnyOf:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return any(c(sound_event_annotation) for c in self.conditions)
|
return any(c(sound_event_annotation) for c in self.conditions)
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(AnyOfConfig)
|
||||||
def from_config(cls, config: AnyOfConfig):
|
@staticmethod
|
||||||
|
def from_config(config: AnyOfConfig):
|
||||||
conditions = [
|
conditions = [
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
build_sound_event_condition(cond) for cond in config.conditions
|
||||||
]
|
]
|
||||||
return cls(conditions)
|
return AnyOf(conditions)
|
||||||
|
|
||||||
|
|
||||||
condition_registry.register(AnyOfConfig, AnyOf)
|
|
||||||
|
|
||||||
|
|
||||||
class NotConfig(BaseConfig):
|
class NotConfig(BaseConfig):
|
||||||
@ -257,14 +244,13 @@ class Not:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return not self.condition(sound_event_annotation)
|
return not self.condition(sound_event_annotation)
|
||||||
|
|
||||||
@classmethod
|
@conditions.register(NotConfig)
|
||||||
def from_config(cls, config: NotConfig):
|
@staticmethod
|
||||||
|
def from_config(config: NotConfig):
|
||||||
condition = build_sound_event_condition(config.condition)
|
condition = build_sound_event_condition(config.condition)
|
||||||
return cls(condition)
|
return Not(condition)
|
||||||
|
|
||||||
|
|
||||||
condition_registry.register(NotConfig, Not)
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
SoundEventConditionConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
HasTagConfig,
|
HasTagConfig,
|
||||||
@ -283,7 +269,7 @@ SoundEventConditionConfig = Annotated[
|
|||||||
def build_sound_event_condition(
|
def build_sound_event_condition(
|
||||||
config: SoundEventConditionConfig,
|
config: SoundEventConditionConfig,
|
||||||
) -> SoundEventCondition:
|
) -> SoundEventCondition:
|
||||||
return condition_registry.build(config)
|
return conditions.build(config)
|
||||||
|
|
||||||
|
|
||||||
def filter_clip_annotation(
|
def filter_clip_annotation(
|
||||||
|
|||||||
@ -17,7 +17,7 @@ SoundEventTransform = Callable[
|
|||||||
data.SoundEventAnnotation,
|
data.SoundEventAnnotation,
|
||||||
]
|
]
|
||||||
|
|
||||||
transform_registry: Registry[SoundEventTransform, []] = Registry("transform")
|
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
||||||
|
|
||||||
|
|
||||||
class SetFrequencyBoundConfig(BaseConfig):
|
class SetFrequencyBoundConfig(BaseConfig):
|
||||||
@ -63,12 +63,10 @@ class SetFrequencyBound:
|
|||||||
update=dict(sound_event=sound_event)
|
update=dict(sound_event=sound_event)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@transforms.register(SetFrequencyBoundConfig)
|
||||||
def from_config(cls, config: SetFrequencyBoundConfig):
|
@staticmethod
|
||||||
return cls(hertz=config.hertz, boundary=config.boundary)
|
def from_config(config: SetFrequencyBoundConfig):
|
||||||
|
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
|
||||||
|
|
||||||
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyIfConfig(BaseConfig):
|
class ApplyIfConfig(BaseConfig):
|
||||||
@ -95,14 +93,12 @@ class ApplyIf:
|
|||||||
|
|
||||||
return self.transform(sound_event_annotation)
|
return self.transform(sound_event_annotation)
|
||||||
|
|
||||||
@classmethod
|
@transforms.register(ApplyIfConfig)
|
||||||
def from_config(cls, config: ApplyIfConfig):
|
@staticmethod
|
||||||
|
def from_config(config: ApplyIfConfig):
|
||||||
transform = build_sound_event_transform(config.transform)
|
transform = build_sound_event_transform(config.transform)
|
||||||
condition = build_sound_event_condition(config.condition)
|
condition = build_sound_event_condition(config.condition)
|
||||||
return cls(condition=condition, transform=transform)
|
return ApplyIf(condition=condition, transform=transform)
|
||||||
|
|
||||||
|
|
||||||
transform_registry.register(ApplyIfConfig, ApplyIf)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceTagConfig(BaseConfig):
|
class ReplaceTagConfig(BaseConfig):
|
||||||
@ -134,12 +130,12 @@ class ReplaceTag:
|
|||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
@classmethod
|
@transforms.register(ReplaceTagConfig)
|
||||||
def from_config(cls, config: ReplaceTagConfig):
|
@staticmethod
|
||||||
return cls(original=config.original, replacement=config.replacement)
|
def from_config(config: ReplaceTagConfig):
|
||||||
|
return ReplaceTag(
|
||||||
|
original=config.original, replacement=config.replacement
|
||||||
transform_registry.register(ReplaceTagConfig, ReplaceTag)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MapTagValueConfig(BaseConfig):
|
class MapTagValueConfig(BaseConfig):
|
||||||
@ -189,18 +185,16 @@ class MapTagValue:
|
|||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
@classmethod
|
@transforms.register(MapTagValueConfig)
|
||||||
def from_config(cls, config: MapTagValueConfig):
|
@staticmethod
|
||||||
return cls(
|
def from_config(config: MapTagValueConfig):
|
||||||
|
return MapTagValue(
|
||||||
tag_key=config.tag_key,
|
tag_key=config.tag_key,
|
||||||
value_mapping=config.value_mapping,
|
value_mapping=config.value_mapping,
|
||||||
target_key=config.target_key,
|
target_key=config.target_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
transform_registry.register(MapTagValueConfig, MapTagValue)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyAllConfig(BaseConfig):
|
class ApplyAllConfig(BaseConfig):
|
||||||
name: Literal["apply_all"] = "apply_all"
|
name: Literal["apply_all"] = "apply_all"
|
||||||
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
||||||
@ -219,14 +213,13 @@ class ApplyAll:
|
|||||||
|
|
||||||
return sound_event_annotation
|
return sound_event_annotation
|
||||||
|
|
||||||
@classmethod
|
@transforms.register(ApplyAllConfig)
|
||||||
def from_config(cls, config: ApplyAllConfig):
|
@staticmethod
|
||||||
|
def from_config(config: ApplyAllConfig):
|
||||||
steps = [build_sound_event_transform(step) for step in config.steps]
|
steps = [build_sound_event_transform(step) for step in config.steps]
|
||||||
return cls(steps)
|
return ApplyAll(steps)
|
||||||
|
|
||||||
|
|
||||||
transform_registry.register(ApplyAllConfig, ApplyAll)
|
|
||||||
|
|
||||||
SoundEventTransformConfig = Annotated[
|
SoundEventTransformConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
SetFrequencyBoundConfig,
|
SetFrequencyBoundConfig,
|
||||||
@ -242,7 +235,7 @@ SoundEventTransformConfig = Annotated[
|
|||||||
def build_sound_event_transform(
|
def build_sound_event_transform(
|
||||||
config: SoundEventTransformConfig,
|
config: SoundEventTransformConfig,
|
||||||
) -> SoundEventTransform:
|
) -> SoundEventTransform:
|
||||||
return transform_registry.build(config)
|
return transforms.build(config)
|
||||||
|
|
||||||
|
|
||||||
def transform_clip_annotation(
|
def transform_clip_annotation(
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
from batdetect2.evaluate.evaluate import evaluate
|
from batdetect2.evaluate.evaluate import evaluate
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
|
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"load_evaluation_config",
|
|
||||||
"evaluate",
|
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
|
"TaskConfig",
|
||||||
"build_evaluator",
|
"build_evaluator",
|
||||||
|
"build_task",
|
||||||
|
"evaluate",
|
||||||
|
"load_evaluation_config",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Annotated, Literal, Optional, Union
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
|
from soundevent.geometry import compute_interval_overlap
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.core.registries import Registry
|
||||||
@ -27,12 +28,10 @@ class TimeAffinity(AffinityFunction):
|
|||||||
geometry1, geometry2, time_buffer=self.time_buffer
|
geometry1, geometry2, time_buffer=self.time_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@affinity_functions.register(TimeAffinityConfig)
|
||||||
def from_config(cls, config: TimeAffinityConfig):
|
@staticmethod
|
||||||
return cls(time_buffer=config.time_buffer)
|
def from_config(config: TimeAffinityConfig):
|
||||||
|
return TimeAffinity(time_buffer=config.time_buffer)
|
||||||
|
|
||||||
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_timestamp_affinity(
|
def compute_timestamp_affinity(
|
||||||
@ -73,12 +72,10 @@ class IntervalIOU(AffinityFunction):
|
|||||||
time_buffer=self.time_buffer,
|
time_buffer=self.time_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@affinity_functions.register(IntervalIOUConfig)
|
||||||
def from_config(cls, config: IntervalIOUConfig):
|
@staticmethod
|
||||||
return cls(time_buffer=config.time_buffer)
|
def from_config(config: IntervalIOUConfig):
|
||||||
|
return IntervalIOU(time_buffer=config.time_buffer)
|
||||||
|
|
||||||
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_interval_iou(
|
def compute_interval_iou(
|
||||||
@ -97,9 +94,11 @@ def compute_interval_iou(
|
|||||||
end_time1 += time_buffer
|
end_time1 += time_buffer
|
||||||
end_time2 += time_buffer
|
end_time2 += time_buffer
|
||||||
|
|
||||||
intersection = max(
|
intersection = compute_interval_overlap(
|
||||||
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
(start_time1, end_time1),
|
||||||
|
(start_time2, end_time2),
|
||||||
)
|
)
|
||||||
|
|
||||||
union = (
|
union = (
|
||||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||||
)
|
)
|
||||||
@ -110,6 +109,86 @@ def compute_interval_iou(
|
|||||||
return intersection / union
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxIOUConfig(BaseConfig):
|
||||||
|
name: Literal["bbox_iou"] = "bbox_iou"
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
freq_buffer: float = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxIOU(AffinityFunction):
|
||||||
|
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||||
|
self.time_buffer = time_buffer
|
||||||
|
self.freq_buffer = freq_buffer
|
||||||
|
|
||||||
|
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||||
|
if not isinstance(geometry1, data.BoundingBox):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(geometry2, data.BoundingBox):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
||||||
|
)
|
||||||
|
return bbox_iou(
|
||||||
|
geometry1,
|
||||||
|
geometry2,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.freq_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
@affinity_functions.register(BBoxIOUConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: BBoxIOUConfig):
|
||||||
|
return BBoxIOU(
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=config.freq_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_iou(
|
||||||
|
geometry1: data.BoundingBox,
|
||||||
|
geometry2: data.BoundingBox,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float:
|
||||||
|
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
||||||
|
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
||||||
|
|
||||||
|
start_time1 -= time_buffer
|
||||||
|
start_time2 -= time_buffer
|
||||||
|
end_time1 += time_buffer
|
||||||
|
end_time2 += time_buffer
|
||||||
|
|
||||||
|
low_freq1 -= freq_buffer
|
||||||
|
low_freq2 -= freq_buffer
|
||||||
|
high_freq1 += freq_buffer
|
||||||
|
high_freq2 += freq_buffer
|
||||||
|
|
||||||
|
time_intersection = compute_interval_overlap(
|
||||||
|
(start_time1, end_time1),
|
||||||
|
(start_time2, end_time2),
|
||||||
|
)
|
||||||
|
|
||||||
|
freq_intersection = max(
|
||||||
|
0,
|
||||||
|
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
||||||
|
)
|
||||||
|
|
||||||
|
intersection = time_intersection * freq_intersection
|
||||||
|
|
||||||
|
if intersection == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
union = (
|
||||||
|
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
||||||
|
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
||||||
|
- intersection
|
||||||
|
)
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOUConfig(BaseConfig):
|
class GeometricIOUConfig(BaseConfig):
|
||||||
name: Literal["geometric_iou"] = "geometric_iou"
|
name: Literal["geometric_iou"] = "geometric_iou"
|
||||||
time_buffer: float = 0.01
|
time_buffer: float = 0.01
|
||||||
@ -127,17 +206,17 @@ class GeometricIOU(AffinityFunction):
|
|||||||
time_buffer=self.time_buffer,
|
time_buffer=self.time_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@affinity_functions.register(GeometricIOUConfig)
|
||||||
def from_config(cls, config: GeometricIOUConfig):
|
@staticmethod
|
||||||
return cls(time_buffer=config.time_buffer)
|
def from_config(config: GeometricIOUConfig):
|
||||||
|
return GeometricIOU(time_buffer=config.time_buffer)
|
||||||
|
|
||||||
|
|
||||||
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
|
|
||||||
|
|
||||||
AffinityConfig = Annotated[
|
AffinityConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
TimeAffinityConfig,
|
TimeAffinityConfig,
|
||||||
IntervalIOUConfig,
|
IntervalIOUConfig,
|
||||||
|
BBoxIOUConfig,
|
||||||
GeometricIOUConfig,
|
GeometricIOUConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
|
|||||||
@ -4,13 +4,11 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
from batdetect2.evaluate.tasks import (
|
||||||
from batdetect2.evaluate.metrics import (
|
TaskConfig,
|
||||||
ClassificationAPConfig,
|
|
||||||
DetectionAPConfig,
|
|
||||||
MetricConfig,
|
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.plots import PlotConfig
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -20,15 +18,12 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
ignore_start_end: float = 0.01
|
tasks: List[TaskConfig] = Field(
|
||||||
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
|
||||||
metrics: List[MetricConfig] = Field(
|
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
DetectionAPConfig(),
|
DetectionTaskConfig(),
|
||||||
ClassificationAPConfig(),
|
ClassificationTaskConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: List[PlotConfig] = Field(default_factory=list)
|
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,23 +1,12 @@
|
|||||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.evaluate.match import build_matcher, match
|
from batdetect2.evaluate.tasks import build_task
|
||||||
from batdetect2.evaluate.metrics import build_metric
|
|
||||||
from batdetect2.evaluate.plots import build_plotter
|
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
||||||
ClipEvaluation,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
MatcherProtocol,
|
|
||||||
MetricsProtocol,
|
|
||||||
PlotterProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
@ -28,146 +17,51 @@ __all__ = [
|
|||||||
class Evaluator:
|
class Evaluator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: EvaluationConfig,
|
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
matcher: MatcherProtocol,
|
tasks: Sequence[EvaluatorProtocol],
|
||||||
metrics: List[MetricsProtocol],
|
|
||||||
plots: List[PlotterProtocol],
|
|
||||||
):
|
):
|
||||||
self.config = config
|
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.matcher = matcher
|
self.tasks = tasks
|
||||||
self.metrics = metrics
|
|
||||||
self.plots = plots
|
|
||||||
|
|
||||||
def match(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipEvaluation:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
ground_truth = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if self.filter_sound_event_annotations(sound_event, clip)
|
|
||||||
]
|
|
||||||
predictions = [
|
|
||||||
prediction
|
|
||||||
for prediction in predictions
|
|
||||||
if self.filter_predictions(prediction, clip)
|
|
||||||
]
|
|
||||||
return ClipEvaluation(
|
|
||||||
clip=clip_annotation.clip,
|
|
||||||
matches=match(
|
|
||||||
ground_truth,
|
|
||||||
predictions,
|
|
||||||
clip=clip,
|
|
||||||
targets=self.targets,
|
|
||||||
matcher=self.matcher,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter_sound_event_annotations(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
clip: data.Clip,
|
|
||||||
) -> bool:
|
|
||||||
if not self.targets.filter(sound_event_annotation):
|
|
||||||
return False
|
|
||||||
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return is_in_bounds(
|
|
||||||
geometry,
|
|
||||||
clip,
|
|
||||||
self.config.ignore_start_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter_predictions(
|
|
||||||
self,
|
|
||||||
prediction: RawPrediction,
|
|
||||||
clip: data.Clip,
|
|
||||||
) -> bool:
|
|
||||||
return is_in_bounds(
|
|
||||||
prediction.geometry,
|
|
||||||
clip,
|
|
||||||
self.config.ignore_start_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
) -> List[ClipEvaluation]:
|
) -> List[Any]:
|
||||||
if len(clip_annotations) != len(predictions):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of annotated clips and sets of predictions do not match"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self.match(clip_annotation, preds)
|
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEvaluation],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
for metric in self.metrics:
|
for task, outputs in zip(self.tasks, eval_outputs):
|
||||||
results.update(metric(clip_evaluations))
|
results.update(task.compute_metrics(outputs))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def generate_plots(
|
def generate_plots(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self,
|
||||||
|
eval_outputs: List[Any],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
for plotter in self.plots:
|
for task, outputs in zip(self.tasks, eval_outputs):
|
||||||
for name, fig in plotter(clip_evaluations):
|
for name, fig in task.generate_plots(outputs):
|
||||||
yield name, fig
|
yield name, fig
|
||||||
|
|
||||||
|
|
||||||
def build_evaluator(
|
def build_evaluator(
|
||||||
config: Optional[EvaluationConfig] = None,
|
config: Optional[Union[EvaluationConfig, dict]] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
matcher: Optional[MatcherProtocol] = None,
|
|
||||||
plots: Optional[List[PlotterProtocol]] = None,
|
|
||||||
metrics: Optional[List[MetricsProtocol]] = None,
|
|
||||||
) -> EvaluatorProtocol:
|
) -> EvaluatorProtocol:
|
||||||
config = config or EvaluationConfig()
|
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
matcher = matcher or build_matcher(config.match_strategy)
|
|
||||||
|
|
||||||
if metrics is None:
|
if config is None:
|
||||||
metrics = [
|
config = EvaluationConfig()
|
||||||
build_metric(config, targets.class_names)
|
|
||||||
for config in config.metrics
|
|
||||||
]
|
|
||||||
|
|
||||||
if plots is None:
|
if not isinstance(config, EvaluationConfig):
|
||||||
plots = [
|
config = EvaluationConfig.model_validate(config)
|
||||||
build_plotter(config, targets.class_names)
|
|
||||||
for config in config.plots
|
|
||||||
]
|
|
||||||
|
|
||||||
return Evaluator(
|
return Evaluator(
|
||||||
config=config,
|
|
||||||
targets=targets,
|
targets=targets,
|
||||||
matcher=matcher,
|
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||||
metrics=metrics,
|
|
||||||
plots=plots,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_in_bounds(
|
|
||||||
geometry: data.Geometry,
|
|
||||||
clip: data.Clip,
|
|
||||||
buffer: float,
|
|
||||||
) -> bool:
|
|
||||||
start_time = compute_bounds(geometry)[0]
|
|
||||||
return (start_time >= clip.start_time + buffer) and (
|
|
||||||
start_time <= clip.end_time - buffer
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,11 +4,10 @@ from lightning import LightningModule
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
from batdetect2.evaluate.tables import FullEvaluationTable
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.logging import get_image_logger, get_table_logger
|
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
|
from batdetect2.typing import ClipMatches, EvaluatorProtocol
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
@ -54,18 +53,8 @@ class EvaluationModule(LightningModule):
|
|||||||
def on_test_epoch_end(self):
|
def on_test_epoch_end(self):
|
||||||
self.log_metrics(self.clip_evaluations)
|
self.log_metrics(self.clip_evaluations)
|
||||||
self.plot_examples(self.clip_evaluations)
|
self.plot_examples(self.clip_evaluations)
|
||||||
self.log_table(self.clip_evaluations)
|
|
||||||
|
|
||||||
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
|
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
|
||||||
table_logger = get_table_logger(self.logger) # type: ignore
|
|
||||||
|
|
||||||
if table_logger is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
df = FullEvaluationTable()(evaluated_clips)
|
|
||||||
table_logger("full_evaluation", df, 0)
|
|
||||||
|
|
||||||
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
|
|
||||||
plotter = get_image_logger(self.logger) # type: ignore
|
plotter = get_image_logger(self.logger) # type: ignore
|
||||||
|
|
||||||
if plotter is None:
|
if plotter is None:
|
||||||
@ -74,7 +63,7 @@ class EvaluationModule(LightningModule):
|
|||||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||||
plotter(figure_name, fig, self.global_step)
|
plotter(figure_name, fig, self.global_step)
|
||||||
|
|
||||||
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
|
def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
|
||||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||||
self.log_dict(metrics)
|
self.log_dict(metrics)
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,7 @@ from soundevent.evaluation import compute_affinity
|
|||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
from batdetect2.evaluate.affinity import (
|
from batdetect2.evaluate.affinity import (
|
||||||
AffinityConfig,
|
AffinityConfig,
|
||||||
GeometricIOUConfig,
|
GeometricIOUConfig,
|
||||||
@ -17,11 +16,13 @@ from batdetect2.evaluate.affinity import (
|
|||||||
)
|
)
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
|
AffinityFunction,
|
||||||
|
MatcherProtocol,
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
|
RawPrediction,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol
|
from batdetect2.typing.evaluate import ClipMatches
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
|
||||||
|
|
||||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||||
"""The geometry representation to use for matching."""
|
"""The geometry representation to use for matching."""
|
||||||
@ -33,9 +34,10 @@ def match(
|
|||||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||||
raw_predictions: Sequence[RawPrediction],
|
raw_predictions: Sequence[RawPrediction],
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
|
scores: Optional[Sequence[float]] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
matcher: Optional[MatcherProtocol] = None,
|
matcher: Optional[MatcherProtocol] = None,
|
||||||
) -> List[MatchEvaluation]:
|
) -> ClipMatches:
|
||||||
if matcher is None:
|
if matcher is None:
|
||||||
matcher = build_matcher()
|
matcher = build_matcher()
|
||||||
|
|
||||||
@ -51,8 +53,10 @@ def match(
|
|||||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if scores is None:
|
||||||
scores = [
|
scores = [
|
||||||
raw_prediction.detection_score for raw_prediction in raw_predictions
|
raw_prediction.detection_score
|
||||||
|
for raw_prediction in raw_predictions
|
||||||
]
|
]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
@ -73,9 +77,11 @@ def match(
|
|||||||
|
|
||||||
gt_det = target_idx is not None
|
gt_det = target_idx is not None
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
|
gt_geometry = (
|
||||||
|
target_geometries[target_idx] if target_idx is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
pred_score = float(prediction.detection_score) if prediction else 0
|
||||||
|
|
||||||
pred_geometry = (
|
pred_geometry = (
|
||||||
predicted_geometries[source_idx]
|
predicted_geometries[source_idx]
|
||||||
if source_idx is not None
|
if source_idx is not None
|
||||||
@ -84,7 +90,7 @@ def match(
|
|||||||
|
|
||||||
class_scores = (
|
class_scores = (
|
||||||
{
|
{
|
||||||
str(class_name): float(score)
|
class_name: score
|
||||||
for class_name, score in zip(
|
for class_name, score in zip(
|
||||||
targets.class_names,
|
targets.class_names,
|
||||||
prediction.class_scores,
|
prediction.class_scores,
|
||||||
@ -100,6 +106,7 @@ def match(
|
|||||||
sound_event_annotation=target,
|
sound_event_annotation=target,
|
||||||
gt_det=gt_det,
|
gt_det=gt_det,
|
||||||
gt_class=gt_class,
|
gt_class=gt_class,
|
||||||
|
gt_geometry=gt_geometry,
|
||||||
pred_score=pred_score,
|
pred_score=pred_score,
|
||||||
pred_class_scores=class_scores,
|
pred_class_scores=class_scores,
|
||||||
pred_geometry=pred_geometry,
|
pred_geometry=pred_geometry,
|
||||||
@ -107,7 +114,7 @@ def match(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return matches
|
return ClipMatches(clip=clip, matches=matches)
|
||||||
|
|
||||||
|
|
||||||
class StartTimeMatchConfig(BaseConfig):
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
@ -132,12 +139,10 @@ class StartTimeMatcher(MatcherProtocol):
|
|||||||
distance_threshold=self.distance_threshold,
|
distance_threshold=self.distance_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@matching_strategies.register(StartTimeMatchConfig)
|
||||||
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher":
|
@staticmethod
|
||||||
return cls(distance_threshold=config.distance_threshold)
|
def from_config(config: StartTimeMatchConfig):
|
||||||
|
return StartTimeMatcher(distance_threshold=config.distance_threshold)
|
||||||
|
|
||||||
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
|
|
||||||
|
|
||||||
|
|
||||||
def match_start_times(
|
def match_start_times(
|
||||||
@ -264,19 +269,17 @@ class GreedyMatcher(MatcherProtocol):
|
|||||||
affinity_threshold=self.affinity_threshold,
|
affinity_threshold=self.affinity_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@matching_strategies.register(GreedyMatchConfig)
|
||||||
def from_config(cls, config: GreedyMatchConfig):
|
@staticmethod
|
||||||
|
def from_config(config: GreedyMatchConfig):
|
||||||
affinity_function = build_affinity_function(config.affinity_function)
|
affinity_function = build_affinity_function(config.affinity_function)
|
||||||
return cls(
|
return GreedyMatcher(
|
||||||
geometry=config.geometry,
|
geometry=config.geometry,
|
||||||
affinity_threshold=config.affinity_threshold,
|
affinity_threshold=config.affinity_threshold,
|
||||||
affinity_function=affinity_function,
|
affinity_function=affinity_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_match(
|
def greedy_match(
|
||||||
ground_truth: Sequence[data.Geometry],
|
ground_truth: Sequence[data.Geometry],
|
||||||
predictions: Sequence[data.Geometry],
|
predictions: Sequence[data.Geometry],
|
||||||
@ -313,21 +316,21 @@ def greedy_match(
|
|||||||
unassigned_gt = set(range(len(ground_truth)))
|
unassigned_gt = set(range(len(ground_truth)))
|
||||||
|
|
||||||
if not predictions:
|
if not predictions:
|
||||||
for target_idx in range(len(ground_truth)):
|
for gt_idx in range(len(ground_truth)):
|
||||||
yield None, target_idx, 0
|
yield None, gt_idx, 0
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if not ground_truth:
|
if not ground_truth:
|
||||||
for source_idx in range(len(predictions)):
|
for pred_idx in range(len(predictions)):
|
||||||
yield source_idx, None, 0
|
yield pred_idx, None, 0
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
indices = np.argsort(scores)[::-1]
|
indices = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
for source_idx in indices:
|
for pred_idx in indices:
|
||||||
source_geometry = predictions[source_idx]
|
source_geometry = predictions[pred_idx]
|
||||||
|
|
||||||
affinities = np.array(
|
affinities = np.array(
|
||||||
[
|
[
|
||||||
@ -340,18 +343,18 @@ def greedy_match(
|
|||||||
affinity = affinities[closest_target]
|
affinity = affinities[closest_target]
|
||||||
|
|
||||||
if affinities[closest_target] <= affinity_threshold:
|
if affinities[closest_target] <= affinity_threshold:
|
||||||
yield source_idx, None, 0
|
yield pred_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if closest_target not in unassigned_gt:
|
if closest_target not in unassigned_gt:
|
||||||
yield source_idx, None, 0
|
yield pred_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
unassigned_gt.remove(closest_target)
|
unassigned_gt.remove(closest_target)
|
||||||
yield source_idx, closest_target, affinity
|
yield pred_idx, closest_target, affinity
|
||||||
|
|
||||||
for target_idx in unassigned_gt:
|
for gt_idx in unassigned_gt:
|
||||||
yield None, target_idx, 0
|
yield None, gt_idx, 0
|
||||||
|
|
||||||
|
|
||||||
class OptimalMatchConfig(BaseConfig):
|
class OptimalMatchConfig(BaseConfig):
|
||||||
@ -386,17 +389,16 @@ class OptimalMatcher(MatcherProtocol):
|
|||||||
affinity_threshold=self.affinity_threshold,
|
affinity_threshold=self.affinity_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@matching_strategies.register(OptimalMatchConfig)
|
||||||
def from_config(cls, config: OptimalMatchConfig):
|
@staticmethod
|
||||||
return cls(
|
def from_config(config: OptimalMatchConfig):
|
||||||
|
return OptimalMatcher(
|
||||||
affinity_threshold=config.affinity_threshold,
|
affinity_threshold=config.affinity_threshold,
|
||||||
time_buffer=config.time_buffer,
|
time_buffer=config.time_buffer,
|
||||||
frequency_buffer=config.frequency_buffer,
|
frequency_buffer=config.frequency_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
|
|
||||||
|
|
||||||
MatchConfig = Annotated[
|
MatchConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
GreedyMatchConfig,
|
GreedyMatchConfig,
|
||||||
|
|||||||
@ -1,712 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Callable, Mapping
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics, preprocessing
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.typing import ClipEvaluation, MetricsProtocol
|
|
||||||
|
|
||||||
__all__ = ["DetectionAP", "ClassificationAP"]
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
|
||||||
|
|
||||||
|
|
||||||
APImplementation = Literal["sklearn", "pascal_voc"]
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAPConfig(BaseConfig):
|
|
||||||
name: Literal["detection_ap"] = "detection_ap"
|
|
||||||
ap_implementation: APImplementation = "pascal_voc"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAP(MetricsProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
implementation: APImplementation = "pascal_voc",
|
|
||||||
):
|
|
||||||
self.implementation = implementation
|
|
||||||
self.metric = _ap_impl_mapping[self.implementation]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true, y_score = zip(
|
|
||||||
*[
|
|
||||||
(match.gt_det, match.pred_score)
|
|
||||||
for clip_eval in clip_evaluations
|
|
||||||
for match in clip_eval.matches
|
|
||||||
]
|
|
||||||
)
|
|
||||||
score = float(self.metric(y_true, y_score))
|
|
||||||
return {"detection_AP": score}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
|
||||||
return cls(implementation=config.ap_implementation)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCAUC(MetricsProtocol):
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true, y_score = zip(
|
|
||||||
*[
|
|
||||||
(match.gt_det, match.pred_score)
|
|
||||||
for clip_eval in clip_evaluations
|
|
||||||
for match in clip_eval.matches
|
|
||||||
]
|
|
||||||
)
|
|
||||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
|
||||||
return {"detection_ROC_AUC": score}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls, config: DetectionROCAUCConfig, class_names: List[str]
|
|
||||||
):
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAPConfig(BaseConfig):
|
|
||||||
name: Literal["classification_ap"] = "classification_ap"
|
|
||||||
ap_implementation: APImplementation = "pascal_voc"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAP(MetricsProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
implementation: APImplementation = "pascal_voc",
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.implementation = implementation
|
|
||||||
self.metric = _ap_impl_mapping[self.implementation]
|
|
||||||
self.class_names = class_names
|
|
||||||
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else "__NONE__"
|
|
||||||
)
|
|
||||||
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
match.pred_class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
class_scores = {}
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_pred_class = y_pred[:, class_index]
|
|
||||||
class_ap = self.metric(y_true_class, y_pred_class)
|
|
||||||
class_scores[class_name] = float(class_ap)
|
|
||||||
|
|
||||||
mean_ap = np.mean(
|
|
||||||
[value for value in class_scores.values() if value != 0]
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"classification_mAP": float(mean_ap),
|
|
||||||
**{
|
|
||||||
f"classification_AP/{class_name}": class_scores[class_name]
|
|
||||||
for class_name in self.selected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClassificationAPConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
class_names,
|
|
||||||
implementation=config.ap_implementation,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCAUC(MetricsProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.class_names = class_names
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else "__NONE__"
|
|
||||||
)
|
|
||||||
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
match.pred_class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
class_scores = {}
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_pred_class = y_pred[:, class_index]
|
|
||||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
|
||||||
class_scores[class_name] = float(class_roc_auc)
|
|
||||||
|
|
||||||
mean_roc_auc = np.mean(
|
|
||||||
[value for value in class_scores.values() if value != 0]
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
|
|
||||||
**{
|
|
||||||
f"classification_ROC_AUC/{class_name}": class_scores[
|
|
||||||
class_name
|
|
||||||
]
|
|
||||||
for class_name in self.selected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClassificationROCAUCConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
class_names,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassAPConfig(BaseConfig):
|
|
||||||
name: Literal["top_class_ap"] = "top_class_ap"
|
|
||||||
ap_implementation: APImplementation = "pascal_voc"
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassAP(MetricsProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
implementation: APImplementation = "pascal_voc",
|
|
||||||
):
|
|
||||||
self.implementation = implementation
|
|
||||||
self.metric = _ap_impl_mapping[self.implementation]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
top_class = match.pred_class
|
|
||||||
|
|
||||||
y_true.append(top_class == match.gt_class)
|
|
||||||
y_score.append(match.pred_class_score)
|
|
||||||
|
|
||||||
score = float(self.metric(y_true, y_score))
|
|
||||||
return {"top_class_AP": score}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
|
|
||||||
return cls(implementation=config.ap_implementation)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(TopClassAPConfig, TopClassAP)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationBalancedAccuracyConfig(BaseConfig):
|
|
||||||
name: Literal["classification_balanced_accuracy"] = (
|
|
||||||
"classification_balanced_accuracy"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationBalancedAccuracy(MetricsProtocol):
|
|
||||||
def __init__(self, class_names: List[str]):
|
|
||||||
self.class_names = class_names
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
top_class = match.pred_class
|
|
||||||
|
|
||||||
# Focus on matches
|
|
||||||
if match.gt_class is None or top_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(self.class_names.index(match.gt_class))
|
|
||||||
y_pred.append(self.class_names.index(top_class))
|
|
||||||
|
|
||||||
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
|
|
||||||
return {"classification_balanced_accuracy": score}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClassificationBalancedAccuracyConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(class_names)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(
|
|
||||||
ClassificationBalancedAccuracyConfig,
|
|
||||||
ClassificationBalancedAccuracy,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionAPConfig(BaseConfig):
|
|
||||||
name: Literal["clip_detection_ap"] = "clip_detection_ap"
|
|
||||||
ap_implementation: APImplementation = "pascal_voc"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionAP(MetricsProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
implementation: APImplementation,
|
|
||||||
):
|
|
||||||
self.implementation = implementation
|
|
||||||
self.metric = _ap_impl_mapping[self.implementation]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
clip_det = []
|
|
||||||
clip_scores = []
|
|
||||||
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
clip_det.append(match.gt_det)
|
|
||||||
clip_scores.append(match.pred_score)
|
|
||||||
|
|
||||||
y_true.append(any(clip_det))
|
|
||||||
y_score.append(max(clip_scores or [0]))
|
|
||||||
|
|
||||||
return {"clip_detection_ap": self.metric(y_true, y_score)}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClipDetectionAPConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(implementation=config.ap_implementation)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionROCAUC(MetricsProtocol):
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
clip_det = []
|
|
||||||
clip_scores = []
|
|
||||||
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
clip_det.append(match.gt_det)
|
|
||||||
clip_scores.append(match.pred_score)
|
|
||||||
|
|
||||||
y_true.append(any(clip_det))
|
|
||||||
y_score.append(max(clip_scores or [0]))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClipDetectionROCAUCConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipMulticlassAPConfig(BaseConfig):
|
|
||||||
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
|
|
||||||
ap_implementation: APImplementation = "pascal_voc"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClipMulticlassAP(MetricsProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
implementation: APImplementation,
|
|
||||||
include: Optional[Sequence[str]] = None,
|
|
||||||
exclude: Optional[Sequence[str]] = None,
|
|
||||||
):
|
|
||||||
self.implementation = implementation
|
|
||||||
self.metric = _ap_impl_mapping[self.implementation]
|
|
||||||
self.class_names = class_names
|
|
||||||
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
clip_classes = set()
|
|
||||||
clip_scores = defaultdict(list)
|
|
||||||
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
if match.gt_class is not None:
|
|
||||||
clip_classes.add(match.gt_class)
|
|
||||||
|
|
||||||
for class_name, score in match.pred_class_scores.items():
|
|
||||||
clip_scores[class_name].append(score)
|
|
||||||
|
|
||||||
y_true.append(clip_classes)
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
# Get max score for each class
|
|
||||||
max(clip_scores.get(class_name, [0]))
|
|
||||||
for class_name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = preprocessing.MultiLabelBinarizer(
|
|
||||||
classes=self.class_names
|
|
||||||
).fit_transform(y_true)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
class_scores = {}
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_pred_class = y_pred[:, class_index]
|
|
||||||
class_ap = self.metric(y_true_class, y_pred_class)
|
|
||||||
class_scores[class_name] = float(class_ap)
|
|
||||||
|
|
||||||
mean_ap = np.mean(
|
|
||||||
[value for value in class_scores.values() if value != 0]
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"clip_multiclass_mAP": float(mean_ap),
|
|
||||||
**{
|
|
||||||
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
|
|
||||||
for class_name in self.selected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls, config: ClipMulticlassAPConfig, class_names: List[str]
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
implementation=config.ap_implementation,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
class_names=class_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipMulticlassROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClipMulticlassROCAUC(MetricsProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
include: Optional[Sequence[str]] = None,
|
|
||||||
exclude: Optional[Sequence[str]] = None,
|
|
||||||
):
|
|
||||||
self.class_names = class_names
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
clip_classes = set()
|
|
||||||
clip_scores = defaultdict(list)
|
|
||||||
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
if match.gt_class is not None:
|
|
||||||
clip_classes.add(match.gt_class)
|
|
||||||
|
|
||||||
for class_name, score in match.pred_class_scores.items():
|
|
||||||
clip_scores[class_name].append(score)
|
|
||||||
|
|
||||||
y_true.append(clip_classes)
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
# Get maximum score for each class
|
|
||||||
max(clip_scores.get(class_name, [0]))
|
|
||||||
for class_name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = preprocessing.MultiLabelBinarizer(
|
|
||||||
classes=self.class_names
|
|
||||||
).fit_transform(y_true)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
class_scores = {}
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_pred_class = y_pred[:, class_index]
|
|
||||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
|
||||||
class_scores[class_name] = float(class_roc_auc)
|
|
||||||
|
|
||||||
mean_roc_auc = np.mean(
|
|
||||||
[value for value in class_scores.values() if value != 0]
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
|
|
||||||
**{
|
|
||||||
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
|
|
||||||
class_name
|
|
||||||
]
|
|
||||||
for class_name in self.selected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClipMulticlassROCAUCConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
class_names=class_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
|
|
||||||
|
|
||||||
MetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DetectionAPConfig,
|
|
||||||
DetectionROCAUCConfig,
|
|
||||||
ClassificationAPConfig,
|
|
||||||
ClassificationROCAUCConfig,
|
|
||||||
TopClassAPConfig,
|
|
||||||
ClassificationBalancedAccuracyConfig,
|
|
||||||
ClipDetectionAPConfig,
|
|
||||||
ClipDetectionROCAUCConfig,
|
|
||||||
ClipMulticlassAPConfig,
|
|
||||||
ClipMulticlassROCAUCConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_metric(config: MetricConfig, class_names: List[str]):
|
|
||||||
return metrics_registry.build(config, class_names)
|
|
||||||
|
|
||||||
|
|
||||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
|
||||||
y_true = np.array(y_true)
|
|
||||||
y_score = np.array(y_score)
|
|
||||||
|
|
||||||
sort_ind = np.argsort(y_score)[::-1]
|
|
||||||
y_true_sorted = y_true[sort_ind]
|
|
||||||
|
|
||||||
num_positives = y_true.sum()
|
|
||||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
|
||||||
true_pos_c = np.cumsum(y_true_sorted)
|
|
||||||
|
|
||||||
recall = true_pos_c / num_positives
|
|
||||||
precision = true_pos_c / np.maximum(
|
|
||||||
true_pos_c + false_pos_c,
|
|
||||||
np.finfo(np.float64).eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
|
||||||
recall[np.isnan(recall)] = 0
|
|
||||||
|
|
||||||
# pascal 12 way
|
|
||||||
mprec = np.hstack((0, precision, 0))
|
|
||||||
mrec = np.hstack((0, recall, 1))
|
|
||||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
|
||||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
|
||||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
|
||||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
|
||||||
|
|
||||||
return ave_prec
|
|
||||||
|
|
||||||
|
|
||||||
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
|
|
||||||
"sklearn": metrics.average_precision_score,
|
|
||||||
"pascal_voc": pascal_voc_average_precision,
|
|
||||||
}
|
|
||||||
0
src/batdetect2/evaluate/metrics/__init__.py
Normal file
0
src/batdetect2/evaluate/metrics/__init__.py
Normal file
267
src/batdetect2/evaluate/metrics/classification.py
Normal file
267
src/batdetect2/evaluate/metrics/classification.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClassificationMetric",
|
||||||
|
"ClassificationMetricConfig",
|
||||||
|
"build_classification_metric",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchEval:
|
||||||
|
clip: data.Clip
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
is_prediction: bool
|
||||||
|
is_ground_truth: bool
|
||||||
|
is_generic: bool
|
||||||
|
true_class: Optional[str]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
clip: data.Clip
|
||||||
|
matches: Mapping[str, List[MatchEval]]
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
|
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
||||||
|
Registry("classification_metric")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClassificationConfig(BaseConfig):
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClassificationMetric:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.include = include
|
||||||
|
self.exclude = exclude
|
||||||
|
|
||||||
|
def include_class(self, class_name: str) -> bool:
|
||||||
|
if self.include is not None:
|
||||||
|
return class_name in self.include
|
||||||
|
|
||||||
|
if self.exclude is not None:
|
||||||
|
return class_name not in self.exclude
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
label: str = "average_precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
label: str = "average_precision",
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEval]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: average_precision(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives=num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
mean_score = float(
|
||||||
|
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": mean_score,
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if self.include_class(class_name)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClassificationAveragePrecisionConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
return ClassificationAveragePrecision(
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUCConfig(BaseClassificationConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUC(BaseClassificationMetric):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
label: str = "roc_auc",
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.label = label
|
||||||
|
self.include = include
|
||||||
|
self.exclude = exclude
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEval]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: float(
|
||||||
|
metrics.roc_auc_score(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
mean_score = float(
|
||||||
|
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": mean_score,
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if self.include_class(class_name)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classification_metrics.register(ClassificationROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClassificationROCAUCConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ClassificationROCAUC(
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClassificationAveragePrecisionConfig,
|
||||||
|
ClassificationROCAUCConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_classification_metric(
|
||||||
|
config: ClassificationMetricConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClassificationMetric:
|
||||||
|
return classification_metrics.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_per_class_metric_data(
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
):
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
num_positives = defaultdict(lambda: 0)
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, matches in clip_eval.matches.items():
|
||||||
|
for m in matches:
|
||||||
|
# Exclude matches with ground truth sounds where the class
|
||||||
|
# is unknown
|
||||||
|
if m.is_generic and ignore_generic:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_class = m.true_class == class_name
|
||||||
|
|
||||||
|
if is_class:
|
||||||
|
num_positives[class_name] += 1
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true[class_name].append(is_class)
|
||||||
|
y_score[class_name].append(m.score)
|
||||||
|
|
||||||
|
return y_true, y_score, num_positives
|
||||||
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
true_classes: Set[str]
|
||||||
|
class_scores: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
||||||
|
"clip_classification_metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationAveragePrecision:
|
||||||
|
def __init__(self, label: str = "average_precision"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, score in clip_eval.class_scores.items():
|
||||||
|
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||||
|
y_score[class_name].append(score)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: float(
|
||||||
|
average_precision(
|
||||||
|
y_true=y_true[class_name],
|
||||||
|
y_score=y_score[class_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for class_name in y_true
|
||||||
|
}
|
||||||
|
|
||||||
|
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": float(mean),
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if not np.isnan(score)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@clip_classification_metrics.register(
|
||||||
|
ClipClassificationAveragePrecisionConfig
|
||||||
|
)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipClassificationAveragePrecisionConfig):
|
||||||
|
return ClipClassificationAveragePrecision(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationROCAUC:
|
||||||
|
def __init__(self, label: str = "roc_auc"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, score in clip_eval.class_scores.items():
|
||||||
|
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||||
|
y_score[class_name].append(score)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: float(
|
||||||
|
metrics.roc_auc_score(
|
||||||
|
y_true=y_true[class_name],
|
||||||
|
y_score=y_score[class_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for class_name in y_true
|
||||||
|
}
|
||||||
|
|
||||||
|
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": float(mean),
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if not np.isnan(score)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipClassificationROCAUCConfig):
|
||||||
|
return ClipClassificationROCAUC(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
ClipClassificationMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClipClassificationAveragePrecisionConfig,
|
||||||
|
ClipClassificationROCAUCConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_metric(config: ClipClassificationMetricConfig):
|
||||||
|
return clip_classification_metrics.build(config)
|
||||||
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
gt_det: bool
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
||||||
|
"clip_detection_metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionAveragePrecision:
|
||||||
|
def __init__(self, label: str = "average_precision"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
y_true.append(clip_eval.gt_det)
|
||||||
|
y_score.append(clip_eval.score)
|
||||||
|
|
||||||
|
score = average_precision(y_true=y_true, y_score=y_score)
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionAveragePrecisionConfig):
|
||||||
|
return ClipDetectionAveragePrecision(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUC:
|
||||||
|
def __init__(self, label: str = "roc_auc"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
y_true.append(clip_eval.gt_det)
|
||||||
|
y_score.append(clip_eval.score)
|
||||||
|
|
||||||
|
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionROCAUCConfig):
|
||||||
|
return ClipDetectionROCAUC(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionRecallConfig(BaseConfig):
|
||||||
|
name: Literal["recall"] = "recall"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "recall"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionRecall:
|
||||||
|
def __init__(self, threshold: float, label: str = "recall"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_positives = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
if clip_eval.gt_det:
|
||||||
|
num_positives += 1
|
||||||
|
|
||||||
|
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_positives == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_positives
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionRecallConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionRecallConfig):
|
||||||
|
return ClipDetectionRecall(
|
||||||
|
threshold=config.threshold, label=config.label
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionPrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["precision"] = "precision"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionPrecision:
|
||||||
|
def __init__(self, threshold: float, label: str = "precision"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_detections = 0
|
||||||
|
true_positives = 0
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
if clip_eval.score >= self.threshold:
|
||||||
|
num_detections += 1
|
||||||
|
|
||||||
|
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_detections == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_detections
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionPrecisionConfig):
|
||||||
|
return ClipDetectionPrecision(
|
||||||
|
threshold=config.threshold, label=config.label
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClipDetectionMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClipDetectionAveragePrecisionConfig,
|
||||||
|
ClipDetectionROCAUCConfig,
|
||||||
|
ClipDetectionRecallConfig,
|
||||||
|
ClipDetectionPrecisionConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_metric(config: ClipDetectionMetricConfig):
|
||||||
|
return clip_detection_metrics.build(config)
|
||||||
60
src/batdetect2/evaluate/metrics/common.py
Normal file
60
src/batdetect2/evaluate/metrics/common.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"compute_precision_recall",
|
||||||
|
"average_precision",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives: Optional[int] = None,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
y_true = np.array(y_true)
|
||||||
|
y_score = np.array(y_score)
|
||||||
|
|
||||||
|
if num_positives is None:
|
||||||
|
num_positives = y_true.sum()
|
||||||
|
|
||||||
|
# Sort by score
|
||||||
|
sort_ind = np.argsort(y_score)[::-1]
|
||||||
|
y_true_sorted = y_true[sort_ind]
|
||||||
|
y_score_sorted = y_score[sort_ind]
|
||||||
|
|
||||||
|
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||||
|
true_pos_c = np.cumsum(y_true_sorted)
|
||||||
|
|
||||||
|
recall = true_pos_c / num_positives
|
||||||
|
precision = true_pos_c / np.maximum(
|
||||||
|
true_pos_c + false_pos_c,
|
||||||
|
np.finfo(np.float64).eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
precision[np.isnan(precision)] = 0
|
||||||
|
recall[np.isnan(recall)] = 0
|
||||||
|
return precision, recall, y_score_sorted
|
||||||
|
|
||||||
|
|
||||||
|
def average_precision(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives: Optional[int] = None,
|
||||||
|
) -> float:
|
||||||
|
precision, recall, _ = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives=num_positives,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pascal 12 way
|
||||||
|
mprec = np.hstack((0, precision, 0))
|
||||||
|
mrec = np.hstack((0, recall, 1))
|
||||||
|
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||||
|
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||||
|
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||||
|
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||||
|
|
||||||
|
return ave_prec
|
||||||
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
from batdetect2.typing import RawPrediction
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DetectionMetricConfig",
|
||||||
|
"DetectionMetric",
|
||||||
|
"build_detection_metric",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchEval:
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
is_prediction: bool
|
||||||
|
is_ground_truth: bool
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
clip: data.Clip
|
||||||
|
matches: List[MatchEval]
|
||||||
|
|
||||||
|
|
||||||
|
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
|
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionAveragePrecision:
|
||||||
|
def __init__(self, label: str, ignore_non_predictions: bool = True):
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
||||||
|
return {self.label: ap}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionAveragePrecisionConfig):
|
||||||
|
return DetectionAveragePrecision(
|
||||||
|
label=config.label,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUC:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
label: str = "roc_auc",
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
):
|
||||||
|
self.label = label
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
|
||||||
|
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||||
|
y_true: List[bool] = []
|
||||||
|
y_score: List[float] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionROCAUCConfig):
|
||||||
|
return DetectionROCAUC(
|
||||||
|
label=config.label,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionRecallConfig(BaseConfig):
|
||||||
|
name: Literal["recall"] = "recall"
|
||||||
|
label: str = "recall"
|
||||||
|
threshold: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionRecall:
|
||||||
|
def __init__(self, threshold: float, label: str = "recall"):
|
||||||
|
self.label = label
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_positives = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_ground_truth:
|
||||||
|
num_positives += 1
|
||||||
|
|
||||||
|
if m.score >= self.threshold and m.is_ground_truth:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_positives == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_positives
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionRecallConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionRecallConfig):
|
||||||
|
return DetectionRecall(threshold=config.threshold, label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["precision"] = "precision"
|
||||||
|
label: str = "precision"
|
||||||
|
threshold: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPrecision:
|
||||||
|
def __init__(self, threshold: float, label: str = "precision"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_detections = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
is_detection = m.score >= self.threshold
|
||||||
|
|
||||||
|
if is_detection:
|
||||||
|
num_detections += 1
|
||||||
|
|
||||||
|
if is_detection and m.is_ground_truth:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_detections == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_detections
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionPrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionPrecisionConfig):
|
||||||
|
return DetectionPrecision(
|
||||||
|
threshold=config.threshold,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DetectionMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
DetectionAveragePrecisionConfig,
|
||||||
|
DetectionROCAUCConfig,
|
||||||
|
DetectionRecallConfig,
|
||||||
|
DetectionPrecisionConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_detection_metric(config: DetectionMetricConfig):
|
||||||
|
return detection_metrics.build(config)
|
||||||
314
src/batdetect2/evaluate/metrics/top_class.py
Normal file
314
src/batdetect2/evaluate/metrics/top_class.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics, preprocessing
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
from batdetect2.typing import RawPrediction
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TopClassMetricConfig",
|
||||||
|
"TopClassMetric",
|
||||||
|
"build_top_class_metric",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchEval:
|
||||||
|
clip: data.Clip
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
is_ground_truth: bool
|
||||||
|
is_generic: bool
|
||||||
|
is_prediction: bool
|
||||||
|
pred_class: Optional[str]
|
||||||
|
true_class: Optional[str]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
clip: data.Clip
|
||||||
|
matches: List[MatchEval]
|
||||||
|
|
||||||
|
|
||||||
|
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
|
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
ignore_generic: bool = True
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAveragePrecision:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
label: str = "average_precision",
|
||||||
|
):
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
score = average_precision(y_true, y_score, num_positives=num_positives)
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassAveragePrecisionConfig):
|
||||||
|
return TopClassAveragePrecision(
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
ignore_generic: bool = True
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
label: str = "roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassROCAUC:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
label: str = "roc_auc",
|
||||||
|
):
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||||
|
y_true: List[bool] = []
|
||||||
|
y_score: List[float] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassROCAUCConfig):
|
||||||
|
return TopClassROCAUC(
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassRecallConfig(BaseConfig):
|
||||||
|
name: Literal["recall"] = "recall"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "recall"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassRecall:
|
||||||
|
def __init__(self, threshold: float, label: str = "recall"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_positives = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_ground_truth:
|
||||||
|
num_positives += 1
|
||||||
|
|
||||||
|
if m.score >= self.threshold and m.pred_class == m.true_class:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_positives == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_positives
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassRecallConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassRecallConfig):
|
||||||
|
return TopClassRecall(
|
||||||
|
threshold=config.threshold,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassPrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["precision"] = "precision"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "precision"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassPrecision:
|
||||||
|
def __init__(self, threshold: float, label: str = "precision"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_detections = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
is_detection = m.score >= self.threshold
|
||||||
|
|
||||||
|
if is_detection:
|
||||||
|
num_detections += 1
|
||||||
|
|
||||||
|
if is_detection and m.pred_class == m.true_class:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_detections == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_detections
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassPrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassPrecisionConfig):
|
||||||
|
return TopClassPrecision(
|
||||||
|
threshold=config.threshold,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BalancedAccuracyConfig(BaseConfig):
|
||||||
|
name: Literal["balanced_accuracy"] = "balanced_accuracy"
|
||||||
|
label: str = "balanced_accuracy"
|
||||||
|
exclude_noise: bool = False
|
||||||
|
noise_class: str = "noise"
|
||||||
|
|
||||||
|
|
||||||
|
class BalancedAccuracy:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
exclude_noise: bool = True,
|
||||||
|
noise_class: str = "noise",
|
||||||
|
label: str = "balanced_accuracy",
|
||||||
|
):
|
||||||
|
self.exclude_noise = exclude_noise
|
||||||
|
self.noise_class = noise_class
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true: List[str] = []
|
||||||
|
y_pred: List[str] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic:
|
||||||
|
# Ignore matches that correspond to a sound event
|
||||||
|
# with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_ground_truth and self.exclude_noise:
|
||||||
|
# Ignore predictions that were not matched to a
|
||||||
|
# ground truth
|
||||||
|
continue
|
||||||
|
|
||||||
|
if m.pred_class is None and self.exclude_noise:
|
||||||
|
# Ignore non-predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.true_class or self.noise_class)
|
||||||
|
y_pred.append(m.pred_class or self.noise_class)
|
||||||
|
|
||||||
|
encoder = preprocessing.LabelEncoder()
|
||||||
|
encoder.fit(list(set(y_true) | set(y_pred)))
|
||||||
|
|
||||||
|
y_true = encoder.transform(y_true)
|
||||||
|
y_pred = encoder.transform(y_pred)
|
||||||
|
score = metrics.balanced_accuracy_score(y_true, y_pred)
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(BalancedAccuracyConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: BalancedAccuracyConfig):
|
||||||
|
return BalancedAccuracy(
|
||||||
|
exclude_noise=config.exclude_noise,
|
||||||
|
noise_class=config.noise_class,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TopClassMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
TopClassAveragePrecisionConfig,
|
||||||
|
TopClassROCAUCConfig,
|
||||||
|
TopClassRecallConfig,
|
||||||
|
TopClassPrecisionConfig,
|
||||||
|
BalancedAccuracyConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_top_class_metric(config: TopClassMetricConfig):
|
||||||
|
return top_class_metrics.build(config)
|
||||||
@ -1,560 +0,0 @@
|
|||||||
import random
|
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
|
||||||
from batdetect2.plotting.matches import plot_matches
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
ClipEvaluation,
|
|
||||||
MatchEvaluation,
|
|
||||||
PlotterProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"build_plotter",
|
|
||||||
"ExampleGallery",
|
|
||||||
"ExampleGalleryConfig",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleGalleryConfig(BaseConfig):
|
|
||||||
name: Literal["example_gallery"] = "example_gallery"
|
|
||||||
examples_per_class: int = 5
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleGallery(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
examples_per_class: int,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
):
|
|
||||||
self.examples_per_class = examples_per_class
|
|
||||||
self.preprocessor = preprocessor or build_preprocessor()
|
|
||||||
self.audio_loader = audio_loader or build_audio_loader()
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
|
||||||
per_class_matches = group_matches(clip_evaluations)
|
|
||||||
|
|
||||||
for class_name, matches in per_class_matches.items():
|
|
||||||
true_positives = get_binned_sample(
|
|
||||||
matches.true_positives,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
false_positives = get_binned_sample(
|
|
||||||
matches.false_positives,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
false_negatives = random.sample(
|
|
||||||
matches.false_negatives,
|
|
||||||
k=min(self.examples_per_class, len(matches.false_negatives)),
|
|
||||||
)
|
|
||||||
|
|
||||||
cross_triggers = get_binned_sample(
|
|
||||||
matches.cross_triggers,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = plot_match_gallery(
|
|
||||||
true_positives,
|
|
||||||
false_positives,
|
|
||||||
false_negatives,
|
|
||||||
cross_triggers,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield f"example_gallery/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
|
|
||||||
audio_loader = build_audio_loader(config.audio)
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
config.preprocessing,
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
examples_per_class=config.examples_per_class,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipEvaluationPlotConfig(BaseConfig):
|
|
||||||
name: Literal["example_clip"] = "example_clip"
|
|
||||||
num_plots: int = 5
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PlotClipEvaluation(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_plots: int = 3,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
):
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.num_plots = num_plots
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
|
||||||
examples = random.sample(
|
|
||||||
clip_evaluations,
|
|
||||||
k=min(self.num_plots, len(clip_evaluations)),
|
|
||||||
)
|
|
||||||
|
|
||||||
for index, clip_evaluation in enumerate(examples):
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
plot_matches(
|
|
||||||
clip_evaluation.matches,
|
|
||||||
clip=clip_evaluation.clip,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
ax=ax,
|
|
||||||
)
|
|
||||||
yield f"clip_evaluation/example_{index}", fig
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClipEvaluationPlotConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
audio_loader = build_audio_loader(config.audio)
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
config.preprocessing,
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
num_plots=config.num_plots,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPRCurveConfig(BaseConfig):
|
|
||||||
name: Literal["detection_pr_curve"] = "detection_pr_curve"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPRCurve(PlotterProtocol):
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
|
||||||
y_true, y_score = zip(
|
|
||||||
*[
|
|
||||||
(match.gt_det, match.pred_score)
|
|
||||||
for clip_eval in clip_evaluations
|
|
||||||
for match in clip_eval.matches
|
|
||||||
]
|
|
||||||
)
|
|
||||||
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
|
|
||||||
ax.plot(recall, precision, label="Detector")
|
|
||||||
ax.set_xlabel("Recall")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
yield "detection_pr_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: DetectionPRCurveConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationPRCurvesConfig(BaseConfig):
|
|
||||||
name: Literal["classification_pr_curves"] = "classification_pr_curves"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationPRCurves(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.class_names = class_names
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else "__NONE__"
|
|
||||||
)
|
|
||||||
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
match.pred_class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = label_binarize(y_true, classes=self.class_names)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 10))
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
if class_name not in self.selected:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_pred_class = y_pred[:, class_index]
|
|
||||||
precision, recall, _ = metrics.precision_recall_curve(
|
|
||||||
y_true_class,
|
|
||||||
y_pred_class,
|
|
||||||
)
|
|
||||||
ax.plot(recall, precision, label=class_name)
|
|
||||||
|
|
||||||
ax.set_xlabel("Recall")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "classification_pr_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClassificationPRCurvesConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
class_names=class_names,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCCurveConfig(BaseConfig):
|
|
||||||
name: Literal["detection_roc_curve"] = "detection_roc_curve"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCCurve(PlotterProtocol):
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
|
||||||
y_true, y_score = zip(
|
|
||||||
*[
|
|
||||||
(match.gt_det, match.pred_score)
|
|
||||||
for clip_eval in clip_evaluations
|
|
||||||
for match in clip_eval.matches
|
|
||||||
]
|
|
||||||
)
|
|
||||||
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
|
|
||||||
ax.plot(fpr, tpr, label="Detection")
|
|
||||||
ax.set_xlabel("False Positive Rate")
|
|
||||||
ax.set_ylabel("True Positive Rate")
|
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
yield "detection_roc_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: DetectionROCCurveConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCCurvesConfig(BaseConfig):
|
|
||||||
name: Literal["classification_roc_curves"] = "classification_roc_curves"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCCurves(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.class_names = class_names
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else "__NONE__"
|
|
||||||
)
|
|
||||||
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
match.pred_class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = label_binarize(y_true, classes=self.class_names)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 10))
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
if class_name not in self.selected:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_roced_class = y_pred[:, class_index]
|
|
||||||
fpr, tpr, _ = metrics.roc_curve(
|
|
||||||
y_true_class,
|
|
||||||
y_roced_class,
|
|
||||||
)
|
|
||||||
ax.plot(fpr, tpr, label=class_name)
|
|
||||||
|
|
||||||
ax.set_xlabel("False Positive Rate")
|
|
||||||
ax.set_ylabel("True Positive Rate")
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "classification_roc_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClassificationROCCurvesConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
class_names=class_names,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrixConfig(BaseConfig):
|
|
||||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
|
||||||
background_class: str = "noise"
|
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrix(PlotterProtocol):
|
|
||||||
def __init__(self, background_class: str, class_names: List[str]):
|
|
||||||
self.background_class = background_class
|
|
||||||
self.class_names = class_names
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else self.background_class
|
|
||||||
)
|
|
||||||
|
|
||||||
top_class = match.pred_class
|
|
||||||
y_pred.append(
|
|
||||||
top_class
|
|
||||||
if top_class is not None
|
|
||||||
else self.background_class
|
|
||||||
)
|
|
||||||
|
|
||||||
display = metrics.ConfusionMatrixDisplay.from_predictions(
|
|
||||||
y_true,
|
|
||||||
y_pred,
|
|
||||||
labels=[*self.class_names, self.background_class],
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "confusion_matrix", display.figure_
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ConfusionMatrixConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
background_class=config.background_class,
|
|
||||||
class_names=class_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
|
|
||||||
|
|
||||||
|
|
||||||
PlotConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ExampleGalleryConfig,
|
|
||||||
ClipEvaluationPlotConfig,
|
|
||||||
DetectionPRCurveConfig,
|
|
||||||
ClassificationPRCurvesConfig,
|
|
||||||
DetectionROCCurveConfig,
|
|
||||||
ClassificationROCCurvesConfig,
|
|
||||||
ConfusionMatrixConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_plotter(
|
|
||||||
config: PlotConfig, class_names: List[str]
|
|
||||||
) -> PlotterProtocol:
|
|
||||||
return plots_registry.build(config, class_names)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClassMatches:
|
|
||||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def group_matches(
|
|
||||||
clip_evaluations: Sequence[ClipEvaluation],
|
|
||||||
) -> Dict[str, ClassMatches]:
|
|
||||||
class_examples = defaultdict(ClassMatches)
|
|
||||||
|
|
||||||
for clip_evaluation in clip_evaluations:
|
|
||||||
for match in clip_evaluation.matches:
|
|
||||||
gt_class = match.gt_class
|
|
||||||
pred_class = match.pred_class
|
|
||||||
|
|
||||||
if pred_class is None:
|
|
||||||
class_examples[gt_class].false_negatives.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if gt_class is None:
|
|
||||||
class_examples[pred_class].false_positives.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if gt_class != pred_class:
|
|
||||||
class_examples[gt_class].cross_triggers.append(match)
|
|
||||||
class_examples[pred_class].cross_triggers.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
class_examples[gt_class].true_positives.append(match)
|
|
||||||
|
|
||||||
return class_examples
|
|
||||||
|
|
||||||
|
|
||||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
|
||||||
if len(matches) < n_examples:
|
|
||||||
return matches
|
|
||||||
|
|
||||||
indices, pred_scores = zip(
|
|
||||||
*[
|
|
||||||
(index, match.pred_class_scores[pred_class])
|
|
||||||
for index, match in enumerate(matches)
|
|
||||||
if (pred_class := match.pred_class) is not None
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
|
||||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
|
||||||
sample = df.groupby("bins").sample(1)
|
|
||||||
return [matches[ind] for ind in sample["indices"]]
|
|
||||||
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlotConfig(BaseConfig):
|
||||||
|
label: str = "plot"
|
||||||
|
theme: str = "default"
|
||||||
|
title: Optional[str] = None
|
||||||
|
figsize: tuple[int, int] = (10, 10)
|
||||||
|
dpi: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlot:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
label: str = "plot",
|
||||||
|
figsize: tuple[int, int] = (10, 10),
|
||||||
|
title: Optional[str] = None,
|
||||||
|
dpi: int = 100,
|
||||||
|
theme: str = "default",
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.label = label
|
||||||
|
self.figsize = figsize
|
||||||
|
self.dpi = dpi
|
||||||
|
self.theme = theme
|
||||||
|
self.title = title
|
||||||
|
|
||||||
|
def create_figure(self) -> Figure:
|
||||||
|
plt.style.use(self.theme)
|
||||||
|
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
||||||
|
|
||||||
|
if self.title is not None:
|
||||||
|
fig.suptitle(self.title)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
|
||||||
|
return cls(
|
||||||
|
targets=targets,
|
||||||
|
figsize=config.figsize,
|
||||||
|
dpi=config.dpi,
|
||||||
|
theme=config.theme,
|
||||||
|
label=config.label,
|
||||||
|
title=config.title,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
370
src/batdetect2/evaluate/plots/classification.py
Normal file
370
src/batdetect2/evaluate/plots/classification.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
|
ClipEval,
|
||||||
|
_extract_per_class_metric_data,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.metrics import (
|
||||||
|
plot_pr_curve,
|
||||||
|
plot_pr_curves,
|
||||||
|
plot_roc_curve,
|
||||||
|
plot_roc_curves,
|
||||||
|
plot_threshold_precision_curve,
|
||||||
|
plot_threshold_precision_curves,
|
||||||
|
plot_threshold_recall_curve,
|
||||||
|
plot_threshold_recall_curves,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
ClassificationPlotter = Callable[
|
||||||
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
|
]
|
||||||
|
|
||||||
|
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||||
|
Registry("classification_plot")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Classification Precision-Recall Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: compute_precision_recall(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives=num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_pr_curves(data, ax=ax)
|
||||||
|
yield self.label, fig
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (precision, recall, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classification_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||||
|
label: str = "threshold_precision_curve"
|
||||||
|
title: Optional[str] = "Classification Threshold-Precision Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdPrecisionCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: compute_precision_recall(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_threshold_precision_curves(data, ax=ax)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (precision, _, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_threshold_precision_curve(
|
||||||
|
thresholds,
|
||||||
|
precision,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classification_plots.register(ThresholdPrecisionCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ThresholdPrecisionCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||||
|
label: str = "threshold_recall_curve"
|
||||||
|
title: Optional[str] = "Classification Threshold-Recall Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdRecallCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: compute_precision_recall(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (_, recall, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_threshold_recall_curve(
|
||||||
|
thresholds,
|
||||||
|
recall,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classification_plots.register(ThresholdRecallCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ThresholdRecallCurveConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ThresholdRecallCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Classification ROC Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: metrics.roc_curve(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_roc_curves(data, ax=ax)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classification_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ThresholdPrecisionCurveConfig,
|
||||||
|
ThresholdRecallCurveConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_classification_plotter(
|
||||||
|
config: ClassificationPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClassificationPlotter:
|
||||||
|
return classification_plots.build(config, targets)
|
||||||
189
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
189
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.metrics import (
|
||||||
|
plot_pr_curve,
|
||||||
|
plot_pr_curves,
|
||||||
|
plot_roc_curve,
|
||||||
|
plot_roc_curves,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipClassificationPlotConfig",
|
||||||
|
"ClipClassificationPlotter",
|
||||||
|
"build_clip_classification_plotter",
|
||||||
|
]
|
||||||
|
|
||||||
|
ClipClassificationPlotter = Callable[
|
||||||
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
|
]
|
||||||
|
|
||||||
|
clip_classification_plots: Registry[
|
||||||
|
ClipClassificationPlotter, [TargetProtocol]
|
||||||
|
] = Registry("clip_classification_plot")
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
for class_name in self.targets.class_names:
|
||||||
|
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||||
|
y_score = [
|
||||||
|
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||||
|
]
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
data[class_name] = (precision, recall, thresholds)
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_pr_curves(data, ax=ax)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (precision, recall, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@clip_classification_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Clip Classification ROC Curve"
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
for class_name in self.targets.class_names:
|
||||||
|
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||||
|
y_score = [
|
||||||
|
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||||
|
]
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
data[class_name] = (fpr, tpr, thresholds)
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_roc_curves(data, ax=ax)
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@clip_classification_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClipClassificationPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_classification_plotter(
|
||||||
|
config: ClipClassificationPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClipClassificationPlotter:
|
||||||
|
return clip_classification_plots.build(config, targets)
|
||||||
163
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
163
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipDetectionPlotConfig",
|
||||||
|
"ClipDetectionPlotter",
|
||||||
|
"build_clip_detection_plotter",
|
||||||
|
]
|
||||||
|
|
||||||
|
ClipDetectionPlotter = Callable[
|
||||||
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||||
|
Registry("clip_detection_plot")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@clip_detection_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Clip Detection ROC Curve"
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@clip_detection_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
|
label: str = "score_distribution"
|
||||||
|
title: Optional[str] = "Clip Detection Score Distribution"
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlot(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||||
|
sns.histplot(
|
||||||
|
data=df,
|
||||||
|
x="score",
|
||||||
|
binwidth=0.025,
|
||||||
|
binrange=(0, 1),
|
||||||
|
hue="is_true",
|
||||||
|
ax=ax,
|
||||||
|
stat="probability",
|
||||||
|
common_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ScoreDistributionPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClipDetectionPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ScoreDistributionPlotConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_detection_plotter(
|
||||||
|
config: ClipDetectionPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClipDetectionPlotter:
|
||||||
|
return clip_detection_plots.build(config, targets)
|
||||||
309
src/batdetect2/evaluate/plots/detection.py
Normal file
309
src/batdetect2/evaluate/plots/detection.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
import random
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.detections import plot_clip_detections
|
||||||
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
|
|
||||||
|
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||||
|
|
||||||
|
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||||
|
name="detection_plot"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Detection Precision-Recall Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives=num_positives,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@detection_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Detection ROC Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@detection_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
|
label: str = "score_distribution"
|
||||||
|
title: Optional[str] = "Detection Score Distribution"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlot(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
sns.histplot(
|
||||||
|
data=df,
|
||||||
|
x="score",
|
||||||
|
binwidth=0.025,
|
||||||
|
binrange=(0, 1),
|
||||||
|
hue="is_true",
|
||||||
|
ax=ax,
|
||||||
|
stat="probability",
|
||||||
|
common_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@detection_plots.register(ScoreDistributionPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ScoreDistributionPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["example_detection"] = "example_detection"
|
||||||
|
label: str = "example_detection"
|
||||||
|
title: Optional[str] = "Example Detection"
|
||||||
|
figsize: tuple[int, int] = (10, 4)
|
||||||
|
num_examples: int = 5
|
||||||
|
threshold: float = 0.2
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleDetectionPlot(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
num_examples: int = 5,
|
||||||
|
threshold: float = 0.2,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_examples = num_examples
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.threshold = threshold
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
sample = clip_evaluations
|
||||||
|
|
||||||
|
if self.num_examples < len(sample):
|
||||||
|
sample = random.sample(sample, self.num_examples)
|
||||||
|
|
||||||
|
for num_example, clip_eval in enumerate(sample):
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_clip_detections(
|
||||||
|
clip_eval,
|
||||||
|
ax=ax,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield f"{self.label}/example_{num_example}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@detection_plots.register(ExampleDetectionPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ExampleDetectionPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
return ExampleDetectionPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
num_examples=config.num_examples,
|
||||||
|
audio_loader=build_audio_loader(config.audio),
|
||||||
|
preprocessor=build_preprocessor(config.preprocessing),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DetectionPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ScoreDistributionPlotConfig,
|
||||||
|
ExampleDetectionPlotConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_detection_plotter(
|
||||||
|
config: DetectionPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> DetectionPlotter:
|
||||||
|
return detection_plots.build(config, targets)
|
||||||
444
src/batdetect2/evaluate/plots/top_class.py
Normal file
444
src/batdetect2/evaluate/plots/top_class.py
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
|
|
||||||
|
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||||
|
|
||||||
|
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||||
|
name="top_class_plot"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Top Class Precision-Recall Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives=num_positives,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@top_class_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Top Class ROC Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@top_class_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrixConfig(BasePlotConfig):
|
||||||
|
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||||
|
title: Optional[str] = "Top Class Confusion Matrix"
|
||||||
|
figsize: tuple[int, int] = (10, 10)
|
||||||
|
label: str = "confusion_matrix"
|
||||||
|
exclude_generic: bool = True
|
||||||
|
exclude_noise: bool = False
|
||||||
|
noise_class: str = "noise"
|
||||||
|
normalize: Literal["true", "pred", "all", "none"] = "true"
|
||||||
|
threshold: float = 0.2
|
||||||
|
add_colorbar: bool = True
|
||||||
|
cmap: str = "Blues"
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
exclude_generic: bool = True,
|
||||||
|
exclude_noise: bool = False,
|
||||||
|
noise_class: str = "noise",
|
||||||
|
add_colorbar: bool = True,
|
||||||
|
normalize: Literal["true", "pred", "all", "none"] = "true",
|
||||||
|
cmap: str = "Blues",
|
||||||
|
threshold: float = 0.2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.exclude_generic = exclude_generic
|
||||||
|
self.exclude_noise = exclude_noise
|
||||||
|
self.noise_class = noise_class
|
||||||
|
self.normalize = normalize
|
||||||
|
self.add_colorbar = add_colorbar
|
||||||
|
self.threshold = threshold
|
||||||
|
self.cmap = cmap
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true: List[str] = []
|
||||||
|
y_pred: List[str] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
true_class = m.true_class
|
||||||
|
pred_class = m.pred_class
|
||||||
|
|
||||||
|
if not m.is_prediction and self.exclude_noise:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_ground_truth and self.exclude_noise:
|
||||||
|
# Ignore matches that don't correspond to a ground truth
|
||||||
|
continue
|
||||||
|
|
||||||
|
if m.score < self.threshold:
|
||||||
|
if self.exclude_noise:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pred_class = self.noise_class
|
||||||
|
|
||||||
|
if m.is_generic:
|
||||||
|
if self.exclude_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
true_class = self.targets.detection_class_name
|
||||||
|
|
||||||
|
y_true.append(true_class or self.noise_class)
|
||||||
|
y_pred.append(pred_class or self.noise_class)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
class_names = [*self.targets.class_names]
|
||||||
|
|
||||||
|
if not self.exclude_generic:
|
||||||
|
class_names.append(self.targets.detection_class_name)
|
||||||
|
|
||||||
|
if not self.exclude_noise:
|
||||||
|
class_names.append(self.noise_class)
|
||||||
|
|
||||||
|
metrics.ConfusionMatrixDisplay.from_predictions(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
labels=class_names,
|
||||||
|
ax=ax,
|
||||||
|
xticks_rotation="vertical",
|
||||||
|
cmap=self.cmap,
|
||||||
|
colorbar=self.add_colorbar,
|
||||||
|
normalize=self.normalize if self.normalize != "none" else None,
|
||||||
|
values_format=".2f",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
@top_class_plots.register(ConfusionMatrixConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
|
||||||
|
return ConfusionMatrix.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
exclude_generic=config.exclude_generic,
|
||||||
|
exclude_noise=config.exclude_noise,
|
||||||
|
noise_class=config.noise_class,
|
||||||
|
add_colorbar=config.add_colorbar,
|
||||||
|
normalize=config.normalize,
|
||||||
|
cmap=config.cmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["example_classification"] = "example_classification"
|
||||||
|
label: str = "example_classification"
|
||||||
|
title: Optional[str] = "Example Classification"
|
||||||
|
num_examples: int = 4
|
||||||
|
threshold: float = 0.2
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleClassificationPlot(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
num_examples: int = 4,
|
||||||
|
threshold: float = 0.2,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_examples = num_examples
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.threshold = threshold
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.num_examples = num_examples
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||||
|
|
||||||
|
for class_name, matches in grouped.items():
|
||||||
|
true_positives: List[MatchEval] = get_binned_sample(
|
||||||
|
matches.true_positives,
|
||||||
|
n_examples=self.num_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
false_positives: List[MatchEval] = get_binned_sample(
|
||||||
|
matches.false_positives,
|
||||||
|
n_examples=self.num_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
false_negatives: List[MatchEval] = random.sample(
|
||||||
|
matches.false_negatives,
|
||||||
|
k=min(self.num_examples, len(matches.false_negatives)),
|
||||||
|
)
|
||||||
|
|
||||||
|
cross_triggers: List[MatchEval] = get_binned_sample(
|
||||||
|
matches.cross_triggers, n_examples=self.num_examples
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
|
||||||
|
fig = plot_match_gallery(
|
||||||
|
true_positives,
|
||||||
|
false_positives,
|
||||||
|
false_negatives,
|
||||||
|
cross_triggers,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
n_examples=self.num_examples,
|
||||||
|
fig=fig,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.title is not None:
|
||||||
|
fig.suptitle(f"{self.title}: {class_name}")
|
||||||
|
else:
|
||||||
|
fig.suptitle(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@top_class_plots.register(ExampleClassificationPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ExampleClassificationPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
return ExampleClassificationPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
num_examples=config.num_examples,
|
||||||
|
threshold=config.threshold,
|
||||||
|
audio_loader=build_audio_loader(config.audio),
|
||||||
|
preprocessor=build_preprocessor(config.preprocessing),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TopClassPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ConfusionMatrixConfig,
|
||||||
|
ExampleClassificationPlotConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_top_class_plotter(
|
||||||
|
config: TopClassPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> TopClassPlotter:
|
||||||
|
return top_class_plots.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassMatches:
|
||||||
|
false_positives: List[MatchEval] = field(default_factory=list)
|
||||||
|
false_negatives: List[MatchEval] = field(default_factory=list)
|
||||||
|
true_positives: List[MatchEval] = field(default_factory=list)
|
||||||
|
cross_triggers: List[MatchEval] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def group_matches(
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
threshold: float = 0.2,
|
||||||
|
) -> Dict[str, ClassMatches]:
|
||||||
|
class_examples = defaultdict(ClassMatches)
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
gt_class = match.true_class
|
||||||
|
pred_class = match.pred_class
|
||||||
|
is_pred = match.score >= threshold
|
||||||
|
|
||||||
|
if not is_pred and gt_class is not None:
|
||||||
|
class_examples[gt_class].false_negatives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not is_pred:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class is None:
|
||||||
|
class_examples[pred_class].false_positives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class != pred_class:
|
||||||
|
class_examples[pred_class].cross_triggers.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_examples[gt_class].true_positives.append(match)
|
||||||
|
|
||||||
|
return class_examples
|
||||||
|
|
||||||
|
|
||||||
|
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||||
|
if len(matches) < n_examples:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
indices, pred_scores = zip(
|
||||||
|
*[(index, match.score) for index, match in enumerate(matches)]
|
||||||
|
)
|
||||||
|
|
||||||
|
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||||
|
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||||
|
sample = df.groupby("bins").sample(1)
|
||||||
|
return [matches[ind] for ind in sample["indices"]]
|
||||||
@ -5,9 +5,9 @@ from pydantic import Field
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.typing import ClipEvaluation
|
from batdetect2.typing import ClipMatches
|
||||||
|
|
||||||
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
|
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
|
||||||
|
|
||||||
|
|
||||||
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||||
@ -21,20 +21,18 @@ class FullEvaluationTableConfig(BaseConfig):
|
|||||||
|
|
||||||
class FullEvaluationTable:
|
class FullEvaluationTable:
|
||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self, clip_evaluations: Sequence[ClipMatches]
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
return extract_matches_dataframe(clip_evaluations)
|
return extract_matches_dataframe(clip_evaluations)
|
||||||
|
|
||||||
@classmethod
|
@tables_registry.register(FullEvaluationTableConfig)
|
||||||
def from_config(cls, config: FullEvaluationTableConfig):
|
@staticmethod
|
||||||
return cls()
|
def from_config(config: FullEvaluationTableConfig):
|
||||||
|
return FullEvaluationTable()
|
||||||
|
|
||||||
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_matches_dataframe(
|
def extract_matches_dataframe(
|
||||||
clip_evaluations: Sequence[ClipEvaluation],
|
clip_evaluations: Sequence[ClipMatches],
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
@ -78,8 +76,8 @@ def extract_matches_dataframe(
|
|||||||
("gt", "low_freq"): gt_low_freq,
|
("gt", "low_freq"): gt_low_freq,
|
||||||
("gt", "high_freq"): gt_high_freq,
|
("gt", "high_freq"): gt_high_freq,
|
||||||
("pred", "score"): match.pred_score,
|
("pred", "score"): match.pred_score,
|
||||||
("pred", "class"): match.pred_class,
|
("pred", "class"): match.top_class,
|
||||||
("pred", "class_score"): match.pred_class_score,
|
("pred", "class_score"): match.top_class_score,
|
||||||
("pred", "start_time"): pred_start_time,
|
("pred", "start_time"): pred_start_time,
|
||||||
("pred", "end_time"): pred_end_time,
|
("pred", "end_time"): pred_end_time,
|
||||||
("pred", "low_freq"): pred_low_freq,
|
("pred", "low_freq"): pred_low_freq,
|
||||||
|
|||||||
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.evaluate.tasks.base import tasks_registry
|
||||||
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.clip_classification import (
|
||||||
|
ClipClassificationTaskConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TaskConfig",
|
||||||
|
"build_task",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
TaskConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClassificationTaskConfig,
|
||||||
|
DetectionTaskConfig,
|
||||||
|
ClipDetectionTaskConfig,
|
||||||
|
ClipClassificationTaskConfig,
|
||||||
|
TopClassDetectionTaskConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_task(
|
||||||
|
config: TaskConfig,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
) -> EvaluatorProtocol:
|
||||||
|
targets = targets or build_targets()
|
||||||
|
return tasks_registry.build(config, targets)
|
||||||
175
src/batdetect2/evaluate/tasks/base.py
Normal file
175
src/batdetect2/evaluate/tasks/base.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
from batdetect2.evaluate.match import (
|
||||||
|
MatchConfig,
|
||||||
|
StartTimeMatchConfig,
|
||||||
|
build_matcher,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseTaskConfig",
|
||||||
|
"BaseTask",
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||||
|
"tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
T_Output = TypeVar("T_Output")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTaskConfig(BaseConfig):
|
||||||
|
prefix: str
|
||||||
|
ignore_start_end: float = 0.01
|
||||||
|
matching_strategy: MatchConfig = Field(
|
||||||
|
default_factory=StartTimeMatchConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||||
|
targets: TargetProtocol
|
||||||
|
|
||||||
|
matcher: MatcherProtocol
|
||||||
|
|
||||||
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||||
|
|
||||||
|
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
|
||||||
|
ignore_start_end: float
|
||||||
|
|
||||||
|
prefix: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
matcher: MatcherProtocol,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
|
prefix: str,
|
||||||
|
ignore_start_end: float = 0.01,
|
||||||
|
plots: Optional[
|
||||||
|
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
] = None,
|
||||||
|
):
|
||||||
|
self.matcher = matcher
|
||||||
|
self.metrics = metrics
|
||||||
|
self.plots = plots or []
|
||||||
|
self.targets = targets
|
||||||
|
self.prefix = prefix
|
||||||
|
self.ignore_start_end = ignore_start_end
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self,
|
||||||
|
eval_outputs: List[T_Output],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
scores = [metric(eval_outputs) for metric in self.metrics]
|
||||||
|
return {
|
||||||
|
f"{self.prefix}/{name}": score
|
||||||
|
for metric_output in scores
|
||||||
|
for name, score in metric_output.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_plots(
|
||||||
|
self, eval_outputs: List[T_Output]
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
for plot in self.plots:
|
||||||
|
for name, fig in plot(eval_outputs):
|
||||||
|
yield f"{self.prefix}/{name}", fig
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
|
) -> List[T_Output]:
|
||||||
|
return [
|
||||||
|
self.evaluate_clip(clip_annotation, preds)
|
||||||
|
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||||
|
]
|
||||||
|
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> T_Output: ...
|
||||||
|
|
||||||
|
def include_sound_event_annotation(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
clip: data.Clip,
|
||||||
|
) -> bool:
|
||||||
|
if not self.targets.filter(sound_event_annotation):
|
||||||
|
return False
|
||||||
|
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return is_in_bounds(
|
||||||
|
geometry,
|
||||||
|
clip,
|
||||||
|
self.ignore_start_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
def include_prediction(
|
||||||
|
self,
|
||||||
|
prediction: RawPrediction,
|
||||||
|
clip: data.Clip,
|
||||||
|
) -> bool:
|
||||||
|
return is_in_bounds(
|
||||||
|
prediction.geometry,
|
||||||
|
clip,
|
||||||
|
self.ignore_start_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(
|
||||||
|
cls,
|
||||||
|
config: BaseTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
|
plots: Optional[
|
||||||
|
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
matcher = build_matcher(config.matching_strategy)
|
||||||
|
return cls(
|
||||||
|
matcher=matcher,
|
||||||
|
targets=targets,
|
||||||
|
metrics=metrics,
|
||||||
|
plots=plots,
|
||||||
|
prefix=config.prefix,
|
||||||
|
ignore_start_end=config.ignore_start_end,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_in_bounds(
|
||||||
|
geometry: data.Geometry,
|
||||||
|
clip: data.Clip,
|
||||||
|
buffer: float,
|
||||||
|
) -> bool:
|
||||||
|
start_time = compute_bounds(geometry)[0]
|
||||||
|
return (start_time >= clip.start_time + buffer) and (
|
||||||
|
start_time <= clip.end_time - buffer
|
||||||
|
)
|
||||||
149
src/batdetect2/evaluate/tasks/classification.py
Normal file
149
src/batdetect2/evaluate/tasks/classification.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
|
ClassificationAveragePrecisionConfig,
|
||||||
|
ClassificationMetricConfig,
|
||||||
|
ClipEval,
|
||||||
|
MatchEval,
|
||||||
|
build_classification_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.plots.classification import (
|
||||||
|
ClassificationPlotConfig,
|
||||||
|
build_classification_plotter,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||||
|
prefix: str = "classification"
|
||||||
|
metrics: List[ClassificationMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||||
|
)
|
||||||
|
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
||||||
|
include_generics: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationTask(BaseTask[ClipEval]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
include_generics: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.include_generics = include_generics
|
||||||
|
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
preds = [
|
||||||
|
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||||
|
]
|
||||||
|
|
||||||
|
all_gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
]
|
||||||
|
|
||||||
|
per_class_matches = {}
|
||||||
|
|
||||||
|
for class_name in self.targets.class_names:
|
||||||
|
class_idx = self.targets.class_names.index(class_name)
|
||||||
|
|
||||||
|
# Only match to targets of the given class
|
||||||
|
gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in all_gts
|
||||||
|
if self.is_class(sound_event, class_name)
|
||||||
|
]
|
||||||
|
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
|
predictions=[pred.geometry for pred in preds],
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
|
true_class = (
|
||||||
|
self.targets.encode_class(gt) if gt is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
score = (
|
||||||
|
float(pred.class_scores[class_idx])
|
||||||
|
if pred is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEval(
|
||||||
|
clip=clip,
|
||||||
|
gt=gt,
|
||||||
|
pred=pred,
|
||||||
|
is_prediction=pred is not None,
|
||||||
|
is_ground_truth=gt is not None,
|
||||||
|
is_generic=gt is not None and true_class is None,
|
||||||
|
true_class=true_class,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
per_class_matches[class_name] = matches
|
||||||
|
|
||||||
|
return ClipEval(clip=clip, matches=per_class_matches)
|
||||||
|
|
||||||
|
def is_class(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
class_name: str,
|
||||||
|
) -> bool:
|
||||||
|
sound_event_class = self.targets.encode_class(sound_event)
|
||||||
|
|
||||||
|
if sound_event_class is None and self.include_generics:
|
||||||
|
# Sound events that are generic could be of the given
|
||||||
|
# class
|
||||||
|
return True
|
||||||
|
|
||||||
|
return sound_event_class == class_name
|
||||||
|
|
||||||
|
@tasks_registry.register(ClassificationTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClassificationTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [
|
||||||
|
build_classification_metric(metric, targets)
|
||||||
|
for metric in config.metrics
|
||||||
|
]
|
||||||
|
plots = [
|
||||||
|
build_classification_plotter(plot, targets)
|
||||||
|
for plot in config.plots
|
||||||
|
]
|
||||||
|
return ClassificationTask.build(
|
||||||
|
config=config,
|
||||||
|
plots=plots,
|
||||||
|
targets=targets,
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
85
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
85
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.clip_classification import (
|
||||||
|
ClipClassificationAveragePrecisionConfig,
|
||||||
|
ClipClassificationMetricConfig,
|
||||||
|
ClipEval,
|
||||||
|
build_clip_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.plots.clip_classification import (
|
||||||
|
ClipClassificationPlotConfig,
|
||||||
|
build_clip_classification_plotter,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["clip_classification"] = "clip_classification"
|
||||||
|
prefix: str = "clip_classification"
|
||||||
|
metrics: List[ClipClassificationMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
ClipClassificationAveragePrecisionConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gt_classes = set()
|
||||||
|
for sound_event in clip_annotation.sound_events:
|
||||||
|
if not self.include_sound_event_annotation(sound_event, clip):
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_name = self.targets.encode_class(sound_event)
|
||||||
|
|
||||||
|
if class_name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
gt_classes.add(class_name)
|
||||||
|
|
||||||
|
pred_scores = defaultdict(float)
|
||||||
|
for pred in predictions:
|
||||||
|
if not self.include_prediction(pred, clip):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for class_idx, class_name in enumerate(self.targets.class_names):
|
||||||
|
pred_scores[class_name] = max(
|
||||||
|
float(pred.class_scores[class_idx]),
|
||||||
|
pred_scores[class_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
|
||||||
|
|
||||||
|
@tasks_registry.register(ClipClassificationTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClipClassificationTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_clip_classification_plotter(plot, targets)
|
||||||
|
for plot in config.plots
|
||||||
|
]
|
||||||
|
return ClipClassificationTask.build(
|
||||||
|
config=config,
|
||||||
|
plots=plots,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
76
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
76
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.clip_detection import (
|
||||||
|
ClipDetectionAveragePrecisionConfig,
|
||||||
|
ClipDetectionMetricConfig,
|
||||||
|
ClipEval,
|
||||||
|
build_clip_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.plots.clip_detection import (
|
||||||
|
ClipDetectionPlotConfig,
|
||||||
|
build_clip_detection_plotter,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["clip_detection"] = "clip_detection"
|
||||||
|
prefix: str = "clip_detection"
|
||||||
|
metrics: List[ClipDetectionMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
ClipDetectionAveragePrecisionConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gt_det = any(
|
||||||
|
self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_score = 0
|
||||||
|
for pred in predictions:
|
||||||
|
if not self.include_prediction(pred, clip):
|
||||||
|
continue
|
||||||
|
|
||||||
|
pred_score = max(pred_score, pred.detection_score)
|
||||||
|
|
||||||
|
return ClipEval(
|
||||||
|
gt_det=gt_det,
|
||||||
|
score=pred_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tasks_registry.register(ClipDetectionTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClipDetectionTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_clip_detection_plotter(plot, targets)
|
||||||
|
for plot in config.plots
|
||||||
|
]
|
||||||
|
return ClipDetectionTask.build(
|
||||||
|
config=config,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
plots=plots,
|
||||||
|
)
|
||||||
88
src/batdetect2/evaluate/tasks/detection.py
Normal file
88
src/batdetect2/evaluate/tasks/detection.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.detection import (
|
||||||
|
ClipEval,
|
||||||
|
DetectionAveragePrecisionConfig,
|
||||||
|
DetectionMetricConfig,
|
||||||
|
MatchEval,
|
||||||
|
build_detection_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.plots.detection import (
|
||||||
|
DetectionPlotConfig,
|
||||||
|
build_detection_plotter,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||||
|
prefix: str = "detection"
|
||||||
|
metrics: List[DetectionMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||||
|
)
|
||||||
|
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
]
|
||||||
|
preds = [
|
||||||
|
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||||
|
]
|
||||||
|
scores = [pred.detection_score for pred in preds]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
|
predictions=[pred.geometry for pred in preds],
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEval(
|
||||||
|
gt=gt,
|
||||||
|
pred=pred,
|
||||||
|
is_prediction=pred is not None,
|
||||||
|
is_ground_truth=gt is not None,
|
||||||
|
score=pred.detection_score if pred is not None else 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClipEval(clip=clip, matches=matches)
|
||||||
|
|
||||||
|
@tasks_registry.register(DetectionTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: DetectionTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_detection_plotter(plot, targets) for plot in config.plots
|
||||||
|
]
|
||||||
|
return DetectionTask.build(
|
||||||
|
config=config,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
plots=plots,
|
||||||
|
)
|
||||||
111
src/batdetect2/evaluate/tasks/top_class.py
Normal file
111
src/batdetect2/evaluate/tasks/top_class.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
|
ClipEval,
|
||||||
|
MatchEval,
|
||||||
|
TopClassAveragePrecisionConfig,
|
||||||
|
TopClassMetricConfig,
|
||||||
|
build_top_class_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.plots.top_class import (
|
||||||
|
TopClassPlotConfig,
|
||||||
|
build_top_class_plotter,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["top_class_detection"] = "top_class_detection"
|
||||||
|
prefix: str = "top_class"
|
||||||
|
metrics: List[TopClassMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||||
|
)
|
||||||
|
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
]
|
||||||
|
preds = [
|
||||||
|
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||||
|
]
|
||||||
|
# Take the highest score for each prediction
|
||||||
|
scores = [pred.class_scores.max() for pred in preds]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
|
predictions=[pred.geometry for pred in preds],
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
|
true_class = (
|
||||||
|
self.targets.encode_class(gt) if gt is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
class_idx = (
|
||||||
|
pred.class_scores.argmax() if pred is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
score = (
|
||||||
|
float(pred.class_scores[class_idx]) if pred is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_class = (
|
||||||
|
self.targets.class_names[class_idx]
|
||||||
|
if class_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEval(
|
||||||
|
clip=clip,
|
||||||
|
gt=gt,
|
||||||
|
pred=pred,
|
||||||
|
is_ground_truth=gt is not None,
|
||||||
|
is_prediction=pred is not None,
|
||||||
|
true_class=true_class,
|
||||||
|
is_generic=gt is not None and true_class is None,
|
||||||
|
pred_class=pred_class,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClipEval(clip=clip, matches=matches)
|
||||||
|
|
||||||
|
@tasks_registry.register(TopClassDetectionTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: TopClassDetectionTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||||
|
]
|
||||||
|
return TopClassDetectionTask.build(
|
||||||
|
config=config,
|
||||||
|
plots=plots,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
@ -11,7 +11,6 @@ from batdetect2.plotting.matches import (
|
|||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
plot_false_negative_match,
|
plot_false_negative_match,
|
||||||
plot_false_positive_match,
|
plot_false_positive_match,
|
||||||
plot_matches,
|
|
||||||
plot_true_positive_match,
|
plot_true_positive_match,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,7 +21,6 @@ __all__ = [
|
|||||||
"plot_cross_trigger_match",
|
"plot_cross_trigger_match",
|
||||||
"plot_false_negative_match",
|
"plot_false_negative_match",
|
||||||
"plot_false_positive_match",
|
"plot_false_positive_match",
|
||||||
"plot_matches",
|
|
||||||
"plot_spectrogram",
|
"plot_spectrogram",
|
||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_detection_heatmap",
|
"plot_detection_heatmap",
|
||||||
|
|||||||
@ -65,8 +65,6 @@ def plot_anchor_points(
|
|||||||
if not targets.filter(sound_event):
|
if not targets.filter(sound_event):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sound_event = targets.transform(sound_event)
|
|
||||||
|
|
||||||
position, _ = targets.encode_roi(sound_event)
|
position, _ = targets.encode_roi(sound_event)
|
||||||
positions.append(position)
|
positions.append(position)
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def create_ax(
|
|||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
|
||||||
|
|
||||||
return ax # type: ignore
|
return ax # type: ignore
|
||||||
|
|
||||||
@ -66,6 +66,9 @@ def plot_spectrogram(
|
|||||||
vmax=vmax,
|
vmax=vmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(start_time, end_time)
|
||||||
|
ax.set_ylim(min_freq, max_freq)
|
||||||
|
|
||||||
if add_colorbar:
|
if add_colorbar:
|
||||||
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
||||||
|
|
||||||
|
|||||||
113
src/batdetect2/plotting/detections.py
Normal file
113
src/batdetect2/plotting/detections.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from matplotlib import axes, patches
|
||||||
|
from soundevent.plot import plot_geometry
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||||
|
from batdetect2.plotting.clips import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
plot_clip,
|
||||||
|
)
|
||||||
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"plot_clip_detections",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def plot_clip_detections(
|
||||||
|
clip_eval: ClipEval,
|
||||||
|
figsize: tuple[int, int] = (10, 10),
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
threshold: float = 0.2,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_title: bool = True,
|
||||||
|
fill: bool = False,
|
||||||
|
linewidth: float = 1.0,
|
||||||
|
gt_color: str = "green",
|
||||||
|
gt_linestyle: str = "-",
|
||||||
|
true_pred_color: str = "yellow",
|
||||||
|
true_pred_linestyle: str = "--",
|
||||||
|
false_pred_color: str = "blue",
|
||||||
|
false_pred_linestyle: str = "-",
|
||||||
|
missed_gt_color: str = "red",
|
||||||
|
missed_gt_linestyle: str = "-",
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(figsize=figsize, ax=ax)
|
||||||
|
|
||||||
|
plot_clip(
|
||||||
|
clip_eval.clip,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
is_match = (
|
||||||
|
m.pred is not None and m.gt is not None and m.score >= threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.pred is not None:
|
||||||
|
color = true_pred_color if is_match else false_pred_color
|
||||||
|
plot_geometry(
|
||||||
|
m.pred.geometry,
|
||||||
|
ax=ax,
|
||||||
|
add_points=False,
|
||||||
|
facecolor="none" if not fill else color,
|
||||||
|
alpha=m.pred.detection_score,
|
||||||
|
linewidth=linewidth,
|
||||||
|
linestyle=true_pred_linestyle
|
||||||
|
if is_match
|
||||||
|
else missed_gt_linestyle,
|
||||||
|
color=color,
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.gt is not None:
|
||||||
|
color = gt_color if is_match else missed_gt_color
|
||||||
|
plot_geometry(
|
||||||
|
m.gt.sound_event.geometry, # type: ignore
|
||||||
|
ax=ax,
|
||||||
|
add_points=False,
|
||||||
|
linewidth=linewidth,
|
||||||
|
facecolor="none" if not fill else color,
|
||||||
|
linestyle=gt_linestyle if is_match else false_pred_linestyle,
|
||||||
|
color=color,
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title(clip_eval.clip.recording.path.name)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
handles=[
|
||||||
|
patches.Patch(
|
||||||
|
label="found GT",
|
||||||
|
edgecolor=gt_color,
|
||||||
|
facecolor="none" if not fill else gt_color,
|
||||||
|
linestyle=gt_linestyle,
|
||||||
|
),
|
||||||
|
patches.Patch(
|
||||||
|
label="missed GT",
|
||||||
|
edgecolor=missed_gt_color,
|
||||||
|
facecolor="none" if not fill else missed_gt_color,
|
||||||
|
linestyle=missed_gt_linestyle,
|
||||||
|
),
|
||||||
|
patches.Patch(
|
||||||
|
label="true Det",
|
||||||
|
edgecolor=true_pred_color,
|
||||||
|
facecolor="none" if not fill else true_pred_color,
|
||||||
|
linestyle=true_pred_linestyle,
|
||||||
|
),
|
||||||
|
patches.Patch(
|
||||||
|
label="false Det",
|
||||||
|
edgecolor=false_pred_color,
|
||||||
|
facecolor="none" if not fill else false_pred_color,
|
||||||
|
linestyle=false_pred_linestyle,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return ax
|
||||||
@ -1,81 +1,109 @@
|
|||||||
from typing import List, Optional
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
from batdetect2.plotting.matches import (
|
from batdetect2.plotting.matches import (
|
||||||
|
MatchProtocol,
|
||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
plot_false_negative_match,
|
plot_false_negative_match,
|
||||||
plot_false_positive_match,
|
plot_false_positive_match,
|
||||||
plot_true_positive_match,
|
plot_true_positive_match,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.evaluate import MatchEvaluation
|
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = ["plot_match_gallery"]
|
__all__ = ["plot_match_gallery"]
|
||||||
|
|
||||||
|
|
||||||
def plot_match_gallery(
|
def plot_match_gallery(
|
||||||
true_positives: List[MatchEvaluation],
|
true_positives: Sequence[MatchProtocol],
|
||||||
false_positives: List[MatchEvaluation],
|
false_positives: Sequence[MatchProtocol],
|
||||||
false_negatives: List[MatchEvaluation],
|
false_negatives: Sequence[MatchProtocol],
|
||||||
cross_triggers: List[MatchEvaluation],
|
cross_triggers: Sequence[MatchProtocol],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
n_examples: int = 5,
|
n_examples: int = 5,
|
||||||
duration: float = 0.1,
|
duration: float = 0.1,
|
||||||
|
fig: Optional[Figure] = None,
|
||||||
):
|
):
|
||||||
|
if fig is None:
|
||||||
fig = plt.figure(figsize=(20, 20))
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
|
||||||
for index, match in enumerate(true_positives[:n_examples]):
|
axes = fig.subplots(
|
||||||
ax = plt.subplot(4, n_examples, index + 1)
|
nrows=4,
|
||||||
|
ncols=n_examples,
|
||||||
|
sharex="none",
|
||||||
|
sharey="row",
|
||||||
|
)
|
||||||
|
|
||||||
|
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
|
||||||
try:
|
try:
|
||||||
plot_true_positive_match(
|
plot_true_positive_match(
|
||||||
match,
|
tp_match,
|
||||||
ax=ax,
|
ax=tp_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_positives[:n_examples]):
|
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
|
||||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
|
||||||
try:
|
try:
|
||||||
plot_false_positive_match(
|
plot_false_positive_match(
|
||||||
match,
|
fp_match,
|
||||||
ax=ax,
|
ax=fp_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_negatives[:n_examples]):
|
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
|
||||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
|
||||||
try:
|
try:
|
||||||
plot_false_negative_match(
|
plot_false_negative_match(
|
||||||
match,
|
fn_match,
|
||||||
ax=ax,
|
ax=fn_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
|
||||||
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
|
|
||||||
try:
|
try:
|
||||||
plot_cross_trigger_match(
|
plot_cross_trigger_match(
|
||||||
match,
|
ct_match,
|
||||||
ax=ax,
|
ax=ct_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|||||||
@ -1,16 +1,17 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Protocol, Tuple, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.plot.tags import TagColorMapper
|
|
||||||
|
|
||||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_matches",
|
|
||||||
"plot_false_positive_match",
|
"plot_false_positive_match",
|
||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_false_negative_match",
|
"plot_false_negative_match",
|
||||||
@ -18,6 +19,14 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MatchProtocol(Protocol):
|
||||||
|
clip: data.Clip
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
score: float
|
||||||
|
true_class: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DURATION = 0.05
|
DEFAULT_DURATION = 0.05
|
||||||
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
||||||
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
||||||
@ -27,88 +36,8 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
|
|||||||
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||||
|
|
||||||
|
|
||||||
def plot_matches(
|
|
||||||
matches: List[MatchEvaluation],
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
ax: Optional[Axes] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
color_mapper: Optional[TagColorMapper] = None,
|
|
||||||
add_points: bool = False,
|
|
||||||
fill: bool = False,
|
|
||||||
spec_cmap: str = "gray",
|
|
||||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
|
||||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
|
||||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
|
||||||
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
|
||||||
) -> Axes:
|
|
||||||
ax = plot_clip(
|
|
||||||
clip,
|
|
||||||
ax=ax,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
figsize=figsize,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
spec_cmap=spec_cmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if color_mapper is None:
|
|
||||||
color_mapper = TagColorMapper()
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
if match.is_cross_trigger():
|
|
||||||
plot_cross_trigger_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_points=add_points,
|
|
||||||
add_spectrogram=False,
|
|
||||||
use_score=True,
|
|
||||||
color=cross_trigger_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
elif match.is_true_positive():
|
|
||||||
plot_true_positive_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_spectrogram=False,
|
|
||||||
use_score=True,
|
|
||||||
add_points=add_points,
|
|
||||||
color=true_positive_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
elif match.is_false_negative():
|
|
||||||
plot_false_negative_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_spectrogram=False,
|
|
||||||
add_points=add_points,
|
|
||||||
color=false_negative_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
elif match.is_false_positive:
|
|
||||||
plot_false_positive_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_spectrogram=False,
|
|
||||||
use_score=True,
|
|
||||||
add_points=add_points,
|
|
||||||
color=false_positive_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_false_positive_match(
|
def plot_false_positive_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -119,21 +48,24 @@ def plot_false_positive_match(
|
|||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
add_text: bool = True,
|
add_text: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_title: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: Union[float, str] = "small",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.pred_geometry is not None
|
assert match.pred is not None
|
||||||
assert match.sound_event_annotation is None
|
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
|
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
|
start_time - duration / 2,
|
||||||
|
0,
|
||||||
|
),
|
||||||
end_time=min(
|
end_time=min(
|
||||||
start_time + duration / 2,
|
start_time + duration / 2,
|
||||||
match.clip.end_time,
|
match.clip.recording.duration,
|
||||||
),
|
),
|
||||||
recording=match.clip.recording,
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
@ -150,30 +82,33 @@ def plot_false_positive_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred.geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.pred_score if use_score else 1,
|
alpha=match.score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_text:
|
||||||
plt.text(
|
ax.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
f"score={match.score:.2f}",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=fontsize,
|
fontsize=fontsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title("False Positive")
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_false_negative_match(
|
def plot_false_negative_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -182,26 +117,28 @@ def plot_false_negative_match(
|
|||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_text: bool = True,
|
add_title: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.pred_geometry is None
|
assert match.gt is not None
|
||||||
assert match.sound_event_annotation is not None
|
|
||||||
sound_event = match.sound_event_annotation.sound_event
|
geometry = match.gt.sound_event.geometry
|
||||||
geometry = sound_event.geometry
|
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time = compute_bounds(geometry)[0]
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
end_time=min(
|
start_time - duration / 2,
|
||||||
start_time + duration / 2, sound_event.recording.duration
|
0,
|
||||||
),
|
),
|
||||||
recording=sound_event.recording,
|
end_time=min(
|
||||||
|
start_time + duration / 2,
|
||||||
|
match.clip.recording.duration,
|
||||||
|
),
|
||||||
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
if add_spectrogram:
|
||||||
@ -215,33 +152,23 @@ def plot_false_negative_match(
|
|||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_annotation(
|
ax = plot.plot_geometry(
|
||||||
match.sound_event_annotation,
|
geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_title:
|
||||||
plt.text(
|
ax.set_title("False Negative")
|
||||||
start_time,
|
|
||||||
high_freq,
|
|
||||||
f"False Negative \nClass: {match.gt_class} ",
|
|
||||||
va="top",
|
|
||||||
ha="right",
|
|
||||||
color=color,
|
|
||||||
fontsize=fontsize,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_true_positive_match(
|
def plot_true_positive_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -258,39 +185,42 @@ def plot_true_positive_match(
|
|||||||
fontsize: Union[float, str] = "small",
|
fontsize: Union[float, str] = "small",
|
||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
|
add_title: bool = True,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.sound_event_annotation is not None
|
assert match.gt is not None
|
||||||
assert match.pred_geometry is not None
|
assert match.pred is not None
|
||||||
sound_event = match.sound_event_annotation.sound_event
|
|
||||||
geometry = sound_event.geometry
|
geometry = match.gt.sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
end_time=min(
|
start_time - duration / 2,
|
||||||
start_time + duration / 2, sound_event.recording.duration
|
0,
|
||||||
),
|
),
|
||||||
recording=sound_event.recording,
|
end_time=min(
|
||||||
|
start_time + duration / 2,
|
||||||
|
match.clip.recording.duration,
|
||||||
|
),
|
||||||
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
if add_spectrogram:
|
||||||
ax = plot_clip(
|
ax = plot_clip(
|
||||||
clip,
|
clip,
|
||||||
|
ax=ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_annotation(
|
ax = plot.plot_geometry(
|
||||||
match.sound_event_annotation,
|
geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -299,31 +229,34 @@ def plot_true_positive_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_geometry(
|
plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred.geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.pred_score if use_score else 1,
|
alpha=match.score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_text:
|
||||||
plt.text(
|
ax.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
f"score={match.score:.2f}",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=fontsize,
|
fontsize=fontsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title("True Positive")
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_cross_trigger_match(
|
def plot_cross_trigger_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -334,6 +267,7 @@ def plot_cross_trigger_match(
|
|||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_text: bool = True,
|
add_text: bool = True,
|
||||||
|
add_title: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
@ -341,20 +275,24 @@ def plot_cross_trigger_match(
|
|||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.sound_event_annotation is not None
|
assert match.gt is not None
|
||||||
assert match.pred_geometry is not None
|
assert match.pred is not None
|
||||||
sound_event = match.sound_event_annotation.sound_event
|
|
||||||
geometry = sound_event.geometry
|
geometry = match.gt.sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
end_time=min(
|
start_time - duration / 2,
|
||||||
start_time + duration / 2, sound_event.recording.duration
|
0,
|
||||||
),
|
),
|
||||||
recording=sound_event.recording,
|
end_time=min(
|
||||||
|
start_time + duration / 2,
|
||||||
|
match.clip.recording.duration,
|
||||||
|
),
|
||||||
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
if add_spectrogram:
|
||||||
@ -368,11 +306,9 @@ def plot_cross_trigger_match(
|
|||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_annotation(
|
ax = plot.plot_geometry(
|
||||||
match.sound_event_annotation,
|
geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -381,24 +317,28 @@ def plot_cross_trigger_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred.geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.pred_score if use_score else 1,
|
alpha=match.score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_text:
|
||||||
plt.text(
|
ax.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
f"score={match.score:.2f}\nclass={match.true_class}",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=fontsize,
|
fontsize=fontsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title("Cross Trigger")
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|||||||
286
src/batdetect2/plotting/metrics.py
Normal file
286
src/batdetect2/plotting/metrics.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import seaborn as sns
|
||||||
|
from cycler import cycler
|
||||||
|
from matplotlib import axes
|
||||||
|
|
||||||
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_styler(ax: axes.Axes) -> axes.Axes:
|
||||||
|
color_cycler = cycler(color=sns.color_palette("muted"))
|
||||||
|
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
|
||||||
|
marker=["o", "s", "^"]
|
||||||
|
)
|
||||||
|
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
|
||||||
|
color_cycler
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_prop_cycle(custom_cycler)
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_style(ax: axes.Axes) -> axes.Axes:
|
||||||
|
ax = set_default_styler(ax)
|
||||||
|
ax.spines.right.set_visible(False)
|
||||||
|
ax.spines.top.set_visible(False)
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pr_curve(
|
||||||
|
precision: np.ndarray,
|
||||||
|
recall: np.ndarray,
|
||||||
|
thresholds: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
add_legend: bool = False,
|
||||||
|
label: str = "PR Curve",
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(
|
||||||
|
recall,
|
||||||
|
precision,
|
||||||
|
label=label,
|
||||||
|
marker="o",
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pr_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (precision, recall, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
recall,
|
||||||
|
precision,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_precision_curve(
|
||||||
|
threshold: np.ndarray,
|
||||||
|
precision: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_precision_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (precision, _, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
thresholds,
|
||||||
|
precision,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_recall_curve(
|
||||||
|
threshold: np.ndarray,
|
||||||
|
recall: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Recall")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_recall_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (_, recall, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
thresholds,
|
||||||
|
recall,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Recall")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curve(
|
||||||
|
fpr: np.ndarray,
|
||||||
|
tpr: np.ndarray,
|
||||||
|
thresholds: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(
|
||||||
|
fpr,
|
||||||
|
tpr,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (fpr, tpr, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
fpr,
|
||||||
|
tpr,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def _get_marker_positions(
|
||||||
|
thresholds: np.ndarray,
|
||||||
|
n_points: int = 11,
|
||||||
|
) -> np.ndarray:
|
||||||
|
size = len(thresholds)
|
||||||
|
cut_points = np.linspace(0, 1, n_points)
|
||||||
|
indices = np.searchsorted(thresholds[::-1], cut_points)
|
||||||
|
return np.clip(size - indices, 0, size - 1) # type: ignore
|
||||||
@ -28,12 +28,10 @@ class CenterAudio(torch.nn.Module):
|
|||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
return center_tensor(wav)
|
return center_tensor(wav)
|
||||||
|
|
||||||
@classmethod
|
@audio_transforms.register(CenterAudioConfig)
|
||||||
def from_config(cls, config: CenterAudioConfig, samplerate: int):
|
@staticmethod
|
||||||
return cls()
|
def from_config(config: CenterAudioConfig, samplerate: int):
|
||||||
|
return CenterAudio()
|
||||||
|
|
||||||
audio_transforms.register(CenterAudioConfig, CenterAudio)
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleAudioConfig(BaseConfig):
|
class ScaleAudioConfig(BaseConfig):
|
||||||
@ -44,12 +42,10 @@ class ScaleAudio(torch.nn.Module):
|
|||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
return peak_normalize(wav)
|
return peak_normalize(wav)
|
||||||
|
|
||||||
@classmethod
|
@audio_transforms.register(ScaleAudioConfig)
|
||||||
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
|
@staticmethod
|
||||||
return cls()
|
def from_config(config: ScaleAudioConfig, samplerate: int):
|
||||||
|
return ScaleAudio()
|
||||||
|
|
||||||
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
|
|
||||||
|
|
||||||
|
|
||||||
class FixDurationConfig(BaseConfig):
|
class FixDurationConfig(BaseConfig):
|
||||||
@ -75,13 +71,12 @@ class FixDuration(torch.nn.Module):
|
|||||||
|
|
||||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||||
|
|
||||||
@classmethod
|
@audio_transforms.register(FixDurationConfig)
|
||||||
def from_config(cls, config: FixDurationConfig, samplerate: int):
|
@staticmethod
|
||||||
return cls(samplerate=samplerate, duration=config.duration)
|
def from_config(config: FixDurationConfig, samplerate: int):
|
||||||
|
return FixDuration(samplerate=samplerate, duration=config.duration)
|
||||||
|
|
||||||
|
|
||||||
audio_transforms.register(FixDurationConfig, FixDuration)
|
|
||||||
|
|
||||||
AudioTransform = Annotated[
|
AudioTransform = Annotated[
|
||||||
Union[
|
Union[
|
||||||
FixDurationConfig,
|
FixDurationConfig,
|
||||||
|
|||||||
@ -285,10 +285,11 @@ class PCEN(torch.nn.Module):
|
|||||||
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
||||||
).to(spec.dtype)
|
).to(spec.dtype)
|
||||||
|
|
||||||
@classmethod
|
@spectrogram_transforms.register(PcenConfig)
|
||||||
def from_config(cls, config: PcenConfig, samplerate: int):
|
@staticmethod
|
||||||
|
def from_config(config: PcenConfig, samplerate: int):
|
||||||
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
|
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
|
||||||
return cls(
|
return PCEN(
|
||||||
smoothing_constant=smooth,
|
smoothing_constant=smooth,
|
||||||
gain=config.gain,
|
gain=config.gain,
|
||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
@ -296,9 +297,6 @@ class PCEN(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
spectrogram_transforms.register(PcenConfig, PCEN)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_smoothing_constant(
|
def _compute_smoothing_constant(
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
time_constant: float,
|
time_constant: float,
|
||||||
@ -335,12 +333,10 @@ class ScaleAmplitude(torch.nn.Module):
|
|||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
return self.scaler(spec)
|
return self.scaler(spec)
|
||||||
|
|
||||||
@classmethod
|
@spectrogram_transforms.register(ScaleAmplitudeConfig)
|
||||||
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
|
@staticmethod
|
||||||
return cls(scale=config.scale)
|
def from_config(config: ScaleAmplitudeConfig, samplerate: int):
|
||||||
|
return ScaleAmplitude(scale=config.scale)
|
||||||
|
|
||||||
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
|
|
||||||
|
|
||||||
|
|
||||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||||
@ -352,19 +348,13 @@ class SpectralMeanSubstraction(torch.nn.Module):
|
|||||||
mean = spec.mean(-1, keepdim=True)
|
mean = spec.mean(-1, keepdim=True)
|
||||||
return (spec - mean).clamp(min=0)
|
return (spec - mean).clamp(min=0)
|
||||||
|
|
||||||
@classmethod
|
@spectrogram_transforms.register(SpectralMeanSubstractionConfig)
|
||||||
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
|
||||||
config: SpectralMeanSubstractionConfig,
|
config: SpectralMeanSubstractionConfig,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
):
|
):
|
||||||
return cls()
|
return SpectralMeanSubstraction()
|
||||||
|
|
||||||
|
|
||||||
spectrogram_transforms.register(
|
|
||||||
SpectralMeanSubstractionConfig,
|
|
||||||
SpectralMeanSubstraction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PeakNormalizeConfig(BaseConfig):
|
class PeakNormalizeConfig(BaseConfig):
|
||||||
@ -375,13 +365,12 @@ class PeakNormalize(torch.nn.Module):
|
|||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
return peak_normalize(spec)
|
return peak_normalize(spec)
|
||||||
|
|
||||||
@classmethod
|
@spectrogram_transforms.register(PeakNormalizeConfig)
|
||||||
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
|
@staticmethod
|
||||||
return cls()
|
def from_config(config: PeakNormalizeConfig, samplerate: int):
|
||||||
|
return PeakNormalize()
|
||||||
|
|
||||||
|
|
||||||
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
|
|
||||||
|
|
||||||
SpectrogramTransform = Annotated[
|
SpectrogramTransform = Annotated[
|
||||||
Union[
|
Union[
|
||||||
PcenConfig,
|
PcenConfig,
|
||||||
|
|||||||
@ -99,7 +99,7 @@ DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
|||||||
DEFAULT_CLASSES = [
|
DEFAULT_CLASSES = [
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="barbar",
|
name="barbar",
|
||||||
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
|
tags=[data.Tag(key="class", value="Barbastella barbastellus")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="eptser",
|
name="eptser",
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
EchoAugmentationConfig,
|
AddEchoConfig,
|
||||||
FrequencyMaskAugmentationConfig,
|
MaskFrequencyConfig,
|
||||||
RandomAudioSource,
|
RandomAudioSource,
|
||||||
TimeMaskAugmentationConfig,
|
MaskTimeConfig,
|
||||||
VolumeAugmentationConfig,
|
ScaleVolumeConfig,
|
||||||
WarpAugmentationConfig,
|
WarpConfig,
|
||||||
add_echo,
|
add_echo,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
mask_frequency,
|
mask_frequency,
|
||||||
@ -43,20 +43,20 @@ __all__ = [
|
|||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
"ClassificationLossConfig",
|
"ClassificationLossConfig",
|
||||||
"DetectionLossConfig",
|
"DetectionLossConfig",
|
||||||
"EchoAugmentationConfig",
|
"AddEchoConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"MaskFrequencyConfig",
|
||||||
"LossConfig",
|
"LossConfig",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
"PLTrainerConfig",
|
"PLTrainerConfig",
|
||||||
"RandomAudioSource",
|
"RandomAudioSource",
|
||||||
"SizeLossConfig",
|
"SizeLossConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"MaskTimeConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"TrainingDataset",
|
"TrainingDataset",
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"ValidationDataset",
|
"ValidationDataset",
|
||||||
"VolumeAugmentationConfig",
|
"ScaleVolumeConfig",
|
||||||
"WarpAugmentationConfig",
|
"WarpConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"build_augmentations",
|
"build_augmentations",
|
||||||
"build_clip_labeler",
|
"build_clip_labeler",
|
||||||
|
|||||||
@ -12,21 +12,23 @@ from soundevent import data
|
|||||||
from soundevent.geometry import scale_geometry, shift_geometry
|
from soundevent.geometry import scale_geometry, shift_geometry
|
||||||
|
|
||||||
from batdetect2.audio.clips import get_subclip_annotation
|
from batdetect2.audio.clips import get_subclip_annotation
|
||||||
|
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.typing import AudioLoader, Augmentation
|
from batdetect2.typing import AudioLoader, Augmentation
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationConfig",
|
"AugmentationConfig",
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
"DEFAULT_AUGMENTATION_CONFIG",
|
"DEFAULT_AUGMENTATION_CONFIG",
|
||||||
"EchoAugmentationConfig",
|
"AddEchoConfig",
|
||||||
"AudioSource",
|
"AudioSource",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"MaskFrequencyConfig",
|
||||||
"MixAugmentationConfig",
|
"MixAudioConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"MaskTimeConfig",
|
||||||
"VolumeAugmentationConfig",
|
"ScaleVolumeConfig",
|
||||||
"WarpAugmentationConfig",
|
"WarpConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"build_augmentations",
|
"build_augmentations",
|
||||||
"load_augmentation_config",
|
"load_augmentation_config",
|
||||||
@ -37,10 +39,19 @@ __all__ = [
|
|||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
||||||
|
|
||||||
|
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
|
||||||
|
Registry(name="audio_augmentation")
|
||||||
|
)
|
||||||
|
|
||||||
class MixAugmentationConfig(BaseConfig):
|
spec_augmentations: Registry[Augmentation, []] = Registry(
|
||||||
|
name="spec_augmentation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MixAudioConfig(BaseConfig):
|
||||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||||
|
|
||||||
name: Literal["mix_audio"] = "mix_audio"
|
name: Literal["mix_audio"] = "mix_audio"
|
||||||
@ -87,6 +98,27 @@ class MixAudio(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
return mixed_audio, mixed_annotations
|
return mixed_audio, mixed_annotations
|
||||||
|
|
||||||
|
@audio_augmentations.register(MixAudioConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: MixAudioConfig,
|
||||||
|
samplerate: int,
|
||||||
|
source: Optional[AudioSource],
|
||||||
|
):
|
||||||
|
if source is None:
|
||||||
|
warnings.warn(
|
||||||
|
"Mix audio augmentation ('mix_audio') requires an "
|
||||||
|
"'example_source' callable to be provided.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return lambda wav, clip_annotation: (wav, clip_annotation)
|
||||||
|
|
||||||
|
return MixAudio(
|
||||||
|
example_source=source,
|
||||||
|
min_weight=config.min_weight,
|
||||||
|
max_weight=config.max_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mix_audio(
|
def mix_audio(
|
||||||
wav1: torch.Tensor,
|
wav1: torch.Tensor,
|
||||||
@ -136,7 +168,7 @@ def combine_clip_annotations(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class EchoAugmentationConfig(BaseConfig):
|
class AddEchoConfig(BaseConfig):
|
||||||
"""Configuration for adding synthetic echo/reverb."""
|
"""Configuration for adding synthetic echo/reverb."""
|
||||||
|
|
||||||
name: Literal["add_echo"] = "add_echo"
|
name: Literal["add_echo"] = "add_echo"
|
||||||
@ -149,14 +181,17 @@ class EchoAugmentationConfig(BaseConfig):
|
|||||||
class AddEcho(torch.nn.Module):
|
class AddEcho(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
min_weight: float = 0.1,
|
min_weight: float = 0.1,
|
||||||
max_weight: float = 1.0,
|
max_weight: float = 1.0,
|
||||||
max_delay: int = 2560,
|
max_delay: float = 0.005,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.samplerate = samplerate
|
||||||
self.min_weight = min_weight
|
self.min_weight = min_weight
|
||||||
self.max_weight = max_weight
|
self.max_weight = max_weight
|
||||||
self.max_delay = max_delay
|
self.max_delay_s = max_delay
|
||||||
|
self.max_delay = int(max_delay * samplerate)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -167,6 +202,20 @@ class AddEcho(torch.nn.Module):
|
|||||||
weight = np.random.uniform(self.min_weight, self.max_weight)
|
weight = np.random.uniform(self.min_weight, self.max_weight)
|
||||||
return add_echo(wav, delay=delay, weight=weight), clip_annotation
|
return add_echo(wav, delay=delay, weight=weight), clip_annotation
|
||||||
|
|
||||||
|
@audio_augmentations.register(AddEchoConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: AddEchoConfig,
|
||||||
|
samplerate: int,
|
||||||
|
source: Optional[AudioSource],
|
||||||
|
):
|
||||||
|
return AddEcho(
|
||||||
|
samplerate=samplerate,
|
||||||
|
min_weight=config.min_weight,
|
||||||
|
max_weight=config.max_weight,
|
||||||
|
max_delay=config.max_delay,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_echo(
|
def add_echo(
|
||||||
wav: torch.Tensor,
|
wav: torch.Tensor,
|
||||||
@ -183,7 +232,7 @@ def add_echo(
|
|||||||
return mix_audio(wav, audio_delay, weight)
|
return mix_audio(wav, audio_delay, weight)
|
||||||
|
|
||||||
|
|
||||||
class VolumeAugmentationConfig(BaseConfig):
|
class ScaleVolumeConfig(BaseConfig):
|
||||||
"""Configuration for random volume scaling of the spectrogram."""
|
"""Configuration for random volume scaling of the spectrogram."""
|
||||||
|
|
||||||
name: Literal["scale_volume"] = "scale_volume"
|
name: Literal["scale_volume"] = "scale_volume"
|
||||||
@ -206,19 +255,27 @@ class ScaleVolume(torch.nn.Module):
|
|||||||
factor = np.random.uniform(self.min_scaling, self.max_scaling)
|
factor = np.random.uniform(self.min_scaling, self.max_scaling)
|
||||||
return scale_volume(spec, factor=factor), clip_annotation
|
return scale_volume(spec, factor=factor), clip_annotation
|
||||||
|
|
||||||
|
@spec_augmentations.register(ScaleVolumeConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ScaleVolumeConfig):
|
||||||
|
return ScaleVolume(
|
||||||
|
min_scaling=config.min_scaling,
|
||||||
|
max_scaling=config.max_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
|
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
|
||||||
"""Scale the amplitude of the spectrogram by a factor."""
|
"""Scale the amplitude of the spectrogram by a factor."""
|
||||||
return spec * factor
|
return spec * factor
|
||||||
|
|
||||||
|
|
||||||
class WarpAugmentationConfig(BaseConfig):
|
class WarpConfig(BaseConfig):
|
||||||
name: Literal["warp"] = "warp"
|
name: Literal["warp"] = "warp"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
delta: float = 0.04
|
delta: float = 0.04
|
||||||
|
|
||||||
|
|
||||||
class WarpSpectrogram(torch.nn.Module):
|
class Warp(torch.nn.Module):
|
||||||
def __init__(self, delta: float = 0.04) -> None:
|
def __init__(self, delta: float = 0.04) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
@ -234,6 +291,11 @@ class WarpSpectrogram(torch.nn.Module):
|
|||||||
warp_clip_annotation(clip_annotation, factor=factor),
|
warp_clip_annotation(clip_annotation, factor=factor),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@spec_augmentations.register(WarpConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: WarpConfig):
|
||||||
|
return Warp(delta=config.delta)
|
||||||
|
|
||||||
|
|
||||||
def warp_sound_event_annotation(
|
def warp_sound_event_annotation(
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
@ -294,7 +356,7 @@ def warp_spectrogram(
|
|||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
class TimeMaskAugmentationConfig(BaseConfig):
|
class MaskTimeConfig(BaseConfig):
|
||||||
name: Literal["mask_time"] = "mask_time"
|
name: Literal["mask_time"] = "mask_time"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_perc: float = 0.05
|
max_perc: float = 0.05
|
||||||
@ -336,6 +398,14 @@ class MaskTime(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
return mask_time(spec, masks), clip_annotation
|
return mask_time(spec, masks), clip_annotation
|
||||||
|
|
||||||
|
@spec_augmentations.register(MaskTimeConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: MaskTimeConfig):
|
||||||
|
return MaskTime(
|
||||||
|
max_perc=config.max_perc,
|
||||||
|
max_masks=config.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mask_time(
|
def mask_time(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
@ -351,7 +421,7 @@ def mask_time(
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
class MaskFrequencyConfig(BaseConfig):
|
||||||
name: Literal["mask_freq"] = "mask_freq"
|
name: Literal["mask_freq"] = "mask_freq"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_perc: float = 0.10
|
max_perc: float = 0.10
|
||||||
@ -394,6 +464,14 @@ class MaskFrequency(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
return mask_frequency(spec, masks), clip_annotation
|
return mask_frequency(spec, masks), clip_annotation
|
||||||
|
|
||||||
|
@spec_augmentations.register(MaskFrequencyConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: MaskFrequencyConfig):
|
||||||
|
return MaskFrequency(
|
||||||
|
max_perc=config.max_perc,
|
||||||
|
max_masks=config.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mask_frequency(
|
def mask_frequency(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
@ -410,8 +488,8 @@ def mask_frequency(
|
|||||||
|
|
||||||
AudioAugmentationConfig = Annotated[
|
AudioAugmentationConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
MixAugmentationConfig,
|
MixAudioConfig,
|
||||||
EchoAugmentationConfig,
|
AddEchoConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
@ -419,22 +497,22 @@ AudioAugmentationConfig = Annotated[
|
|||||||
|
|
||||||
SpectrogramAugmentationConfig = Annotated[
|
SpectrogramAugmentationConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
VolumeAugmentationConfig,
|
ScaleVolumeConfig,
|
||||||
WarpAugmentationConfig,
|
WarpConfig,
|
||||||
FrequencyMaskAugmentationConfig,
|
MaskFrequencyConfig,
|
||||||
TimeMaskAugmentationConfig,
|
MaskTimeConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
MixAugmentationConfig,
|
MixAudioConfig,
|
||||||
EchoAugmentationConfig,
|
AddEchoConfig,
|
||||||
VolumeAugmentationConfig,
|
ScaleVolumeConfig,
|
||||||
WarpAugmentationConfig,
|
WarpConfig,
|
||||||
FrequencyMaskAugmentationConfig,
|
MaskFrequencyConfig,
|
||||||
TimeMaskAugmentationConfig,
|
MaskTimeConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
@ -513,7 +591,7 @@ def build_augmentation_from_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "warp":
|
if config.name == "warp":
|
||||||
return WarpSpectrogram(
|
return Warp(
|
||||||
delta=config.delta,
|
delta=config.delta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -538,14 +616,14 @@ def build_augmentation_from_config(
|
|||||||
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
audio=[
|
audio=[
|
||||||
MixAugmentationConfig(),
|
MixAudioConfig(),
|
||||||
EchoAugmentationConfig(),
|
AddEchoConfig(),
|
||||||
],
|
],
|
||||||
spectrogram=[
|
spectrogram=[
|
||||||
VolumeAugmentationConfig(),
|
ScaleVolumeConfig(),
|
||||||
WarpAugmentationConfig(),
|
WarpConfig(),
|
||||||
TimeMaskAugmentationConfig(),
|
MaskTimeConfig(),
|
||||||
FrequencyMaskAugmentationConfig(),
|
MaskFrequencyConfig(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -566,9 +644,9 @@ class AugmentationSequence(torch.nn.Module):
|
|||||||
return tensor, clip_annotation
|
return tensor, clip_annotation
|
||||||
|
|
||||||
|
|
||||||
def build_augmentation_sequence(
|
def build_audio_augmentations(
|
||||||
samplerate: int,
|
steps: Optional[Sequence[AudioAugmentationConfig]] = None,
|
||||||
steps: Optional[Sequence[AugmentationConfig]] = None,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
audio_source: Optional[AudioSource] = None,
|
audio_source: Optional[AudioSource] = None,
|
||||||
) -> Optional[Augmentation]:
|
) -> Optional[Augmentation]:
|
||||||
if not steps:
|
if not steps:
|
||||||
@ -577,10 +655,8 @@ def build_augmentation_sequence(
|
|||||||
augmentations = []
|
augmentations = []
|
||||||
|
|
||||||
for step_config in steps:
|
for step_config in steps:
|
||||||
augmentation = build_augmentation_from_config(
|
augmentation = audio_augmentations.build(
|
||||||
step_config,
|
step_config, samplerate, audio_source
|
||||||
samplerate=samplerate,
|
|
||||||
audio_source=audio_source,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if augmentation is None:
|
if augmentation is None:
|
||||||
@ -596,6 +672,30 @@ def build_augmentation_sequence(
|
|||||||
return AugmentationSequence(augmentations)
|
return AugmentationSequence(augmentations)
|
||||||
|
|
||||||
|
|
||||||
|
def build_spectrogram_augmentations(
|
||||||
|
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
|
||||||
|
) -> Optional[Augmentation]:
|
||||||
|
if not steps:
|
||||||
|
return None
|
||||||
|
|
||||||
|
augmentations = []
|
||||||
|
|
||||||
|
for step_config in steps:
|
||||||
|
augmentation = spec_augmentations.build(step_config)
|
||||||
|
|
||||||
|
if augmentation is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
augmentations.append(
|
||||||
|
MaybeApply(
|
||||||
|
augmentation=augmentation,
|
||||||
|
probability=step_config.probability,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AugmentationSequence(augmentations)
|
||||||
|
|
||||||
|
|
||||||
def build_augmentations(
|
def build_augmentations(
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
config: Optional[AugmentationsConfig] = None,
|
config: Optional[AugmentationsConfig] = None,
|
||||||
@ -609,16 +709,14 @@ def build_augmentations(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_augmentation = build_augmentation_sequence(
|
audio_augmentation = build_audio_augmentations(
|
||||||
samplerate,
|
|
||||||
steps=config.audio,
|
steps=config.audio,
|
||||||
|
samplerate=samplerate,
|
||||||
audio_source=audio_source,
|
audio_source=audio_source,
|
||||||
)
|
)
|
||||||
|
|
||||||
spectrogram_augmentation = build_augmentation_sequence(
|
spectrogram_augmentation = build_spectrogram_augmentations(
|
||||||
samplerate,
|
steps=config.spectrogram,
|
||||||
steps=config.audio,
|
|
||||||
audio_source=audio_source,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return audio_augmentation, spectrogram_augmentation
|
return audio_augmentation, spectrogram_augmentation
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
@ -10,7 +10,6 @@ from batdetect2.postprocess import to_raw_predictions
|
|||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipEvaluation,
|
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
@ -36,23 +35,23 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
def generate_plots(
|
def generate_plots(
|
||||||
self,
|
self,
|
||||||
|
eval_outputs: Any,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
evaluated_clips: List[ClipEvaluation],
|
|
||||||
):
|
):
|
||||||
plotter = get_image_logger(pl_module.logger) # type: ignore
|
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||||
|
|
||||||
if plotter is None:
|
if plotter is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
for figure_name, fig in self.evaluator.generate_plots(eval_outputs):
|
||||||
plotter(figure_name, fig, pl_module.global_step)
|
plotter(figure_name, fig, pl_module.global_step)
|
||||||
|
|
||||||
def log_metrics(
|
def log_metrics(
|
||||||
self,
|
self,
|
||||||
|
eval_outputs: Any,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
evaluated_clips: List[ClipEvaluation],
|
|
||||||
):
|
):
|
||||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
metrics = self.evaluator.compute_metrics(eval_outputs)
|
||||||
pl_module.log_dict(metrics)
|
pl_module.log_dict(metrics)
|
||||||
|
|
||||||
def on_validation_epoch_end(
|
def on_validation_epoch_end(
|
||||||
@ -60,13 +59,13 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
clip_evaluations = self.evaluator.evaluate(
|
eval_outputs = self.evaluator.evaluate(
|
||||||
self._clip_annotations,
|
self._clip_annotations,
|
||||||
self._predictions,
|
self._predictions,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log_metrics(pl_module, clip_evaluations)
|
self.log_metrics(eval_outputs, pl_module)
|
||||||
self.generate_plots(pl_module, clip_evaluations)
|
self.generate_plots(eval_outputs, pl_module)
|
||||||
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
@ -105,7 +105,10 @@ def train(
|
|||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
config,
|
config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
evaluator=build_evaluator(config.train.validation, targets=targets),
|
evaluator=build_evaluator(
|
||||||
|
config.train.validation,
|
||||||
|
targets=targets,
|
||||||
|
),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
@ -143,7 +146,7 @@ def build_trainer_callbacks(
|
|||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
monitor="total_loss/val",
|
monitor="classification/mean_average_precision",
|
||||||
),
|
),
|
||||||
ValidationMetrics(evaluator),
|
ValidationMetrics(evaluator),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing.evaluate import (
|
||||||
ClipEvaluation,
|
AffinityFunction,
|
||||||
|
ClipMatches,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
|
MatcherProtocol,
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
MetricsProtocol,
|
MetricsProtocol,
|
||||||
PlotterProtocol,
|
PlotterProtocol,
|
||||||
@ -36,19 +38,22 @@ from batdetect2.typing.train import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AffinityFunction",
|
||||||
"AudioLoader",
|
"AudioLoader",
|
||||||
"Augmentation",
|
"Augmentation",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
"BatDetect2Prediction",
|
"BatDetect2Prediction",
|
||||||
"ClipEvaluation",
|
"ClipMatches",
|
||||||
"ClipLabeller",
|
"ClipLabeller",
|
||||||
"ClipperProtocol",
|
"ClipperProtocol",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
|
"EvaluatorProtocol",
|
||||||
"GeometryDecoder",
|
"GeometryDecoder",
|
||||||
"Heatmaps",
|
"Heatmaps",
|
||||||
"LossProtocol",
|
"LossProtocol",
|
||||||
"Losses",
|
"Losses",
|
||||||
"MatchEvaluation",
|
"MatchEvaluation",
|
||||||
|
"MatcherProtocol",
|
||||||
"MetricsProtocol",
|
"MetricsProtocol",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"PlotterProtocol",
|
"PlotterProtocol",
|
||||||
@ -63,5 +68,4 @@ __all__ = [
|
|||||||
"SoundEventFilter",
|
"SoundEventFilter",
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"EvaluatorProtocol",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -31,6 +31,7 @@ class MatchEvaluation:
|
|||||||
sound_event_annotation: Optional[data.SoundEventAnnotation]
|
sound_event_annotation: Optional[data.SoundEventAnnotation]
|
||||||
gt_det: bool
|
gt_det: bool
|
||||||
gt_class: Optional[str]
|
gt_class: Optional[str]
|
||||||
|
gt_geometry: Optional[data.Geometry]
|
||||||
|
|
||||||
pred_score: float
|
pred_score: float
|
||||||
pred_class_scores: Dict[str, float]
|
pred_class_scores: Dict[str, float]
|
||||||
@ -39,44 +40,32 @@ class MatchEvaluation:
|
|||||||
affinity: float
|
affinity: float
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pred_class(self) -> Optional[str]:
|
def top_class(self) -> Optional[str]:
|
||||||
if not self.pred_class_scores:
|
if not self.pred_class_scores:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pred_class_score(self) -> float:
|
def is_prediction(self) -> bool:
|
||||||
pred_class = self.pred_class
|
return self.pred_geometry is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_generic(self) -> bool:
|
||||||
|
return self.gt_det and self.gt_class is None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def top_class_score(self) -> float:
|
||||||
|
pred_class = self.top_class
|
||||||
|
|
||||||
if pred_class is None:
|
if pred_class is None:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return self.pred_class_scores[pred_class]
|
return self.pred_class_scores[pred_class]
|
||||||
|
|
||||||
def is_true_positive(self, threshold: float = 0) -> bool:
|
|
||||||
return (
|
|
||||||
self.gt_det
|
|
||||||
and self.pred_score > threshold
|
|
||||||
and self.gt_class == self.pred_class
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_false_positive(self, threshold: float = 0) -> bool:
|
|
||||||
return self.gt_det is None and self.pred_score > threshold
|
|
||||||
|
|
||||||
def is_false_negative(self, threshold: float = 0) -> bool:
|
|
||||||
return self.gt_det and self.pred_score <= threshold
|
|
||||||
|
|
||||||
def is_cross_trigger(self, threshold: float = 0) -> bool:
|
|
||||||
return (
|
|
||||||
self.gt_det
|
|
||||||
and self.pred_score > threshold
|
|
||||||
and self.gt_class != self.pred_class
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClipEvaluation:
|
class ClipMatches:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
matches: List[MatchEvaluation]
|
matches: List[MatchEvaluation]
|
||||||
|
|
||||||
@ -103,29 +92,36 @@ class AffinityFunction(Protocol, Generic[Geom]):
|
|||||||
|
|
||||||
class MetricsProtocol(Protocol):
|
class MetricsProtocol(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
) -> Dict[str, float]: ...
|
) -> Dict[str, float]: ...
|
||||||
|
|
||||||
|
|
||||||
class PlotterProtocol(Protocol):
|
class PlotterProtocol(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
) -> Iterable[Tuple[str, Figure]]: ...
|
) -> Iterable[Tuple[str, Figure]]: ...
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorProtocol(Protocol):
|
EvaluationOutput = TypeVar("EvaluationOutput")
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
) -> List[ClipEvaluation]: ...
|
) -> EvaluationOutput: ...
|
||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self, eval_outputs: EvaluationOutput
|
||||||
) -> Dict[str, float]: ...
|
) -> Dict[str, float]: ...
|
||||||
|
|
||||||
def generate_plots(
|
def generate_plots(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self, eval_outputs: EvaluationOutput
|
||||||
) -> Iterable[Tuple[str, Figure]]: ...
|
) -> Iterable[Tuple[str, Figure]]: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user