mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Better evaluation organisation
This commit is contained in:
parent
4cd983a2c2
commit
d6ddc4514c
@ -140,13 +140,14 @@ train:
|
||||
validation:
|
||||
metrics:
|
||||
- name: detection_ap
|
||||
- name: detection_roc_auc
|
||||
- name: classification_ap
|
||||
- name: classification_roc_auc
|
||||
- name: top_class_ap
|
||||
- name: classification_balanced_accuracy
|
||||
- name: clip_ap
|
||||
- name: clip_roc_auc
|
||||
plots:
|
||||
- name: example_gallery
|
||||
- name: example_clip
|
||||
- name: detection_pr_curve
|
||||
- name: classification_pr_curves
|
||||
- name: detection_roc_curve
|
||||
- name: classification_roc_curves
|
||||
|
||||
evaluation:
|
||||
match_strategy:
|
||||
@ -155,6 +156,14 @@ evaluation:
|
||||
metrics:
|
||||
- name: classification_ap
|
||||
- name: detection_ap
|
||||
- name: detection_roc_auc
|
||||
- name: classification_roc_auc
|
||||
- name: top_class_ap
|
||||
- name: classification_balanced_accuracy
|
||||
- name: clip_multiclass_ap
|
||||
- name: clip_multiclass_roc_auc
|
||||
- name: clip_detection_ap
|
||||
- name: clip_detection_roc_auc
|
||||
plots:
|
||||
- name: example_gallery
|
||||
- name: example_clip
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
@ -8,6 +9,7 @@ from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.evaluate import build_evaluator, evaluate
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.postprocess.decoding import to_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.train import train
|
||||
@ -19,6 +21,7 @@ from batdetect2.typing import (
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
@ -92,6 +95,18 @@ class BatDetect2API:
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
def process_spectrogram(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
start_times: Optional[Sequence[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
outputs = self.model.detector(spec)
|
||||
clip_detections = self.postprocessor(outputs, start_times=start_times)
|
||||
return [
|
||||
to_raw_predictions(clip_dets.numpy(), self.targets)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BatDetect2Config):
|
||||
targets = build_targets(config=config.targets)
|
||||
@ -109,7 +124,7 @@ class BatDetect2API:
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
config=config.evaluation.evaluator,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
@ -164,7 +179,7 @@ class BatDetect2API:
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
config=config.evaluation.evaluator,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
@ -56,18 +56,16 @@ class RandomClip:
|
||||
min_sound_event_overlap=self.min_sound_event_overlap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: RandomClipConfig):
|
||||
return cls(
|
||||
@clipper_registry.register(RandomClipConfig)
|
||||
@staticmethod
|
||||
def from_config(config: RandomClipConfig):
|
||||
return RandomClip(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
min_sound_event_overlap=config.min_sound_event_overlap,
|
||||
)
|
||||
|
||||
|
||||
clipper_registry.register(RandomClipConfig, RandomClip)
|
||||
|
||||
|
||||
def get_subclip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
random: bool = True,
|
||||
@ -184,13 +182,12 @@ class PaddedClip:
|
||||
)
|
||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PaddedClipConfig):
|
||||
return cls(chunk_size=config.chunk_size)
|
||||
@clipper_registry.register(PaddedClipConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PaddedClipConfig):
|
||||
return PaddedClip(chunk_size=config.chunk_size)
|
||||
|
||||
|
||||
clipper_registry.register(PaddedClipConfig, PaddedClip)
|
||||
|
||||
ClipConfig = Annotated[
|
||||
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
|
||||
]
|
||||
|
||||
@ -53,6 +53,7 @@ class BaseConfig(BaseModel):
|
||||
"""
|
||||
return yaml.dump(
|
||||
self.model_dump(
|
||||
mode="json",
|
||||
exclude_none=exclude_none,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
import sys
|
||||
from typing import Generic, Protocol, Type, TypeVar
|
||||
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import assert_type
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import ParamSpec
|
||||
from typing import Concatenate, ParamSpec
|
||||
else:
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
__all__ = [
|
||||
"Registry",
|
||||
"SimpleRegistry",
|
||||
]
|
||||
|
||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||
@ -18,19 +18,26 @@ T_Type = TypeVar("T_Type", covariant=True)
|
||||
P_Type = ParamSpec("P_Type")
|
||||
|
||||
|
||||
class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol):
|
||||
"""A generic protocol for the logic classes."""
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: T_Config,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
) -> T_Type: ...
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
|
||||
class SimpleRegistry(Generic[T]):
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._registry = {}
|
||||
|
||||
def register(self, name: str):
|
||||
def decorator(obj: T) -> T:
|
||||
self._registry[name] = obj
|
||||
return obj
|
||||
|
||||
return decorator
|
||||
|
||||
def get(self, name: str) -> T:
|
||||
return self._registry[name]
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
return name in self._registry
|
||||
|
||||
|
||||
class Registry(Generic[T_Type, P_Type]):
|
||||
@ -38,13 +45,15 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._registry = {}
|
||||
self._registry: Dict[
|
||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||
] = {}
|
||||
self._config_types: Dict[str, Type[BaseModel]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
config_cls: Type[T_Config],
|
||||
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
||||
) -> None:
|
||||
):
|
||||
fields = config_cls.model_fields
|
||||
|
||||
if "name" not in fields:
|
||||
@ -52,10 +61,21 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
|
||||
name = fields["name"].default
|
||||
|
||||
self._config_types[name] = config_cls
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("'name' field must be a string literal.")
|
||||
|
||||
self._registry[name] = logic_cls
|
||||
def decorator(
|
||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||
):
|
||||
self._registry[name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
||||
return tuple(self._config_types.values())
|
||||
|
||||
def build(
|
||||
self,
|
||||
@ -75,4 +95,4 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
f"No {self._name} with name '{name}' is registered."
|
||||
)
|
||||
|
||||
return self._registry[name].from_config(config, *args, **kwargs)
|
||||
return self._registry[name](config, *args, **kwargs)
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.core.registries import Registry
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
condition_registry: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
@ -27,12 +27,10 @@ class HasTag:
|
||||
) -> bool:
|
||||
return self.tag in sound_event_annotation.tags
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: HasTagConfig):
|
||||
return cls(tag=config.tag)
|
||||
|
||||
|
||||
condition_registry.register(HasTagConfig, HasTag)
|
||||
@conditions.register(HasTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasTagConfig):
|
||||
return HasTag(tag=config.tag)
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
@ -52,12 +50,10 @@ class HasAllTags:
|
||||
) -> bool:
|
||||
return self.tags.issubset(sound_event_annotation.tags)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: HasAllTagsConfig):
|
||||
return cls(tags=config.tags)
|
||||
|
||||
|
||||
condition_registry.register(HasAllTagsConfig, HasAllTags)
|
||||
@conditions.register(HasAllTagsConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAllTagsConfig):
|
||||
return HasAllTags(tags=config.tags)
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
@ -77,13 +73,12 @@ class HasAnyTag:
|
||||
) -> bool:
|
||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: HasAnyTagConfig):
|
||||
return cls(tags=config.tags)
|
||||
@conditions.register(HasAnyTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAnyTagConfig):
|
||||
return HasAnyTag(tags=config.tags)
|
||||
|
||||
|
||||
condition_registry.register(HasAnyTagConfig, HasAnyTag)
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
@ -134,12 +129,10 @@ class Duration:
|
||||
|
||||
return self._comparator(duration)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: DurationConfig):
|
||||
return cls(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
condition_registry.register(DurationConfig, Duration)
|
||||
@conditions.register(DurationConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DurationConfig):
|
||||
return Duration(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
@ -181,18 +174,16 @@ class Frequency:
|
||||
|
||||
return self._comparator(high_freq)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: FrequencyConfig):
|
||||
return cls(
|
||||
@conditions.register(FrequencyConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FrequencyConfig):
|
||||
return Frequency(
|
||||
operator=config.operator,
|
||||
boundary=config.boundary,
|
||||
hertz=config.hertz,
|
||||
)
|
||||
|
||||
|
||||
condition_registry.register(FrequencyConfig, Frequency)
|
||||
|
||||
|
||||
class AllOfConfig(BaseConfig):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
@ -207,15 +198,13 @@ class AllOf:
|
||||
) -> bool:
|
||||
return all(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: AllOfConfig):
|
||||
@conditions.register(AllOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AllOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return cls(conditions)
|
||||
|
||||
|
||||
condition_registry.register(AllOfConfig, AllOf)
|
||||
return AllOf(conditions)
|
||||
|
||||
|
||||
class AnyOfConfig(BaseConfig):
|
||||
@ -232,15 +221,13 @@ class AnyOf:
|
||||
) -> bool:
|
||||
return any(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: AnyOfConfig):
|
||||
@conditions.register(AnyOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AnyOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return cls(conditions)
|
||||
|
||||
|
||||
condition_registry.register(AnyOfConfig, AnyOf)
|
||||
return AnyOf(conditions)
|
||||
|
||||
|
||||
class NotConfig(BaseConfig):
|
||||
@ -257,14 +244,13 @@ class Not:
|
||||
) -> bool:
|
||||
return not self.condition(sound_event_annotation)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: NotConfig):
|
||||
@conditions.register(NotConfig)
|
||||
@staticmethod
|
||||
def from_config(config: NotConfig):
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return cls(condition)
|
||||
return Not(condition)
|
||||
|
||||
|
||||
condition_registry.register(NotConfig, Not)
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
Union[
|
||||
HasTagConfig,
|
||||
@ -283,7 +269,7 @@ SoundEventConditionConfig = Annotated[
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
) -> SoundEventCondition:
|
||||
return condition_registry.build(config)
|
||||
return conditions.build(config)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
|
||||
@ -17,7 +17,7 @@ SoundEventTransform = Callable[
|
||||
data.SoundEventAnnotation,
|
||||
]
|
||||
|
||||
transform_registry: Registry[SoundEventTransform, []] = Registry("transform")
|
||||
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
||||
|
||||
|
||||
class SetFrequencyBoundConfig(BaseConfig):
|
||||
@ -63,12 +63,10 @@ class SetFrequencyBound:
|
||||
update=dict(sound_event=sound_event)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: SetFrequencyBoundConfig):
|
||||
return cls(hertz=config.hertz, boundary=config.boundary)
|
||||
|
||||
|
||||
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
|
||||
@transforms.register(SetFrequencyBoundConfig)
|
||||
@staticmethod
|
||||
def from_config(config: SetFrequencyBoundConfig):
|
||||
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
|
||||
|
||||
|
||||
class ApplyIfConfig(BaseConfig):
|
||||
@ -95,14 +93,12 @@ class ApplyIf:
|
||||
|
||||
return self.transform(sound_event_annotation)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ApplyIfConfig):
|
||||
@transforms.register(ApplyIfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ApplyIfConfig):
|
||||
transform = build_sound_event_transform(config.transform)
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return cls(condition=condition, transform=transform)
|
||||
|
||||
|
||||
transform_registry.register(ApplyIfConfig, ApplyIf)
|
||||
return ApplyIf(condition=condition, transform=transform)
|
||||
|
||||
|
||||
class ReplaceTagConfig(BaseConfig):
|
||||
@ -134,12 +130,12 @@ class ReplaceTag:
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ReplaceTagConfig):
|
||||
return cls(original=config.original, replacement=config.replacement)
|
||||
|
||||
|
||||
transform_registry.register(ReplaceTagConfig, ReplaceTag)
|
||||
@transforms.register(ReplaceTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ReplaceTagConfig):
|
||||
return ReplaceTag(
|
||||
original=config.original, replacement=config.replacement
|
||||
)
|
||||
|
||||
|
||||
class MapTagValueConfig(BaseConfig):
|
||||
@ -189,18 +185,16 @@ class MapTagValue:
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: MapTagValueConfig):
|
||||
return cls(
|
||||
@transforms.register(MapTagValueConfig)
|
||||
@staticmethod
|
||||
def from_config(config: MapTagValueConfig):
|
||||
return MapTagValue(
|
||||
tag_key=config.tag_key,
|
||||
value_mapping=config.value_mapping,
|
||||
target_key=config.target_key,
|
||||
)
|
||||
|
||||
|
||||
transform_registry.register(MapTagValueConfig, MapTagValue)
|
||||
|
||||
|
||||
class ApplyAllConfig(BaseConfig):
|
||||
name: Literal["apply_all"] = "apply_all"
|
||||
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
||||
@ -219,14 +213,13 @@ class ApplyAll:
|
||||
|
||||
return sound_event_annotation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ApplyAllConfig):
|
||||
@transforms.register(ApplyAllConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ApplyAllConfig):
|
||||
steps = [build_sound_event_transform(step) for step in config.steps]
|
||||
return cls(steps)
|
||||
return ApplyAll(steps)
|
||||
|
||||
|
||||
transform_registry.register(ApplyAllConfig, ApplyAll)
|
||||
|
||||
SoundEventTransformConfig = Annotated[
|
||||
Union[
|
||||
SetFrequencyBoundConfig,
|
||||
@ -242,7 +235,7 @@ SoundEventTransformConfig = Annotated[
|
||||
def build_sound_event_transform(
|
||||
config: SoundEventTransformConfig,
|
||||
) -> SoundEventTransform:
|
||||
return transform_registry.build(config)
|
||||
return transforms.build(config)
|
||||
|
||||
|
||||
def transform_clip_annotation(
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.evaluator import MultipleEvaluator, build_evaluator
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
"evaluate",
|
||||
"Evaluator",
|
||||
"MultipleEvaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
@ -27,12 +27,10 @@ class TimeAffinity(AffinityFunction):
|
||||
geometry1, geometry2, time_buffer=self.time_buffer
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: TimeAffinityConfig):
|
||||
return cls(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
|
||||
@affinity_functions.register(TimeAffinityConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TimeAffinityConfig):
|
||||
return TimeAffinity(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_timestamp_affinity(
|
||||
@ -73,12 +71,10 @@ class IntervalIOU(AffinityFunction):
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: IntervalIOUConfig):
|
||||
return cls(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
|
||||
@affinity_functions.register(IntervalIOUConfig)
|
||||
@staticmethod
|
||||
def from_config(config: IntervalIOUConfig):
|
||||
return IntervalIOU(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_interval_iou(
|
||||
@ -127,13 +123,12 @@ class GeometricIOU(AffinityFunction):
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: GeometricIOUConfig):
|
||||
return cls(time_buffer=config.time_buffer)
|
||||
@affinity_functions.register(GeometricIOUConfig)
|
||||
@staticmethod
|
||||
def from_config(config: GeometricIOUConfig):
|
||||
return GeometricIOU(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
|
||||
|
||||
AffinityConfig = Annotated[
|
||||
Union[
|
||||
TimeAffinityConfig,
|
||||
|
||||
@ -1,16 +1,13 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAPConfig,
|
||||
DetectionAPConfig,
|
||||
MetricConfig,
|
||||
from batdetect2.evaluate.evaluator import (
|
||||
EvaluatorConfig,
|
||||
MultipleEvaluatorConfig,
|
||||
)
|
||||
from batdetect2.evaluate.plots import PlotConfig
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||
|
||||
__all__ = [
|
||||
@ -20,15 +17,7 @@ __all__ = [
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
ignore_start_end: float = 0.01
|
||||
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||
metrics: List[MetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionAPConfig(),
|
||||
ClassificationAPConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[PlotConfig] = Field(default_factory=list)
|
||||
evaluator: EvaluatorConfig = Field(default_factory=MultipleEvaluatorConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
|
||||
@ -55,7 +55,10 @@ def evaluate(
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation.evaluator,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
logger = build_logger(
|
||||
config.evaluation.logger,
|
||||
|
||||
@ -1,173 +0,0 @@
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.match import build_matcher, match
|
||||
from batdetect2.evaluate.metrics import build_metric
|
||||
from batdetect2.evaluate.plots import build_plotter
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
MatcherProtocol,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
)
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
config: EvaluationConfig,
|
||||
targets: TargetProtocol,
|
||||
matcher: MatcherProtocol,
|
||||
metrics: List[MetricsProtocol],
|
||||
plots: List[PlotterProtocol],
|
||||
):
|
||||
self.config = config
|
||||
self.targets = targets
|
||||
self.matcher = matcher
|
||||
self.metrics = metrics
|
||||
self.plots = plots
|
||||
|
||||
def match(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEvaluation:
|
||||
clip = clip_annotation.clip
|
||||
ground_truth = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
predictions = [
|
||||
prediction
|
||||
for prediction in predictions
|
||||
if self.filter_predictions(prediction, clip)
|
||||
]
|
||||
return ClipEvaluation(
|
||||
clip=clip_annotation.clip,
|
||||
matches=match(
|
||||
ground_truth,
|
||||
predictions,
|
||||
clip=clip,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
),
|
||||
)
|
||||
|
||||
def filter_sound_event_annotations(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
if not self.targets.filter(sound_event_annotation):
|
||||
return False
|
||||
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
return is_in_bounds(
|
||||
geometry,
|
||||
clip,
|
||||
self.config.ignore_start_end,
|
||||
)
|
||||
|
||||
def filter_predictions(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
return is_in_bounds(
|
||||
prediction.geometry,
|
||||
clip,
|
||||
self.config.ignore_start_end,
|
||||
)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipEvaluation]:
|
||||
if len(clip_annotations) != len(predictions):
|
||||
raise ValueError(
|
||||
"Number of annotated clips and sets of predictions do not match"
|
||||
)
|
||||
|
||||
return [
|
||||
self.match(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEvaluation],
|
||||
) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for metric in self.metrics:
|
||||
results.update(metric(clip_evaluations))
|
||||
|
||||
return results
|
||||
|
||||
def generate_plots(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for plotter in self.plots:
|
||||
for name, fig in plotter(clip_evaluations):
|
||||
yield name, fig
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[EvaluationConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
plots: Optional[List[PlotterProtocol]] = None,
|
||||
metrics: Optional[List[MetricsProtocol]] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
config = config or EvaluationConfig()
|
||||
targets = targets or build_targets()
|
||||
matcher = matcher or build_matcher(config.match_strategy)
|
||||
|
||||
if metrics is None:
|
||||
metrics = [
|
||||
build_metric(config, targets.class_names)
|
||||
for config in config.metrics
|
||||
]
|
||||
|
||||
if plots is None:
|
||||
plots = [
|
||||
build_plotter(config, targets.class_names)
|
||||
for config in config.plots
|
||||
]
|
||||
|
||||
return Evaluator(
|
||||
config=config,
|
||||
targets=targets,
|
||||
matcher=matcher,
|
||||
metrics=metrics,
|
||||
plots=plots,
|
||||
)
|
||||
|
||||
|
||||
def is_in_bounds(
|
||||
geometry: data.Geometry,
|
||||
clip: data.Clip,
|
||||
buffer: float,
|
||||
) -> bool:
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
return (start_time >= clip.start_time + buffer) and (
|
||||
start_time <= clip.end_time - buffer
|
||||
)
|
||||
114
src/batdetect2/evaluate/evaluator/__init__.py
Normal file
114
src/batdetect2/evaluate/evaluator/__init__.py
Normal file
@ -0,0 +1,114 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.evaluate.evaluator.base import evaluators
|
||||
from batdetect2.evaluate.evaluator.clip import ClipMetricsConfig
|
||||
from batdetect2.evaluate.evaluator.per_class import ClassificationMetricsConfig
|
||||
from batdetect2.evaluate.evaluator.single import GlobalEvaluatorConfig
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
EvaluatorProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorConfig",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
|
||||
EvaluatorConfig = Annotated[
|
||||
Union[
|
||||
ClassificationMetricsConfig,
|
||||
GlobalEvaluatorConfig,
|
||||
ClipMetricsConfig,
|
||||
"MultipleEvaluatorConfig",
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class MultipleEvaluatorConfig(BaseConfig):
|
||||
name: Literal["multiple_evaluations"] = "multiple_evaluations"
|
||||
evaluations: List[EvaluatorConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClassificationMetricsConfig(),
|
||||
GlobalEvaluatorConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class MultipleEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
evaluators: Sequence[EvaluatorProtocol],
|
||||
):
|
||||
self.targets = targets
|
||||
self.evaluators = evaluators
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[Any]:
|
||||
return [
|
||||
evaluator.evaluate(
|
||||
clip_annotations,
|
||||
predictions,
|
||||
)
|
||||
for evaluator in self.evaluators
|
||||
]
|
||||
|
||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for evaluator, outputs in zip(self.evaluators, eval_outputs):
|
||||
results.update(evaluator.compute_metrics(outputs))
|
||||
|
||||
return results
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: List[Any],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for evaluator, outputs in zip(self.evaluators, eval_outputs):
|
||||
for name, fig in evaluator.generate_plots(outputs):
|
||||
yield name, fig
|
||||
|
||||
@evaluators.register(MultipleEvaluatorConfig)
|
||||
@staticmethod
|
||||
def from_config(config: MultipleEvaluatorConfig, targets: TargetProtocol):
|
||||
return MultipleEvaluator(
|
||||
evaluators=[
|
||||
build_evaluator(conf, targets=targets)
|
||||
for conf in config.evaluations
|
||||
],
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[EvaluatorConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
config = config or MultipleEvaluatorConfig()
|
||||
return evaluators.build(config, targets)
|
||||
107
src/batdetect2/evaluate/evaluator/base.py
Normal file
107
src/batdetect2/evaluate/evaluator/base.py
Normal file
@ -0,0 +1,107 @@
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
build_matcher,
|
||||
)
|
||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BaseEvaluatorConfig",
|
||||
"BaseEvaluator",
|
||||
]
|
||||
|
||||
evaluators: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry("metric")
|
||||
|
||||
|
||||
class BaseEvaluatorConfig(BaseConfig):
|
||||
prefix: str
|
||||
ignore_start_end: float = 0.01
|
||||
matching_strategy: MatchConfig = Field(
|
||||
default_factory=StartTimeMatchConfig
|
||||
)
|
||||
|
||||
|
||||
class BaseEvaluator(EvaluatorProtocol):
|
||||
targets: TargetProtocol
|
||||
|
||||
matcher: MatcherProtocol
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
prefix: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matcher: MatcherProtocol,
|
||||
targets: TargetProtocol,
|
||||
prefix: str,
|
||||
ignore_start_end: float = 0.01,
|
||||
):
|
||||
self.matcher = matcher
|
||||
self.targets = targets
|
||||
self.prefix = prefix
|
||||
self.ignore_start_end = ignore_start_end
|
||||
|
||||
def filter_sound_event_annotations(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
if not self.targets.filter(sound_event_annotation):
|
||||
return False
|
||||
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
return is_in_bounds(
|
||||
geometry,
|
||||
clip,
|
||||
self.ignore_start_end,
|
||||
)
|
||||
|
||||
def filter_predictions(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
return is_in_bounds(
|
||||
prediction.geometry,
|
||||
clip,
|
||||
self.ignore_start_end,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
config: BaseEvaluatorConfig,
|
||||
targets: TargetProtocol,
|
||||
**kwargs,
|
||||
):
|
||||
matcher = build_matcher(config.matching_strategy)
|
||||
return cls(
|
||||
matcher=matcher,
|
||||
targets=targets,
|
||||
prefix=config.prefix,
|
||||
ignore_start_end=config.ignore_start_end,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def is_in_bounds(
|
||||
geometry: data.Geometry,
|
||||
clip: data.Clip,
|
||||
buffer: float,
|
||||
) -> bool:
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
return (start_time >= clip.start_time + buffer) and (
|
||||
start_time <= clip.end_time - buffer
|
||||
)
|
||||
163
src/batdetect2/evaluate/evaluator/clip.py
Normal file
163
src/batdetect2/evaluate/evaluator/clip.py
Normal file
@ -0,0 +1,163 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Literal, Sequence, Set
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.evaluator.base import (
|
||||
BaseEvaluator,
|
||||
BaseEvaluatorConfig,
|
||||
evaluators,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipInfo:
|
||||
gt_det: bool
|
||||
gt_classes: Set[str]
|
||||
pred_score: float
|
||||
pred_class_scores: Dict[str, float]
|
||||
|
||||
|
||||
ClipMetric = Callable[[Sequence[ClipInfo]], float]
|
||||
|
||||
|
||||
def clip_detection_average_precision(
|
||||
clip_evaluations: Sequence[ClipInfo],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.pred_score)
|
||||
|
||||
return average_precision(y_true=y_true, y_score=y_score)
|
||||
|
||||
|
||||
def clip_detection_roc_auc(
|
||||
clip_evaluations: Sequence[ClipInfo],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.pred_score)
|
||||
|
||||
return float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
||||
|
||||
|
||||
clip_metrics = {
|
||||
"average_precision": clip_detection_average_precision,
|
||||
"roc_auc": clip_detection_roc_auc,
|
||||
}
|
||||
|
||||
|
||||
class ClipMetricsConfig(BaseEvaluatorConfig):
|
||||
name: Literal["clip"] = "clip"
|
||||
prefix: str = "clip"
|
||||
metrics: List[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"average_precision",
|
||||
"roc_auc",
|
||||
]
|
||||
)
|
||||
|
||||
@field_validator("metrics", mode="after")
|
||||
@classmethod
|
||||
def validate_metrics(cls, v: List[str]) -> List[str]:
|
||||
for metric_name in v:
|
||||
if metric_name not in clip_metrics:
|
||||
raise ValueError(f"Unknown metric {metric_name}")
|
||||
return v
|
||||
|
||||
|
||||
class ClipEvaluator(BaseEvaluator):
|
||||
def __init__(self, *args, metrics: Dict[str, ClipMetric], **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.metrics = metrics
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipInfo]:
|
||||
return [
|
||||
self.match_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: List[ClipInfo],
|
||||
) -> Dict[str, float]:
|
||||
scores = {
|
||||
name: metric(eval_outputs) for name, metric in self.metrics.items()
|
||||
}
|
||||
return {
|
||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
||||
}
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipInfo:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gt_det = False
|
||||
gt_classes = set()
|
||||
for sound_event in clip_annotation.sound_events:
|
||||
if self.filter_sound_event_annotations(sound_event, clip):
|
||||
continue
|
||||
|
||||
gt_det = True
|
||||
class_name = self.targets.encode_class(sound_event)
|
||||
|
||||
if class_name is None:
|
||||
continue
|
||||
|
||||
gt_classes.add(class_name)
|
||||
|
||||
pred_score = 0
|
||||
pred_class_scores: defaultdict[str, float] = defaultdict(lambda: 0)
|
||||
for pred in predictions:
|
||||
if self.filter_predictions(pred, clip):
|
||||
continue
|
||||
|
||||
pred_score = max(pred_score, pred.detection_score)
|
||||
|
||||
for class_name, class_score in zip(
|
||||
self.targets.class_names,
|
||||
pred.class_scores,
|
||||
):
|
||||
pred_class_scores[class_name] = max(
|
||||
pred_class_scores[class_name],
|
||||
class_score,
|
||||
)
|
||||
|
||||
return ClipInfo(
|
||||
gt_det=gt_det,
|
||||
gt_classes=gt_classes,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=pred_class_scores,
|
||||
)
|
||||
|
||||
@evaluators.register(ClipMetricsConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClipMetricsConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {name: clip_metrics.get(name) for name in config.metrics}
|
||||
return ClipEvaluator.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
0
src/batdetect2/evaluate/evaluator/multiple.py
Normal file
0
src/batdetect2/evaluate/evaluator/multiple.py
Normal file
219
src/batdetect2/evaluate/evaluator/per_class.py
Normal file
219
src/batdetect2/evaluate/evaluator/per_class.py
Normal file
@ -0,0 +1,219 @@
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.evaluator.base import (
|
||||
BaseEvaluator,
|
||||
BaseEvaluatorConfig,
|
||||
evaluators,
|
||||
)
|
||||
from batdetect2.evaluate.match import match
|
||||
from batdetect2.evaluate.metrics.per_class_matches import (
|
||||
ClassificationAveragePrecisionConfig,
|
||||
PerClassMatchMetric,
|
||||
PerClassMatchMetricConfig,
|
||||
build_per_class_matches_metric,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipMatches,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
ScoreFn = Callable[[RawPrediction, int], float]
|
||||
|
||||
|
||||
def score_by_class_score(pred: RawPrediction, class_index: int) -> float:
|
||||
return float(pred.class_scores[class_index])
|
||||
|
||||
|
||||
def score_by_adjusted_class_score(
|
||||
pred: RawPrediction,
|
||||
class_index: int,
|
||||
) -> float:
|
||||
return float(pred.class_scores[class_index]) * pred.detection_score
|
||||
|
||||
|
||||
ScoreFunctionOption = Literal["class_score", "adjusted_class_score"]
|
||||
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
|
||||
"class_score": score_by_class_score,
|
||||
"adjusted_class_score": score_by_adjusted_class_score,
|
||||
}
|
||||
|
||||
|
||||
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
|
||||
return score_functions[name]
|
||||
|
||||
|
||||
class ClassificationMetricsConfig(BaseEvaluatorConfig):
|
||||
name: Literal["classification"] = "classification"
|
||||
prefix: str = "classification"
|
||||
include_generics: bool = True
|
||||
score_by: ScoreFunctionOption = "class_score"
|
||||
metrics: List[PerClassMatchMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
)
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class PerClassEvaluator(BaseEvaluator):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
metrics: Dict[str, PerClassMatchMetric],
|
||||
score_fn: ScoreFn,
|
||||
include_generics: bool = True,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.score_fn = score_fn
|
||||
self.metrics = metrics
|
||||
|
||||
self.include_generics = include_generics
|
||||
|
||||
self.include = include
|
||||
self.exclude = exclude
|
||||
|
||||
self.selected = self.targets.class_names
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> Dict[str, List[ClipMatches]]:
|
||||
ret = defaultdict(list)
|
||||
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions):
|
||||
matches = self.match_clip(clip_annotation, preds)
|
||||
for class_name, clip_eval in matches.items():
|
||||
ret[class_name].append(clip_eval)
|
||||
|
||||
return ret
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: Dict[str, List[ClipMatches]],
|
||||
) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for metric_name, metric in self.metrics.items():
|
||||
class_scores = {
|
||||
class_name: metric(eval_outputs[class_name], class_name)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
mean = float(
|
||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||
)
|
||||
|
||||
results[f"{self.prefix}/mean_{metric_name}"] = mean
|
||||
|
||||
for class_name, value in class_scores.items():
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
results[f"{self.prefix}/{metric_name}/{class_name}"] = value
|
||||
|
||||
return results
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> Dict[str, ClipMatches]:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
preds = [
|
||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
||||
]
|
||||
|
||||
all_gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
|
||||
ret = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
class_idx = self.targets.class_names.index(class_name)
|
||||
|
||||
# Only match to targets of the given class
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in all_gts
|
||||
if self.is_class(sound_event, class_name)
|
||||
]
|
||||
scores = [self.score_fn(pred, class_idx) for pred in preds]
|
||||
|
||||
ret[class_name] = match(
|
||||
gts,
|
||||
preds,
|
||||
clip=clip,
|
||||
scores=scores,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def is_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
) -> bool:
|
||||
sound_event_class = self.targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and self.include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@evaluators.register(ClassificationMetricsConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClassificationMetricsConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {
|
||||
metric.name: build_per_class_matches_metric(metric)
|
||||
for metric in config.metrics
|
||||
}
|
||||
return PerClassEvaluator.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
score_fn=get_score_fn(config.score_by),
|
||||
include_generics=config.include_generics,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
126
src/batdetect2/evaluate/evaluator/single.py
Normal file
126
src/batdetect2/evaluate/evaluator/single.py
Normal file
@ -0,0 +1,126 @@
|
||||
from typing import Callable, Dict, List, Literal, Mapping, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.evaluator.base import (
|
||||
BaseEvaluator,
|
||||
BaseEvaluatorConfig,
|
||||
evaluators,
|
||||
)
|
||||
from batdetect2.evaluate.match import match
|
||||
from batdetect2.evaluate.metrics.matches import (
|
||||
DetectionAveragePrecisionConfig,
|
||||
MatchesMetric,
|
||||
MatchMetricConfig,
|
||||
build_match_metric,
|
||||
)
|
||||
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
|
||||
|
||||
ScoreFn = Callable[[RawPrediction], float]
|
||||
|
||||
|
||||
def score_by_detection_score(pred: RawPrediction) -> float:
|
||||
return pred.detection_score
|
||||
|
||||
|
||||
def score_by_top_class_score(pred: RawPrediction) -> float:
|
||||
return pred.class_scores.max()
|
||||
|
||||
|
||||
ScoreFunctionOption = Literal["detection_score", "top_class_score"]
|
||||
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
|
||||
"detection_score": score_by_detection_score,
|
||||
"top_class_score": score_by_top_class_score,
|
||||
}
|
||||
|
||||
|
||||
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
|
||||
return score_functions[name]
|
||||
|
||||
|
||||
class GlobalEvaluatorConfig(BaseEvaluatorConfig):
|
||||
name: Literal["detection"] = "detection"
|
||||
prefix: str = "detection"
|
||||
score_by: ScoreFunctionOption = "detection_score"
|
||||
metrics: List[MatchMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
)
|
||||
|
||||
|
||||
class GlobalEvaluator(BaseEvaluator):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
score_fn: ScoreFn,
|
||||
metrics: Dict[str, MatchesMetric],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.metrics = metrics
|
||||
self.score_fn = score_fn
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: List[ClipMatches],
|
||||
) -> Dict[str, float]:
|
||||
scores = {
|
||||
name: metric(eval_outputs) for name, metric in self.metrics.items()
|
||||
}
|
||||
return {
|
||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
||||
}
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipMatches]:
|
||||
return [
|
||||
self.match_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipMatches:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
||||
]
|
||||
scores = [self.score_fn(pred) for pred in preds]
|
||||
|
||||
return match(
|
||||
gts,
|
||||
preds,
|
||||
scores=scores,
|
||||
clip=clip,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
)
|
||||
|
||||
@evaluators.register(GlobalEvaluatorConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: GlobalEvaluatorConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {
|
||||
metric.name: build_match_metric(metric)
|
||||
for metric in config.metrics
|
||||
}
|
||||
score_fn = get_score_fn(config.score_by)
|
||||
return GlobalEvaluator.build(
|
||||
config=config,
|
||||
score_fn=score_fn,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
133
src/batdetect2/evaluate/evaluator/top_class.py
Normal file
133
src/batdetect2/evaluate/evaluator/top_class.py
Normal file
@ -0,0 +1,133 @@
|
||||
from typing import Dict, List, Literal, Sequence
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.match import match
|
||||
from batdetect2.evaluate.metrics.base import (
|
||||
BaseMetric,
|
||||
BaseMetricConfig,
|
||||
metrics_registry,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.evaluate.metrics.detection import DetectionMetric
|
||||
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TopClassEvaluator",
|
||||
"TopClassEvaluatorConfig",
|
||||
]
|
||||
|
||||
|
||||
def top_class_average_precision(
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_generic = m.gt_det and (m.gt_class is None)
|
||||
|
||||
# Ignore ground truth sounds with unknown class
|
||||
if is_generic:
|
||||
continue
|
||||
|
||||
num_positives += int(m.gt_det)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if m.pred_geometry is None:
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det & (m.top_class == m.gt_class))
|
||||
y_score.append(m.top_class_score)
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
|
||||
top_class_metrics = {
|
||||
"average_precision": top_class_average_precision,
|
||||
}
|
||||
|
||||
|
||||
class TopClassEvaluatorConfig(BaseMetricConfig):
|
||||
name: Literal["top_class"] = "top_class"
|
||||
prefix: str = "top_class"
|
||||
metrics: List[str] = Field(default_factory=lambda: ["average_precision"])
|
||||
|
||||
@field_validator("metrics", mode="after")
|
||||
@classmethod
|
||||
def validate_metrics(cls, v: List[str]) -> List[str]:
|
||||
for metric_name in v:
|
||||
if metric_name not in top_class_metrics:
|
||||
raise ValueError(f"Unknown metric {metric_name}")
|
||||
return v
|
||||
|
||||
|
||||
class TopClassEvaluator(BaseMetric):
|
||||
def __init__(self, *args, metrics: Dict[str, DetectionMetric], **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.metrics = metrics
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> Dict[str, float]:
|
||||
clip_evaluations = [
|
||||
self.match_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
scores = {
|
||||
name: metric(clip_evaluations)
|
||||
for name, metric in self.metrics.items()
|
||||
}
|
||||
return {
|
||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
||||
}
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipMatches:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
||||
]
|
||||
# Use score of top class for matching
|
||||
scores = [pred.class_scores.max() for pred in preds]
|
||||
|
||||
return match(
|
||||
gts,
|
||||
preds,
|
||||
scores=scores,
|
||||
clip=clip,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: TopClassEvaluatorConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {
|
||||
name: top_class_metrics.get(name) for name in config.metrics
|
||||
}
|
||||
return super().build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(TopClassEvaluatorConfig, TopClassEvaluator)
|
||||
@ -8,7 +8,7 @@ from batdetect2.evaluate.tables import FullEvaluationTable
|
||||
from batdetect2.logging import get_image_logger, get_table_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
|
||||
from batdetect2.typing import ClipMatches, EvaluatorProtocol
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
@ -56,7 +56,7 @@ class EvaluationModule(LightningModule):
|
||||
self.plot_examples(self.clip_evaluations)
|
||||
self.log_table(self.clip_evaluations)
|
||||
|
||||
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
def log_table(self, evaluated_clips: Sequence[ClipMatches]):
|
||||
table_logger = get_table_logger(self.logger) # type: ignore
|
||||
|
||||
if table_logger is None:
|
||||
@ -65,7 +65,7 @@ class EvaluationModule(LightningModule):
|
||||
df = FullEvaluationTable()(evaluated_clips)
|
||||
table_logger("full_evaluation", df, 0)
|
||||
|
||||
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
|
||||
plotter = get_image_logger(self.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
@ -74,7 +74,7 @@ class EvaluationModule(LightningModule):
|
||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||
plotter(figure_name, fig, self.global_step)
|
||||
|
||||
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
|
||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||
self.log_dict(metrics)
|
||||
|
||||
|
||||
@ -8,8 +8,7 @@ from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import match_geometries as optimal_match
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
GeometricIOUConfig,
|
||||
@ -17,11 +16,13 @@ from batdetect2.evaluate.affinity import (
|
||||
)
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
AffinityFunction,
|
||||
MatcherProtocol,
|
||||
MatchEvaluation,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.evaluate import ClipMatches
|
||||
|
||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||
"""The geometry representation to use for matching."""
|
||||
@ -33,9 +34,10 @@ def match(
|
||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||
raw_predictions: Sequence[RawPrediction],
|
||||
clip: data.Clip,
|
||||
scores: Optional[Sequence[float]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
) -> ClipMatches:
|
||||
if matcher is None:
|
||||
matcher = build_matcher()
|
||||
|
||||
@ -51,8 +53,10 @@ def match(
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
if scores is None:
|
||||
scores = [
|
||||
raw_prediction.detection_score for raw_prediction in raw_predictions
|
||||
raw_prediction.detection_score
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
@ -73,9 +77,11 @@ def match(
|
||||
|
||||
gt_det = target_idx is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
gt_geometry = (
|
||||
target_geometries[target_idx] if target_idx is not None else None
|
||||
)
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
|
||||
pred_geometry = (
|
||||
predicted_geometries[source_idx]
|
||||
if source_idx is not None
|
||||
@ -84,7 +90,7 @@ def match(
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
str(class_name): float(score)
|
||||
class_name: score
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.class_scores,
|
||||
@ -100,6 +106,7 @@ def match(
|
||||
sound_event_annotation=target,
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
gt_geometry=gt_geometry,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=class_scores,
|
||||
pred_geometry=pred_geometry,
|
||||
@ -107,7 +114,7 @@ def match(
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
return ClipMatches(clip=clip, matches=matches)
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
@ -132,12 +139,10 @@ class StartTimeMatcher(MatcherProtocol):
|
||||
distance_threshold=self.distance_threshold,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher":
|
||||
return cls(distance_threshold=config.distance_threshold)
|
||||
|
||||
|
||||
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
|
||||
@matching_strategies.register(StartTimeMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: StartTimeMatchConfig):
|
||||
return StartTimeMatcher(distance_threshold=config.distance_threshold)
|
||||
|
||||
|
||||
def match_start_times(
|
||||
@ -264,19 +269,17 @@ class GreedyMatcher(MatcherProtocol):
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: GreedyMatchConfig):
|
||||
@matching_strategies.register(GreedyMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: GreedyMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return cls(
|
||||
return GreedyMatcher(
|
||||
geometry=config.geometry,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
)
|
||||
|
||||
|
||||
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
|
||||
|
||||
|
||||
def greedy_match(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
@ -313,21 +316,21 @@ def greedy_match(
|
||||
unassigned_gt = set(range(len(ground_truth)))
|
||||
|
||||
if not predictions:
|
||||
for target_idx in range(len(ground_truth)):
|
||||
yield None, target_idx, 0
|
||||
for gt_idx in range(len(ground_truth)):
|
||||
yield None, gt_idx, 0
|
||||
|
||||
return
|
||||
|
||||
if not ground_truth:
|
||||
for source_idx in range(len(predictions)):
|
||||
yield source_idx, None, 0
|
||||
for pred_idx in range(len(predictions)):
|
||||
yield pred_idx, None, 0
|
||||
|
||||
return
|
||||
|
||||
indices = np.argsort(scores)[::-1]
|
||||
|
||||
for source_idx in indices:
|
||||
source_geometry = predictions[source_idx]
|
||||
for pred_idx in indices:
|
||||
source_geometry = predictions[pred_idx]
|
||||
|
||||
affinities = np.array(
|
||||
[
|
||||
@ -340,18 +343,18 @@ def greedy_match(
|
||||
affinity = affinities[closest_target]
|
||||
|
||||
if affinities[closest_target] <= affinity_threshold:
|
||||
yield source_idx, None, 0
|
||||
yield pred_idx, None, 0
|
||||
continue
|
||||
|
||||
if closest_target not in unassigned_gt:
|
||||
yield source_idx, None, 0
|
||||
yield pred_idx, None, 0
|
||||
continue
|
||||
|
||||
unassigned_gt.remove(closest_target)
|
||||
yield source_idx, closest_target, affinity
|
||||
yield pred_idx, closest_target, affinity
|
||||
|
||||
for target_idx in unassigned_gt:
|
||||
yield None, target_idx, 0
|
||||
for gt_idx in unassigned_gt:
|
||||
yield None, gt_idx, 0
|
||||
|
||||
|
||||
class OptimalMatchConfig(BaseConfig):
|
||||
@ -386,17 +389,16 @@ class OptimalMatcher(MatcherProtocol):
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: OptimalMatchConfig):
|
||||
return cls(
|
||||
@matching_strategies.register(OptimalMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: OptimalMatchConfig):
|
||||
return OptimalMatcher(
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
time_buffer=config.time_buffer,
|
||||
frequency_buffer=config.frequency_buffer,
|
||||
)
|
||||
|
||||
|
||||
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[
|
||||
GreedyMatchConfig,
|
||||
|
||||
@ -1,712 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics, preprocessing
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipEvaluation, MetricsProtocol
|
||||
|
||||
__all__ = ["DetectionAP", "ClassificationAP"]
|
||||
|
||||
|
||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
||||
|
||||
|
||||
APImplementation = Literal["sklearn", "pascal_voc"]
|
||||
|
||||
|
||||
class DetectionAPConfig(BaseConfig):
|
||||
name: Literal["detection_ap"] = "detection_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class DetectionAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
score = float(self.metric(y_true, y_score))
|
||||
return {"detection_AP": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||
|
||||
|
||||
class DetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
||||
|
||||
|
||||
class DetectionROCAUC(MetricsProtocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {"detection_ROC_AUC": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: DetectionROCAUCConfig, class_names: List[str]
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
|
||||
|
||||
|
||||
class ClassificationAPConfig(BaseConfig):
|
||||
name: Literal["classification_ap"] = "classification_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
self.class_names = class_names
|
||||
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_ap = self.metric(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_ap)
|
||||
|
||||
mean_ap = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
|
||||
return {
|
||||
"classification_mAP": float(mean_ap),
|
||||
**{
|
||||
f"classification_AP/{class_name}": class_scores[class_name]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationAPConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names,
|
||||
implementation=config.ap_implementation,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
||||
|
||||
|
||||
class ClassificationROCAUCConfig(BaseConfig):
|
||||
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationROCAUC(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_roc_auc)
|
||||
|
||||
mean_roc_auc = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
|
||||
return {
|
||||
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
|
||||
**{
|
||||
f"classification_ROC_AUC/{class_name}": class_scores[
|
||||
class_name
|
||||
]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
|
||||
|
||||
|
||||
class TopClassAPConfig(BaseConfig):
|
||||
name: Literal["top_class_ap"] = "top_class_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class TopClassAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
top_class = match.pred_class
|
||||
|
||||
y_true.append(top_class == match.gt_class)
|
||||
y_score.append(match.pred_class_score)
|
||||
|
||||
score = float(self.metric(y_true, y_score))
|
||||
return {"top_class_AP": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(TopClassAPConfig, TopClassAP)
|
||||
|
||||
|
||||
class ClassificationBalancedAccuracyConfig(BaseConfig):
|
||||
name: Literal["classification_balanced_accuracy"] = (
|
||||
"classification_balanced_accuracy"
|
||||
)
|
||||
|
||||
|
||||
class ClassificationBalancedAccuracy(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str]):
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
top_class = match.pred_class
|
||||
|
||||
# Focus on matches
|
||||
if match.gt_class is None or top_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(self.class_names.index(match.gt_class))
|
||||
y_pred.append(self.class_names.index(top_class))
|
||||
|
||||
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
|
||||
return {"classification_balanced_accuracy": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationBalancedAccuracyConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(class_names)
|
||||
|
||||
|
||||
metrics_registry.register(
|
||||
ClassificationBalancedAccuracyConfig,
|
||||
ClassificationBalancedAccuracy,
|
||||
)
|
||||
|
||||
|
||||
class ClipDetectionAPConfig(BaseConfig):
|
||||
name: Literal["clip_detection_ap"] = "clip_detection_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class ClipDetectionAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: APImplementation,
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_det = []
|
||||
clip_scores = []
|
||||
|
||||
for match in clip_eval.matches:
|
||||
clip_det.append(match.gt_det)
|
||||
clip_scores.append(match.pred_score)
|
||||
|
||||
y_true.append(any(clip_det))
|
||||
y_score.append(max(clip_scores or [0]))
|
||||
|
||||
return {"clip_detection_ap": self.metric(y_true, y_score)}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipDetectionAPConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
|
||||
|
||||
|
||||
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
|
||||
|
||||
|
||||
class ClipDetectionROCAUC(MetricsProtocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_det = []
|
||||
clip_scores = []
|
||||
|
||||
for match in clip_eval.matches:
|
||||
clip_det.append(match.gt_det)
|
||||
clip_scores.append(match.pred_score)
|
||||
|
||||
y_true.append(any(clip_det))
|
||||
y_score.append(max(clip_scores or [0]))
|
||||
|
||||
return {
|
||||
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipDetectionROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
|
||||
|
||||
|
||||
class ClipMulticlassAPConfig(BaseConfig):
|
||||
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipMulticlassAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
implementation: APImplementation,
|
||||
include: Optional[Sequence[str]] = None,
|
||||
exclude: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
self.class_names = class_names
|
||||
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_classes = set()
|
||||
clip_scores = defaultdict(list)
|
||||
|
||||
for match in clip_eval.matches:
|
||||
if match.gt_class is not None:
|
||||
clip_classes.add(match.gt_class)
|
||||
|
||||
for class_name, score in match.pred_class_scores.items():
|
||||
clip_scores[class_name].append(score)
|
||||
|
||||
y_true.append(clip_classes)
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
# Get max score for each class
|
||||
max(clip_scores.get(class_name, [0]))
|
||||
for class_name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.MultiLabelBinarizer(
|
||||
classes=self.class_names
|
||||
).fit_transform(y_true)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_ap = self.metric(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_ap)
|
||||
|
||||
mean_ap = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_multiclass_mAP": float(mean_ap),
|
||||
**{
|
||||
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: ClipMulticlassAPConfig, class_names: List[str]
|
||||
):
|
||||
return cls(
|
||||
implementation=config.ap_implementation,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
|
||||
|
||||
|
||||
class ClipMulticlassROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipMulticlassROCAUC(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[Sequence[str]] = None,
|
||||
exclude: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_classes = set()
|
||||
clip_scores = defaultdict(list)
|
||||
|
||||
for match in clip_eval.matches:
|
||||
if match.gt_class is not None:
|
||||
clip_classes.add(match.gt_class)
|
||||
|
||||
for class_name, score in match.pred_class_scores.items():
|
||||
clip_scores[class_name].append(score)
|
||||
|
||||
y_true.append(clip_classes)
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
# Get maximum score for each class
|
||||
max(clip_scores.get(class_name, [0]))
|
||||
for class_name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.MultiLabelBinarizer(
|
||||
classes=self.class_names
|
||||
).fit_transform(y_true)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_roc_auc)
|
||||
|
||||
mean_roc_auc = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
|
||||
**{
|
||||
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
|
||||
class_name
|
||||
]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipMulticlassROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
|
||||
|
||||
MetricConfig = Annotated[
|
||||
Union[
|
||||
DetectionAPConfig,
|
||||
DetectionROCAUCConfig,
|
||||
ClassificationAPConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
TopClassAPConfig,
|
||||
ClassificationBalancedAccuracyConfig,
|
||||
ClipDetectionAPConfig,
|
||||
ClipDetectionROCAUCConfig,
|
||||
ClipMulticlassAPConfig,
|
||||
ClipMulticlassROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_metric(config: MetricConfig, class_names: List[str]):
|
||||
return metrics_registry.build(config, class_names)
|
||||
|
||||
|
||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
|
||||
num_positives = y_true.sum()
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c,
|
||||
np.finfo(np.float64).eps,
|
||||
)
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
|
||||
|
||||
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
|
||||
"sklearn": metrics.average_precision_score,
|
||||
"pascal_voc": pascal_voc_average_precision,
|
||||
}
|
||||
0
src/batdetect2/evaluate/metrics/__init__.py
Normal file
0
src/batdetect2/evaluate/metrics/__init__.py
Normal file
46
src/batdetect2/evaluate/metrics/common.py
Normal file
46
src/batdetect2/evaluate/metrics/common.py
Normal file
@ -0,0 +1,46 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def average_precision(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> float:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
if num_positives is None:
|
||||
num_positives = y_true.sum()
|
||||
|
||||
# Remove non-detections
|
||||
valid_inds = y_score > 0
|
||||
y_true = y_true[valid_inds]
|
||||
y_score = y_score[valid_inds]
|
||||
|
||||
# Sort by score
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c,
|
||||
np.finfo(np.float64).eps,
|
||||
)
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
235
src/batdetect2/evaluate/metrics/matches.py
Normal file
235
src/batdetect2/evaluate/metrics/matches.py
Normal file
@ -0,0 +1,235 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import (
|
||||
ClipMatches,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MatchMetricConfig",
|
||||
"MatchesMetric",
|
||||
"build_match_metric",
|
||||
]
|
||||
|
||||
MatchesMetric = Callable[[Sequence[ClipMatches]], float]
|
||||
|
||||
|
||||
metrics_registry: Registry[MatchesMetric, []] = Registry("match_metric")
|
||||
|
||||
|
||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["detection_average_precision"] = (
|
||||
"detection_average_precision"
|
||||
)
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionAveragePrecision:
|
||||
def __init__(self, ignore_non_predictions: bool = True):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
num_positives += int(m.gt_det)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det)
|
||||
y_score.append(m.pred_score)
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
@metrics_registry.register(DetectionAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionAveragePrecisionConfig):
|
||||
return DetectionAveragePrecision(
|
||||
ignore_non_predictions=config.ignore_non_predictions
|
||||
)
|
||||
|
||||
|
||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["top_class_average_precision"] = (
|
||||
"top_class_average_precision"
|
||||
)
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class TopClassAveragePrecision:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore ground truth sounds with unknown class
|
||||
continue
|
||||
|
||||
num_positives += int(m.gt_det)
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det & (m.top_class == m.gt_class))
|
||||
y_score.append(m.top_class_score)
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
@metrics_registry.register(TopClassAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassAveragePrecisionConfig):
|
||||
return TopClassAveragePrecision(
|
||||
ignore_non_predictions=config.ignore_non_predictions
|
||||
)
|
||||
|
||||
|
||||
class DetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det)
|
||||
y_score.append(m.pred_score)
|
||||
|
||||
return float(metrics.roc_auc_score(y_true, y_score))
|
||||
|
||||
@metrics_registry.register(DetectionROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionROCAUCConfig):
|
||||
return DetectionROCAUC(
|
||||
ignore_non_predictions=config.ignore_non_predictions
|
||||
)
|
||||
|
||||
|
||||
class DetectionRecallConfig(BaseConfig):
|
||||
name: Literal["detection_recall"] = "detection_recall"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionRecall:
|
||||
def __init__(self, threshold: float):
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.gt_det:
|
||||
num_positives += 1
|
||||
|
||||
if m.pred_score >= self.threshold and m.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return 1
|
||||
|
||||
return true_positives / num_positives
|
||||
|
||||
@metrics_registry.register(DetectionRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionRecallConfig):
|
||||
return DetectionRecall(threshold=config.threshold)
|
||||
|
||||
|
||||
class DetectionPrecisionConfig(BaseConfig):
|
||||
name: Literal["detection_precision"] = "detection_precision"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionPrecision:
|
||||
def __init__(self, threshold: float):
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_detection = m.pred_score >= self.threshold
|
||||
|
||||
if is_detection:
|
||||
num_detections += 1
|
||||
|
||||
if is_detection and m.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return np.nan
|
||||
|
||||
return true_positives / num_detections
|
||||
|
||||
@metrics_registry.register(DetectionPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionPrecisionConfig):
|
||||
return DetectionPrecision(threshold=config.threshold)
|
||||
|
||||
|
||||
MatchMetricConfig = Annotated[
|
||||
Union[
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionROCAUCConfig,
|
||||
DetectionRecallConfig,
|
||||
DetectionPrecisionConfig,
|
||||
TopClassAveragePrecisionConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_match_metric(config: MatchMetricConfig):
|
||||
return metrics_registry.build(config)
|
||||
136
src/batdetect2/evaluate/metrics/per_class_matches.py
Normal file
136
src/batdetect2/evaluate/metrics/per_class_matches.py
Normal file
@ -0,0 +1,136 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import (
|
||||
ClipMatches,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
|
||||
PerClassMatchMetric = Callable[[Sequence[ClipMatches], str], float]
|
||||
|
||||
|
||||
metrics_registry: Registry[PerClassMatchMetric, []] = Registry(
|
||||
"match_metric"
|
||||
)
|
||||
|
||||
|
||||
class ClassificationAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["classification_average_precision"] = (
|
||||
"classification_average_precision"
|
||||
)
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ClassificationAveragePrecision:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
class_name: str,
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_class = m.gt_class == class_name
|
||||
|
||||
if is_class:
|
||||
num_positives += 1
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
# Exclude matches with ground truth sounds where the class is
|
||||
# unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
y_true.append(is_class)
|
||||
y_score.append(m.pred_class_scores.get(class_name, 0))
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
@metrics_registry.register(ClassificationAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationAveragePrecisionConfig):
|
||||
return ClassificationAveragePrecision(
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ClassificationROCAUCConfig(BaseConfig):
|
||||
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ClassificationROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
class_name: str,
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
# Exclude matches with ground truth sounds where the class is
|
||||
# unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_class == class_name)
|
||||
y_score.append(m.pred_class_scores.get(class_name, 0))
|
||||
|
||||
return float(metrics.roc_auc_score(y_true, y_score))
|
||||
|
||||
@metrics_registry.register(ClassificationROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationROCAUCConfig):
|
||||
return ClassificationROCAUC(
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
PerClassMatchMetricConfig = Annotated[
|
||||
Union[
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_per_class_matches_metric(config: PerClassMatchMetricConfig):
|
||||
return metrics_registry.build(config)
|
||||
@ -17,7 +17,7 @@ from batdetect2.plotting.matches import plot_matches
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipEvaluation,
|
||||
ClipMatches,
|
||||
MatchEvaluation,
|
||||
PlotterProtocol,
|
||||
PreprocessorProtocol,
|
||||
@ -53,7 +53,7 @@ class ExampleGallery(PlotterProtocol):
|
||||
self.preprocessor = preprocessor or build_preprocessor()
|
||||
self.audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
per_class_matches = group_matches(clip_evaluations)
|
||||
|
||||
for class_name, matches in per_class_matches.items():
|
||||
@ -128,7 +128,7 @@ class PlotClipEvaluation(PlotterProtocol):
|
||||
self.audio_loader = audio_loader
|
||||
self.num_plots = num_plots
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
examples = random.sample(
|
||||
clip_evaluations,
|
||||
k=min(self.num_plots, len(clip_evaluations)),
|
||||
@ -171,7 +171,7 @@ class DetectionPRCurveConfig(BaseConfig):
|
||||
|
||||
|
||||
class DetectionPRCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
@ -231,7 +231,7 @@ class ClassificationPRCurves(PlotterProtocol):
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
@ -303,7 +303,7 @@ class DetectionROCCurveConfig(BaseConfig):
|
||||
|
||||
|
||||
class DetectionROCCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
@ -363,7 +363,7 @@ class ClassificationROCCurves(PlotterProtocol):
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
@ -440,7 +440,7 @@ class ConfusionMatrix(PlotterProtocol):
|
||||
self.background_class = background_class
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
@ -456,7 +456,7 @@ class ConfusionMatrix(PlotterProtocol):
|
||||
else self.background_class
|
||||
)
|
||||
|
||||
top_class = match.pred_class
|
||||
top_class = match.top_class
|
||||
y_pred.append(
|
||||
top_class
|
||||
if top_class is not None
|
||||
@ -515,14 +515,14 @@ class ClassMatches:
|
||||
|
||||
|
||||
def group_matches(
|
||||
clip_evaluations: Sequence[ClipEvaluation],
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> Dict[str, ClassMatches]:
|
||||
class_examples = defaultdict(ClassMatches)
|
||||
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
pred_class = match.top_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
@ -550,7 +550,7 @@ def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||
*[
|
||||
(index, match.pred_class_scores[pred_class])
|
||||
for index, match in enumerate(matches)
|
||||
if (pred_class := match.pred_class) is not None
|
||||
if (pred_class := match.top_class) is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -5,9 +5,9 @@ from pydantic import Field
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipEvaluation
|
||||
from batdetect2.typing import ClipMatches
|
||||
|
||||
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
|
||||
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
|
||||
|
||||
|
||||
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||
@ -21,20 +21,18 @@ class FullEvaluationTableConfig(BaseConfig):
|
||||
|
||||
class FullEvaluationTable:
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
self, clip_evaluations: Sequence[ClipMatches]
|
||||
) -> pd.DataFrame:
|
||||
return extract_matches_dataframe(clip_evaluations)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: FullEvaluationTableConfig):
|
||||
return cls()
|
||||
|
||||
|
||||
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
|
||||
@tables_registry.register(FullEvaluationTableConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FullEvaluationTableConfig):
|
||||
return FullEvaluationTable()
|
||||
|
||||
|
||||
def extract_matches_dataframe(
|
||||
clip_evaluations: Sequence[ClipEvaluation],
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
@ -78,8 +76,8 @@ def extract_matches_dataframe(
|
||||
("gt", "low_freq"): gt_low_freq,
|
||||
("gt", "high_freq"): gt_high_freq,
|
||||
("pred", "score"): match.pred_score,
|
||||
("pred", "class"): match.pred_class,
|
||||
("pred", "class_score"): match.pred_class_score,
|
||||
("pred", "class"): match.top_class,
|
||||
("pred", "class_score"): match.top_class_score,
|
||||
("pred", "start_time"): pred_start_time,
|
||||
("pred", "end_time"): pred_end_time,
|
||||
("pred", "low_freq"): pred_low_freq,
|
||||
|
||||
@ -65,8 +65,6 @@ def plot_anchor_points(
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
sound_event = targets.transform(sound_event)
|
||||
|
||||
position, _ = targets.encode_roi(sound_event)
|
||||
positions.append(position)
|
||||
|
||||
|
||||
@ -162,7 +162,7 @@ def plot_false_positive_match(
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.top_class} \nTop Class Score: {match.top_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
@ -312,7 +312,7 @@ def plot_true_positive_match(
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
@ -394,7 +394,7 @@ def plot_cross_trigger_match(
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.top_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
|
||||
@ -28,12 +28,10 @@ class CenterAudio(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return center_tensor(wav)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: CenterAudioConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
audio_transforms.register(CenterAudioConfig, CenterAudio)
|
||||
@audio_transforms.register(CenterAudioConfig)
|
||||
@staticmethod
|
||||
def from_config(config: CenterAudioConfig, samplerate: int):
|
||||
return CenterAudio()
|
||||
|
||||
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
@ -44,12 +42,10 @@ class ScaleAudio(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return peak_normalize(wav)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
|
||||
@audio_transforms.register(ScaleAudioConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ScaleAudioConfig, samplerate: int):
|
||||
return ScaleAudio()
|
||||
|
||||
|
||||
class FixDurationConfig(BaseConfig):
|
||||
@ -75,13 +71,12 @@ class FixDuration(torch.nn.Module):
|
||||
|
||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: FixDurationConfig, samplerate: int):
|
||||
return cls(samplerate=samplerate, duration=config.duration)
|
||||
@audio_transforms.register(FixDurationConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FixDurationConfig, samplerate: int):
|
||||
return FixDuration(samplerate=samplerate, duration=config.duration)
|
||||
|
||||
|
||||
audio_transforms.register(FixDurationConfig, FixDuration)
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
FixDurationConfig,
|
||||
|
||||
@ -285,10 +285,11 @@ class PCEN(torch.nn.Module):
|
||||
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
||||
).to(spec.dtype)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PcenConfig, samplerate: int):
|
||||
@spectrogram_transforms.register(PcenConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PcenConfig, samplerate: int):
|
||||
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
|
||||
return cls(
|
||||
return PCEN(
|
||||
smoothing_constant=smooth,
|
||||
gain=config.gain,
|
||||
bias=config.bias,
|
||||
@ -296,9 +297,6 @@ class PCEN(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
spectrogram_transforms.register(PcenConfig, PCEN)
|
||||
|
||||
|
||||
def _compute_smoothing_constant(
|
||||
samplerate: int,
|
||||
time_constant: float,
|
||||
@ -335,12 +333,10 @@ class ScaleAmplitude(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.scaler(spec)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
|
||||
return cls(scale=config.scale)
|
||||
|
||||
|
||||
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
|
||||
@spectrogram_transforms.register(ScaleAmplitudeConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ScaleAmplitudeConfig, samplerate: int):
|
||||
return ScaleAmplitude(scale=config.scale)
|
||||
|
||||
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
@ -352,19 +348,13 @@ class SpectralMeanSubstraction(torch.nn.Module):
|
||||
mean = spec.mean(-1, keepdim=True)
|
||||
return (spec - mean).clamp(min=0)
|
||||
|
||||
@classmethod
|
||||
@spectrogram_transforms.register(SpectralMeanSubstractionConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: SpectralMeanSubstractionConfig,
|
||||
samplerate: int,
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
spectrogram_transforms.register(
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectralMeanSubstraction,
|
||||
)
|
||||
return SpectralMeanSubstraction()
|
||||
|
||||
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
@ -375,13 +365,12 @@ class PeakNormalize(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return peak_normalize(spec)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
|
||||
return cls()
|
||||
@spectrogram_transforms.register(PeakNormalizeConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PeakNormalizeConfig, samplerate: int):
|
||||
return PeakNormalize()
|
||||
|
||||
|
||||
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
|
||||
@ -99,7 +99,7 @@ DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||
DEFAULT_CLASSES = [
|
||||
TargetClassConfig(
|
||||
name="barbar",
|
||||
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
|
||||
tags=[data.Tag(key="class", value="Barbastella barbastellus")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="eptser",
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
EchoAugmentationConfig,
|
||||
FrequencyMaskAugmentationConfig,
|
||||
AddEchoConfig,
|
||||
MaskFrequencyConfig,
|
||||
RandomAudioSource,
|
||||
TimeMaskAugmentationConfig,
|
||||
VolumeAugmentationConfig,
|
||||
WarpAugmentationConfig,
|
||||
MaskTimeConfig,
|
||||
ScaleVolumeConfig,
|
||||
WarpConfig,
|
||||
add_echo,
|
||||
build_augmentations,
|
||||
mask_frequency,
|
||||
@ -43,20 +43,20 @@ __all__ = [
|
||||
"AugmentationsConfig",
|
||||
"ClassificationLossConfig",
|
||||
"DetectionLossConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"AddEchoConfig",
|
||||
"MaskFrequencyConfig",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"PLTrainerConfig",
|
||||
"RandomAudioSource",
|
||||
"SizeLossConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"MaskTimeConfig",
|
||||
"TrainingConfig",
|
||||
"TrainingDataset",
|
||||
"TrainingModule",
|
||||
"ValidationDataset",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"ScaleVolumeConfig",
|
||||
"WarpConfig",
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
"build_clip_labeler",
|
||||
|
||||
@ -12,21 +12,23 @@ from soundevent import data
|
||||
from soundevent.geometry import scale_geometry, shift_geometry
|
||||
|
||||
from batdetect2.audio.clips import get_subclip_annotation
|
||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing import AudioLoader, Augmentation
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
"AugmentationsConfig",
|
||||
"DEFAULT_AUGMENTATION_CONFIG",
|
||||
"EchoAugmentationConfig",
|
||||
"AddEchoConfig",
|
||||
"AudioSource",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"MixAugmentationConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"MaskFrequencyConfig",
|
||||
"MixAudioConfig",
|
||||
"MaskTimeConfig",
|
||||
"ScaleVolumeConfig",
|
||||
"WarpConfig",
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
"load_augmentation_config",
|
||||
@ -37,10 +39,19 @@ __all__ = [
|
||||
"warp_spectrogram",
|
||||
]
|
||||
|
||||
|
||||
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
||||
|
||||
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
|
||||
Registry(name="audio_augmentation")
|
||||
)
|
||||
|
||||
class MixAugmentationConfig(BaseConfig):
|
||||
spec_augmentations: Registry[Augmentation, []] = Registry(
|
||||
name="spec_augmentation"
|
||||
)
|
||||
|
||||
|
||||
class MixAudioConfig(BaseConfig):
|
||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||
|
||||
name: Literal["mix_audio"] = "mix_audio"
|
||||
@ -87,6 +98,19 @@ class MixAudio(torch.nn.Module):
|
||||
)
|
||||
return mixed_audio, mixed_annotations
|
||||
|
||||
@audio_augmentations.register(MixAudioConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: MixAudioConfig,
|
||||
samplerate: int,
|
||||
source: Optional[AudioSource],
|
||||
):
|
||||
return MixAudio(
|
||||
example_source=source,
|
||||
min_weight=config.min_weight,
|
||||
max_weight=config.max_weight,
|
||||
)
|
||||
|
||||
|
||||
def mix_audio(
|
||||
wav1: torch.Tensor,
|
||||
@ -136,7 +160,7 @@ def combine_clip_annotations(
|
||||
)
|
||||
|
||||
|
||||
class EchoAugmentationConfig(BaseConfig):
|
||||
class AddEchoConfig(BaseConfig):
|
||||
"""Configuration for adding synthetic echo/reverb."""
|
||||
|
||||
name: Literal["add_echo"] = "add_echo"
|
||||
@ -149,14 +173,17 @@ class EchoAugmentationConfig(BaseConfig):
|
||||
class AddEcho(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
min_weight: float = 0.1,
|
||||
max_weight: float = 1.0,
|
||||
max_delay: int = 2560,
|
||||
max_delay: float = 0.005,
|
||||
):
|
||||
super().__init__()
|
||||
self.samplerate = samplerate
|
||||
self.min_weight = min_weight
|
||||
self.max_weight = max_weight
|
||||
self.max_delay = max_delay
|
||||
self.max_delay_s = max_delay
|
||||
self.max_delay = int(max_delay * samplerate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -167,6 +194,18 @@ class AddEcho(torch.nn.Module):
|
||||
weight = np.random.uniform(self.min_weight, self.max_weight)
|
||||
return add_echo(wav, delay=delay, weight=weight), clip_annotation
|
||||
|
||||
@audio_augmentations.register(AddEchoConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: AddEchoConfig, samplerate: int, source: AudioSource
|
||||
):
|
||||
return AddEcho(
|
||||
samplerate=samplerate,
|
||||
min_weight=config.min_weight,
|
||||
max_weight=config.max_weight,
|
||||
max_delay=config.max_delay,
|
||||
)
|
||||
|
||||
|
||||
def add_echo(
|
||||
wav: torch.Tensor,
|
||||
@ -183,7 +222,7 @@ def add_echo(
|
||||
return mix_audio(wav, audio_delay, weight)
|
||||
|
||||
|
||||
class VolumeAugmentationConfig(BaseConfig):
|
||||
class ScaleVolumeConfig(BaseConfig):
|
||||
"""Configuration for random volume scaling of the spectrogram."""
|
||||
|
||||
name: Literal["scale_volume"] = "scale_volume"
|
||||
@ -206,19 +245,27 @@ class ScaleVolume(torch.nn.Module):
|
||||
factor = np.random.uniform(self.min_scaling, self.max_scaling)
|
||||
return scale_volume(spec, factor=factor), clip_annotation
|
||||
|
||||
@spec_augmentations.register(ScaleVolumeConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ScaleVolumeConfig):
|
||||
return ScaleVolume(
|
||||
min_scaling=config.min_scaling,
|
||||
max_scaling=config.max_scaling,
|
||||
)
|
||||
|
||||
|
||||
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
|
||||
"""Scale the amplitude of the spectrogram by a factor."""
|
||||
return spec * factor
|
||||
|
||||
|
||||
class WarpAugmentationConfig(BaseConfig):
|
||||
class WarpConfig(BaseConfig):
|
||||
name: Literal["warp"] = "warp"
|
||||
probability: float = 0.2
|
||||
delta: float = 0.04
|
||||
|
||||
|
||||
class WarpSpectrogram(torch.nn.Module):
|
||||
class Warp(torch.nn.Module):
|
||||
def __init__(self, delta: float = 0.04) -> None:
|
||||
super().__init__()
|
||||
self.delta = delta
|
||||
@ -234,6 +281,11 @@ class WarpSpectrogram(torch.nn.Module):
|
||||
warp_clip_annotation(clip_annotation, factor=factor),
|
||||
)
|
||||
|
||||
@spec_augmentations.register(WarpConfig)
|
||||
@staticmethod
|
||||
def from_config(config: WarpConfig):
|
||||
return Warp(delta=config.delta)
|
||||
|
||||
|
||||
def warp_sound_event_annotation(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
@ -294,7 +346,7 @@ def warp_spectrogram(
|
||||
).squeeze(0)
|
||||
|
||||
|
||||
class TimeMaskAugmentationConfig(BaseConfig):
|
||||
class MaskTimeConfig(BaseConfig):
|
||||
name: Literal["mask_time"] = "mask_time"
|
||||
probability: float = 0.2
|
||||
max_perc: float = 0.05
|
||||
@ -336,6 +388,14 @@ class MaskTime(torch.nn.Module):
|
||||
]
|
||||
return mask_time(spec, masks), clip_annotation
|
||||
|
||||
@spec_augmentations.register(MaskTimeConfig)
|
||||
@staticmethod
|
||||
def from_config(config: MaskTimeConfig):
|
||||
return MaskTime(
|
||||
max_perc=config.max_perc,
|
||||
max_masks=config.max_masks,
|
||||
)
|
||||
|
||||
|
||||
def mask_time(
|
||||
spec: torch.Tensor,
|
||||
@ -351,7 +411,7 @@ def mask_time(
|
||||
return spec
|
||||
|
||||
|
||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||
class MaskFrequencyConfig(BaseConfig):
|
||||
name: Literal["mask_freq"] = "mask_freq"
|
||||
probability: float = 0.2
|
||||
max_perc: float = 0.10
|
||||
@ -394,6 +454,14 @@ class MaskFrequency(torch.nn.Module):
|
||||
]
|
||||
return mask_frequency(spec, masks), clip_annotation
|
||||
|
||||
@spec_augmentations.register(MaskFrequencyConfig)
|
||||
@staticmethod
|
||||
def from_config(config: MaskFrequencyConfig):
|
||||
return MaskFrequency(
|
||||
max_perc=config.max_perc,
|
||||
max_masks=config.max_masks,
|
||||
)
|
||||
|
||||
|
||||
def mask_frequency(
|
||||
spec: torch.Tensor,
|
||||
@ -410,8 +478,8 @@ def mask_frequency(
|
||||
|
||||
AudioAugmentationConfig = Annotated[
|
||||
Union[
|
||||
MixAugmentationConfig,
|
||||
EchoAugmentationConfig,
|
||||
MixAudioConfig,
|
||||
AddEchoConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
@ -419,22 +487,22 @@ AudioAugmentationConfig = Annotated[
|
||||
|
||||
SpectrogramAugmentationConfig = Annotated[
|
||||
Union[
|
||||
VolumeAugmentationConfig,
|
||||
WarpAugmentationConfig,
|
||||
FrequencyMaskAugmentationConfig,
|
||||
TimeMaskAugmentationConfig,
|
||||
ScaleVolumeConfig,
|
||||
WarpConfig,
|
||||
MaskFrequencyConfig,
|
||||
MaskTimeConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
AugmentationConfig = Annotated[
|
||||
Union[
|
||||
MixAugmentationConfig,
|
||||
EchoAugmentationConfig,
|
||||
VolumeAugmentationConfig,
|
||||
WarpAugmentationConfig,
|
||||
FrequencyMaskAugmentationConfig,
|
||||
TimeMaskAugmentationConfig,
|
||||
MixAudioConfig,
|
||||
AddEchoConfig,
|
||||
ScaleVolumeConfig,
|
||||
WarpConfig,
|
||||
MaskFrequencyConfig,
|
||||
MaskTimeConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
@ -513,7 +581,7 @@ def build_augmentation_from_config(
|
||||
)
|
||||
|
||||
if config.name == "warp":
|
||||
return WarpSpectrogram(
|
||||
return Warp(
|
||||
delta=config.delta,
|
||||
)
|
||||
|
||||
@ -538,14 +606,14 @@ def build_augmentation_from_config(
|
||||
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||
enabled=True,
|
||||
audio=[
|
||||
MixAugmentationConfig(),
|
||||
EchoAugmentationConfig(),
|
||||
MixAudioConfig(),
|
||||
AddEchoConfig(),
|
||||
],
|
||||
spectrogram=[
|
||||
VolumeAugmentationConfig(),
|
||||
WarpAugmentationConfig(),
|
||||
TimeMaskAugmentationConfig(),
|
||||
FrequencyMaskAugmentationConfig(),
|
||||
ScaleVolumeConfig(),
|
||||
WarpConfig(),
|
||||
MaskTimeConfig(),
|
||||
MaskFrequencyConfig(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -566,9 +634,9 @@ class AugmentationSequence(torch.nn.Module):
|
||||
return tensor, clip_annotation
|
||||
|
||||
|
||||
def build_augmentation_sequence(
|
||||
samplerate: int,
|
||||
steps: Optional[Sequence[AugmentationConfig]] = None,
|
||||
def build_audio_augmentations(
|
||||
steps: Optional[Sequence[AudioAugmentationConfig]] = None,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
audio_source: Optional[AudioSource] = None,
|
||||
) -> Optional[Augmentation]:
|
||||
if not steps:
|
||||
@ -577,10 +645,8 @@ def build_augmentation_sequence(
|
||||
augmentations = []
|
||||
|
||||
for step_config in steps:
|
||||
augmentation = build_augmentation_from_config(
|
||||
step_config,
|
||||
samplerate=samplerate,
|
||||
audio_source=audio_source,
|
||||
augmentation = audio_augmentations.build(
|
||||
step_config, samplerate, audio_source
|
||||
)
|
||||
|
||||
if augmentation is None:
|
||||
|
||||
@ -10,7 +10,6 @@ from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.typing import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
ModelOutput,
|
||||
RawPrediction,
|
||||
@ -37,22 +36,26 @@ class ValidationMetrics(Callback):
|
||||
def generate_plots(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
evaluated_clips: List[ClipEvaluation],
|
||||
):
|
||||
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
|
||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||
for figure_name, fig in self.evaluator.generate_plots(
|
||||
self._clip_annotations,
|
||||
self._predictions,
|
||||
):
|
||||
plotter(figure_name, fig, pl_module.global_step)
|
||||
|
||||
def log_metrics(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
evaluated_clips: List[ClipEvaluation],
|
||||
):
|
||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||
metrics = self.evaluator.compute_metrics(
|
||||
self._clip_annotations,
|
||||
self._predictions,
|
||||
)
|
||||
pl_module.log_dict(metrics)
|
||||
|
||||
def on_validation_epoch_end(
|
||||
@ -60,13 +63,8 @@ class ValidationMetrics(Callback):
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
clip_evaluations = self.evaluator.evaluate(
|
||||
self._clip_annotations,
|
||||
self._predictions,
|
||||
)
|
||||
|
||||
self.log_metrics(pl_module, clip_evaluations)
|
||||
self.generate_plots(pl_module, clip_evaluations)
|
||||
self.log_metrics(pl_module)
|
||||
self.generate_plots(pl_module)
|
||||
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
|
||||
@ -105,7 +105,10 @@ def train(
|
||||
trainer = trainer or build_trainer(
|
||||
config,
|
||||
targets=targets,
|
||||
evaluator=build_evaluator(config.train.validation, targets=targets),
|
||||
evaluator=build_evaluator(
|
||||
config.train.validation.evaluator,
|
||||
targets=targets,
|
||||
),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
AffinityFunction,
|
||||
ClipMatches,
|
||||
EvaluatorProtocol,
|
||||
MatcherProtocol,
|
||||
MatchEvaluation,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
@ -36,19 +38,22 @@ from batdetect2.typing.train import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AffinityFunction",
|
||||
"AudioLoader",
|
||||
"Augmentation",
|
||||
"BackboneModel",
|
||||
"BatDetect2Prediction",
|
||||
"ClipEvaluation",
|
||||
"ClipMatches",
|
||||
"ClipLabeller",
|
||||
"ClipperProtocol",
|
||||
"DetectionModel",
|
||||
"EvaluatorProtocol",
|
||||
"GeometryDecoder",
|
||||
"Heatmaps",
|
||||
"LossProtocol",
|
||||
"Losses",
|
||||
"MatchEvaluation",
|
||||
"MatcherProtocol",
|
||||
"MetricsProtocol",
|
||||
"ModelOutput",
|
||||
"PlotterProtocol",
|
||||
@ -63,5 +68,4 @@ __all__ = [
|
||||
"SoundEventFilter",
|
||||
"TargetProtocol",
|
||||
"TrainExample",
|
||||
"EvaluatorProtocol",
|
||||
]
|
||||
|
||||
@ -31,6 +31,7 @@ class MatchEvaluation:
|
||||
sound_event_annotation: Optional[data.SoundEventAnnotation]
|
||||
gt_det: bool
|
||||
gt_class: Optional[str]
|
||||
gt_geometry: Optional[data.Geometry]
|
||||
|
||||
pred_score: float
|
||||
pred_class_scores: Dict[str, float]
|
||||
@ -39,44 +40,32 @@ class MatchEvaluation:
|
||||
affinity: float
|
||||
|
||||
@property
|
||||
def pred_class(self) -> Optional[str]:
|
||||
def top_class(self) -> Optional[str]:
|
||||
if not self.pred_class_scores:
|
||||
return None
|
||||
|
||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||
|
||||
@property
|
||||
def pred_class_score(self) -> float:
|
||||
pred_class = self.pred_class
|
||||
def is_prediction(self) -> bool:
|
||||
return self.pred_geometry is not None
|
||||
|
||||
@property
|
||||
def is_generic(self) -> bool:
|
||||
return self.gt_det and self.gt_class is None
|
||||
|
||||
@property
|
||||
def top_class_score(self) -> float:
|
||||
pred_class = self.top_class
|
||||
|
||||
if pred_class is None:
|
||||
return 0
|
||||
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
def is_true_positive(self, threshold: float = 0) -> bool:
|
||||
return (
|
||||
self.gt_det
|
||||
and self.pred_score > threshold
|
||||
and self.gt_class == self.pred_class
|
||||
)
|
||||
|
||||
def is_false_positive(self, threshold: float = 0) -> bool:
|
||||
return self.gt_det is None and self.pred_score > threshold
|
||||
|
||||
def is_false_negative(self, threshold: float = 0) -> bool:
|
||||
return self.gt_det and self.pred_score <= threshold
|
||||
|
||||
def is_cross_trigger(self, threshold: float = 0) -> bool:
|
||||
return (
|
||||
self.gt_det
|
||||
and self.pred_score > threshold
|
||||
and self.gt_class != self.pred_class
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEvaluation:
|
||||
class ClipMatches:
|
||||
clip: data.Clip
|
||||
matches: List[MatchEvaluation]
|
||||
|
||||
@ -103,29 +92,36 @@ class AffinityFunction(Protocol, Generic[Geom]):
|
||||
|
||||
class MetricsProtocol(Protocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> Dict[str, float]: ...
|
||||
|
||||
|
||||
class PlotterProtocol(Protocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
class EvaluatorProtocol(Protocol):
|
||||
EvaluationOutput = TypeVar("EvaluationOutput")
|
||||
|
||||
|
||||
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
targets: TargetProtocol
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipEvaluation]: ...
|
||||
) -> EvaluationOutput: ...
|
||||
|
||||
def compute_metrics(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
self, eval_outputs: EvaluationOutput
|
||||
) -> Dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
self, eval_outputs: EvaluationOutput
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user