Compare commits

..

No commits in common. "30159d64a9ba3ca51d486e2157810030511cade6" and "4cd983a2c246df84d1fba427efd98726fa5a2f6c" have entirely different histories.

53 changed files with 1971 additions and 4564 deletions

View File

@ -138,49 +138,27 @@ train:
name: csv
validation:
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: sound_event_classification
metrics:
- name: average_precision
metrics:
- name: detection_ap
- name: detection_roc_auc
- name: classification_ap
- name: classification_roc_auc
- name: top_class_ap
- name: classification_balanced_accuracy
- name: clip_ap
- name: clip_roc_auc
evaluation:
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: score_distribution
- name: example_detection
- name: sound_event_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: top_class_detection
metrics:
- name: average_precision
plots:
- name: pr_curve
- name: confusion_matrix
- name: example_classification
- name: clip_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
- name: score_distribution
- name: clip_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
match_strategy:
name: start_time_match
distance_threshold: 0.01
metrics:
- name: classification_ap
- name: detection_ap
plots:
- name: example_gallery
- name: example_clip
- name: detection_pr_curve
- name: classification_pr_curves
- name: detection_roc_curve
- name: classification_roc_curves

View File

@ -1,7 +1,6 @@
from pathlib import Path
from typing import List, Optional, Sequence
from typing import Optional, Sequence
import torch
from soundevent import data
from batdetect2.audio import build_audio_loader
@ -9,7 +8,6 @@ from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.decoding import to_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.train import train
@ -21,7 +19,6 @@ from batdetect2.typing import (
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
class BatDetect2API:
@ -95,18 +92,6 @@ class BatDetect2API:
run_name=run_name,
)
def process_spectrogram(
self,
spec: torch.Tensor,
start_times: Optional[Sequence[float]] = None,
) -> List[List[RawPrediction]]:
outputs = self.model.detector(spec)
clip_detections = self.postprocessor(outputs, start_times=start_times)
return [
to_raw_predictions(clip_dets.numpy(), self.targets)
for clip_dets in clip_detections
]
@classmethod
def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets)
@ -123,7 +108,10 @@ class BatDetect2API:
config=config.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
)
# NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved
@ -175,7 +163,10 @@ class BatDetect2API:
config=config.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
)
return cls(
config=config,

View File

@ -56,16 +56,18 @@ class RandomClip:
min_sound_event_overlap=self.min_sound_event_overlap,
)
@clipper_registry.register(RandomClipConfig)
@staticmethod
def from_config(config: RandomClipConfig):
return RandomClip(
@classmethod
def from_config(cls, config: RandomClipConfig):
return cls(
duration=config.duration,
max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap,
)
clipper_registry.register(RandomClipConfig, RandomClip)
def get_subclip_annotation(
clip_annotation: data.ClipAnnotation,
random: bool = True,
@ -182,12 +184,13 @@ class PaddedClip:
)
return clip_annotation.model_copy(update=dict(clip=clip))
@clipper_registry.register(PaddedClipConfig)
@staticmethod
def from_config(config: PaddedClipConfig):
return PaddedClip(chunk_size=config.chunk_size)
@classmethod
def from_config(cls, config: PaddedClipConfig):
return cls(chunk_size=config.chunk_size)
clipper_registry.register(PaddedClipConfig, PaddedClip)
ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
]

View File

@ -53,7 +53,6 @@ class BaseConfig(BaseModel):
"""
return yaml.dump(
self.model_dump(
mode="json",
exclude_none=exclude_none,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,

View File

@ -1,16 +1,16 @@
import sys
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel
from typing_extensions import assert_type
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
from typing import ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
__all__ = [
"Registry",
"SimpleRegistry",
]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
@ -18,26 +18,19 @@ T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type")
T = TypeVar("T")
class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol):
"""A generic protocol for the logic classes."""
@classmethod
def from_config(
cls,
config: T_Config,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type: ...
class SimpleRegistry(Generic[T]):
def __init__(self, name: str):
self._name = name
self._registry = {}
def register(self, name: str):
def decorator(obj: T) -> T:
self._registry[name] = obj
return obj
return decorator
def get(self, name: str) -> T:
return self._registry[name]
def has(self, name: str) -> bool:
return name in self._registry
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
class Registry(Generic[T_Type, P_Type]):
@ -45,15 +38,13 @@ class Registry(Generic[T_Type, P_Type]):
def __init__(self, name: str):
self._name = name
self._registry: Dict[
str, Callable[Concatenate[..., P_Type], T_Type]
] = {}
self._config_types: Dict[str, Type[BaseModel]] = {}
self._registry = {}
def register(
self,
config_cls: Type[T_Config],
):
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
) -> None:
fields = config_cls.model_fields
if "name" not in fields:
@ -61,21 +52,10 @@ class Registry(Generic[T_Type, P_Type]):
name = fields["name"].default
self._config_types[name] = config_cls
if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.")
def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
):
self._registry[name] = func
return func
return decorator
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
return tuple(self._config_types.values())
self._registry[name] = logic_cls
def build(
self,
@ -95,4 +75,4 @@ class Registry(Generic[T_Type, P_Type]):
f"No {self._name} with name '{name}' is registered."
)
return self._registry[name](config, *args, **kwargs)
return self._registry[name].from_config(config, *args, **kwargs)

View File

@ -10,7 +10,7 @@ from batdetect2.core.registries import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
conditions: Registry[SoundEventCondition, []] = Registry("condition")
condition_registry: Registry[SoundEventCondition, []] = Registry("condition")
class HasTagConfig(BaseConfig):
@ -27,10 +27,12 @@ class HasTag:
) -> bool:
return self.tag in sound_event_annotation.tags
@conditions.register(HasTagConfig)
@staticmethod
def from_config(config: HasTagConfig):
return HasTag(tag=config.tag)
@classmethod
def from_config(cls, config: HasTagConfig):
return cls(tag=config.tag)
condition_registry.register(HasTagConfig, HasTag)
class HasAllTagsConfig(BaseConfig):
@ -50,10 +52,12 @@ class HasAllTags:
) -> bool:
return self.tags.issubset(sound_event_annotation.tags)
@conditions.register(HasAllTagsConfig)
@staticmethod
def from_config(config: HasAllTagsConfig):
return HasAllTags(tags=config.tags)
@classmethod
def from_config(cls, config: HasAllTagsConfig):
return cls(tags=config.tags)
condition_registry.register(HasAllTagsConfig, HasAllTags)
class HasAnyTagConfig(BaseConfig):
@ -73,12 +77,13 @@ class HasAnyTag:
) -> bool:
return bool(self.tags.intersection(sound_event_annotation.tags))
@conditions.register(HasAnyTagConfig)
@staticmethod
def from_config(config: HasAnyTagConfig):
return HasAnyTag(tags=config.tags)
@classmethod
def from_config(cls, config: HasAnyTagConfig):
return cls(tags=config.tags)
condition_registry.register(HasAnyTagConfig, HasAnyTag)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
@ -129,10 +134,12 @@ class Duration:
return self._comparator(duration)
@conditions.register(DurationConfig)
@staticmethod
def from_config(config: DurationConfig):
return Duration(operator=config.operator, seconds=config.seconds)
@classmethod
def from_config(cls, config: DurationConfig):
return cls(operator=config.operator, seconds=config.seconds)
condition_registry.register(DurationConfig, Duration)
class FrequencyConfig(BaseConfig):
@ -174,16 +181,18 @@ class Frequency:
return self._comparator(high_freq)
@conditions.register(FrequencyConfig)
@staticmethod
def from_config(config: FrequencyConfig):
return Frequency(
@classmethod
def from_config(cls, config: FrequencyConfig):
return cls(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
condition_registry.register(FrequencyConfig, Frequency)
class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
@ -198,13 +207,15 @@ class AllOf:
) -> bool:
return all(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AllOfConfig)
@staticmethod
def from_config(config: AllOfConfig):
@classmethod
def from_config(cls, config: AllOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AllOf(conditions)
return cls(conditions)
condition_registry.register(AllOfConfig, AllOf)
class AnyOfConfig(BaseConfig):
@ -221,13 +232,15 @@ class AnyOf:
) -> bool:
return any(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AnyOfConfig)
@staticmethod
def from_config(config: AnyOfConfig):
@classmethod
def from_config(cls, config: AnyOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AnyOf(conditions)
return cls(conditions)
condition_registry.register(AnyOfConfig, AnyOf)
class NotConfig(BaseConfig):
@ -244,13 +257,14 @@ class Not:
) -> bool:
return not self.condition(sound_event_annotation)
@conditions.register(NotConfig)
@staticmethod
def from_config(config: NotConfig):
@classmethod
def from_config(cls, config: NotConfig):
condition = build_sound_event_condition(config.condition)
return Not(condition)
return cls(condition)
condition_registry.register(NotConfig, Not)
SoundEventConditionConfig = Annotated[
Union[
HasTagConfig,
@ -269,7 +283,7 @@ SoundEventConditionConfig = Annotated[
def build_sound_event_condition(
config: SoundEventConditionConfig,
) -> SoundEventCondition:
return conditions.build(config)
return condition_registry.build(config)
def filter_clip_annotation(

View File

@ -17,7 +17,7 @@ SoundEventTransform = Callable[
data.SoundEventAnnotation,
]
transforms: Registry[SoundEventTransform, []] = Registry("transform")
transform_registry: Registry[SoundEventTransform, []] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig):
@ -63,10 +63,12 @@ class SetFrequencyBound:
update=dict(sound_event=sound_event)
)
@transforms.register(SetFrequencyBoundConfig)
@staticmethod
def from_config(config: SetFrequencyBoundConfig):
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
@classmethod
def from_config(cls, config: SetFrequencyBoundConfig):
return cls(hertz=config.hertz, boundary=config.boundary)
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
class ApplyIfConfig(BaseConfig):
@ -93,12 +95,14 @@ class ApplyIf:
return self.transform(sound_event_annotation)
@transforms.register(ApplyIfConfig)
@staticmethod
def from_config(config: ApplyIfConfig):
@classmethod
def from_config(cls, config: ApplyIfConfig):
transform = build_sound_event_transform(config.transform)
condition = build_sound_event_condition(config.condition)
return ApplyIf(condition=condition, transform=transform)
return cls(condition=condition, transform=transform)
transform_registry.register(ApplyIfConfig, ApplyIf)
class ReplaceTagConfig(BaseConfig):
@ -130,12 +134,12 @@ class ReplaceTag:
return sound_event_annotation.model_copy(update=dict(tags=tags))
@transforms.register(ReplaceTagConfig)
@staticmethod
def from_config(config: ReplaceTagConfig):
return ReplaceTag(
original=config.original, replacement=config.replacement
)
@classmethod
def from_config(cls, config: ReplaceTagConfig):
return cls(original=config.original, replacement=config.replacement)
transform_registry.register(ReplaceTagConfig, ReplaceTag)
class MapTagValueConfig(BaseConfig):
@ -185,16 +189,18 @@ class MapTagValue:
return sound_event_annotation.model_copy(update=dict(tags=tags))
@transforms.register(MapTagValueConfig)
@staticmethod
def from_config(config: MapTagValueConfig):
return MapTagValue(
@classmethod
def from_config(cls, config: MapTagValueConfig):
return cls(
tag_key=config.tag_key,
value_mapping=config.value_mapping,
target_key=config.target_key,
)
transform_registry.register(MapTagValueConfig, MapTagValue)
class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
@ -213,13 +219,14 @@ class ApplyAll:
return sound_event_annotation
@transforms.register(ApplyAllConfig)
@staticmethod
def from_config(config: ApplyAllConfig):
@classmethod
def from_config(cls, config: ApplyAllConfig):
steps = [build_sound_event_transform(step) for step in config.steps]
return ApplyAll(steps)
return cls(steps)
transform_registry.register(ApplyAllConfig, ApplyAll)
SoundEventTransformConfig = Annotated[
Union[
SetFrequencyBoundConfig,
@ -235,7 +242,7 @@ SoundEventTransformConfig = Annotated[
def build_sound_event_transform(
config: SoundEventTransformConfig,
) -> SoundEventTransform:
return transforms.build(config)
return transform_registry.build(config)
def transform_clip_annotation(

View File

@ -1,14 +1,11 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task
__all__ = [
"EvaluationConfig",
"Evaluator",
"TaskConfig",
"build_evaluator",
"build_task",
"evaluate",
"load_evaluation_config",
"evaluate",
"Evaluator",
"build_evaluator",
]

View File

@ -3,7 +3,6 @@ from typing import Annotated, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.geometry import compute_interval_overlap
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
@ -28,10 +27,12 @@ class TimeAffinity(AffinityFunction):
geometry1, geometry2, time_buffer=self.time_buffer
)
@affinity_functions.register(TimeAffinityConfig)
@staticmethod
def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer)
@classmethod
def from_config(cls, config: TimeAffinityConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
def compute_timestamp_affinity(
@ -72,10 +73,12 @@ class IntervalIOU(AffinityFunction):
time_buffer=self.time_buffer,
)
@affinity_functions.register(IntervalIOUConfig)
@staticmethod
def from_config(config: IntervalIOUConfig):
return IntervalIOU(time_buffer=config.time_buffer)
@classmethod
def from_config(cls, config: IntervalIOUConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
def compute_interval_iou(
@ -94,11 +97,9 @@ def compute_interval_iou(
end_time1 += time_buffer
end_time2 += time_buffer
intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
@ -109,86 +110,6 @@ def compute_interval_iou(
return intersection / union
class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float):
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
if not isinstance(geometry1, data.BoundingBox):
raise TypeError(
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
)
if not isinstance(geometry2, data.BoundingBox):
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
@affinity_functions.register(BBoxIOUConfig)
@staticmethod
def from_config(config: BBoxIOUConfig):
return BBoxIOU(
time_buffer=config.time_buffer,
freq_buffer=config.freq_buffer,
)
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01
@ -206,17 +127,17 @@ class GeometricIOU(AffinityFunction):
time_buffer=self.time_buffer,
)
@affinity_functions.register(GeometricIOUConfig)
@staticmethod
def from_config(config: GeometricIOUConfig):
return GeometricIOU(time_buffer=config.time_buffer)
@classmethod
def from_config(cls, config: GeometricIOUConfig):
return cls(time_buffer=config.time_buffer)
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
AffinityConfig = Annotated[
Union[
TimeAffinityConfig,
IntervalIOUConfig,
BBoxIOUConfig,
GeometricIOUConfig,
],
Field(discriminator="name"),

View File

@ -4,11 +4,13 @@ from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.tasks import (
TaskConfig,
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
from batdetect2.evaluate.metrics import (
ClassificationAPConfig,
DetectionAPConfig,
MetricConfig,
)
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.plots import PlotConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [
@ -18,12 +20,15 @@ __all__ = [
class EvaluationConfig(BaseConfig):
tasks: List[TaskConfig] = Field(
ignore_start_end: float = 0.01
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
metrics: List[MetricConfig] = Field(
default_factory=lambda: [
DetectionTaskConfig(),
ClassificationTaskConfig(),
DetectionAPConfig(),
ClassificationAPConfig(),
]
)
plots: List[PlotConfig] = Field(default_factory=list)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)

View File

@ -1,12 +1,23 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from matplotlib.figure import Figure
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.match import build_matcher, match
from batdetect2.evaluate.metrics import build_metric
from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
from batdetect2.typing.evaluate import (
ClipEvaluation,
EvaluatorProtocol,
MatcherProtocol,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"Evaluator",
@ -17,51 +28,146 @@ __all__ = [
class Evaluator:
def __init__(
self,
config: EvaluationConfig,
targets: TargetProtocol,
tasks: Sequence[EvaluatorProtocol],
matcher: MatcherProtocol,
metrics: List[MetricsProtocol],
plots: List[PlotterProtocol],
):
self.config = config
self.targets = targets
self.tasks = tasks
self.matcher = matcher
self.metrics = metrics
self.plots = plots
def match(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEvaluation:
clip = clip_annotation.clip
ground_truth = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
predictions = [
prediction
for prediction in predictions
if self.filter_predictions(prediction, clip)
]
return ClipEvaluation(
clip=clip_annotation.clip,
matches=match(
ground_truth,
predictions,
clip=clip,
targets=self.targets,
matcher=self.matcher,
),
)
def filter_sound_event_annotations(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.config.ignore_start_end,
)
def filter_predictions(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.config.ignore_start_end,
)
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[Any]:
) -> List[ClipEvaluation]:
if len(clip_annotations) != len(predictions):
raise ValueError(
"Number of annotated clips and sets of predictions do not match"
)
return [
task.evaluate(clip_annotations, predictions) for task in self.tasks
self.match(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
def compute_metrics(
self,
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, float]:
results = {}
for task, outputs in zip(self.tasks, eval_outputs):
results.update(task.compute_metrics(outputs))
for metric in self.metrics:
results.update(metric(clip_evaluations))
return results
def generate_plots(
self,
eval_outputs: List[Any],
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]:
for task, outputs in zip(self.tasks, eval_outputs):
for name, fig in task.generate_plots(outputs):
for plotter in self.plots:
for name, fig in plotter(clip_evaluations):
yield name, fig
def build_evaluator(
config: Optional[Union[EvaluationConfig, dict]] = None,
config: Optional[EvaluationConfig] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None,
) -> EvaluatorProtocol:
config = config or EvaluationConfig()
targets = targets or build_targets()
matcher = matcher or build_matcher(config.match_strategy)
if config is None:
config = EvaluationConfig()
if metrics is None:
metrics = [
build_metric(config, targets.class_names)
for config in config.metrics
]
if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config)
if plots is None:
plots = [
build_plotter(config, targets.class_names)
for config in config.plots
]
return Evaluator(
config=config,
targets=targets,
tasks=[build_task(task, targets=targets) for task in config.tasks],
matcher=matcher,
metrics=metrics,
plots=plots,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -4,10 +4,11 @@ from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.logging import get_image_logger
from batdetect2.evaluate.tables import FullEvaluationTable
from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipMatches, EvaluatorProtocol
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
class EvaluationModule(LightningModule):
@ -53,8 +54,18 @@ class EvaluationModule(LightningModule):
def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations)
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None:
return
df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
plotter = get_image_logger(self.logger) # type: ignore
if plotter is None:
@ -63,7 +74,7 @@ class EvaluationModule(LightningModule):
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics)

View File

@ -8,7 +8,8 @@ from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
GeometricIOUConfig,
@ -16,13 +17,11 @@ from batdetect2.evaluate.affinity import (
)
from batdetect2.targets import build_targets
from batdetect2.typing import (
AffinityFunction,
MatcherProtocol,
MatchEvaluation,
RawPrediction,
TargetProtocol,
)
from batdetect2.typing.evaluate import ClipMatches
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
@ -34,10 +33,9 @@ def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
scores: Optional[Sequence[float]] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
) -> ClipMatches:
) -> List[MatchEvaluation]:
if matcher is None:
matcher = build_matcher()
@ -53,11 +51,9 @@ def match(
raw_prediction.geometry for raw_prediction in raw_predictions
]
if scores is None:
scores = [
raw_prediction.detection_score
for raw_prediction in raw_predictions
]
scores = [
raw_prediction.detection_score for raw_prediction in raw_predictions
]
matches = []
@ -77,11 +73,9 @@ def match(
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
@ -90,7 +84,7 @@ def match(
class_scores = (
{
class_name: score
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
@ -106,7 +100,6 @@ def match(
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
@ -114,7 +107,7 @@ def match(
)
)
return ClipMatches(clip=clip, matches=matches)
return matches
class StartTimeMatchConfig(BaseConfig):
@ -139,10 +132,12 @@ class StartTimeMatcher(MatcherProtocol):
distance_threshold=self.distance_threshold,
)
@matching_strategies.register(StartTimeMatchConfig)
@staticmethod
def from_config(config: StartTimeMatchConfig):
return StartTimeMatcher(distance_threshold=config.distance_threshold)
@classmethod
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher":
return cls(distance_threshold=config.distance_threshold)
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
def match_start_times(
@ -269,17 +264,19 @@ class GreedyMatcher(MatcherProtocol):
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(GreedyMatchConfig)
@staticmethod
def from_config(config: GreedyMatchConfig):
@classmethod
def from_config(cls, config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return GreedyMatcher(
return cls(
geometry=config.geometry,
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
)
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
def greedy_match(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
@ -316,21 +313,21 @@ def greedy_match(
unassigned_gt = set(range(len(ground_truth)))
if not predictions:
for gt_idx in range(len(ground_truth)):
yield None, gt_idx, 0
for target_idx in range(len(ground_truth)):
yield None, target_idx, 0
return
if not ground_truth:
for pred_idx in range(len(predictions)):
yield pred_idx, None, 0
for source_idx in range(len(predictions)):
yield source_idx, None, 0
return
indices = np.argsort(scores)[::-1]
for pred_idx in indices:
source_geometry = predictions[pred_idx]
for source_idx in indices:
source_geometry = predictions[source_idx]
affinities = np.array(
[
@ -343,18 +340,18 @@ def greedy_match(
affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold:
yield pred_idx, None, 0
yield source_idx, None, 0
continue
if closest_target not in unassigned_gt:
yield pred_idx, None, 0
yield source_idx, None, 0
continue
unassigned_gt.remove(closest_target)
yield pred_idx, closest_target, affinity
yield source_idx, closest_target, affinity
for gt_idx in unassigned_gt:
yield None, gt_idx, 0
for target_idx in unassigned_gt:
yield None, target_idx, 0
class OptimalMatchConfig(BaseConfig):
@ -389,16 +386,17 @@ class OptimalMatcher(MatcherProtocol):
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(OptimalMatchConfig)
@staticmethod
def from_config(config: OptimalMatchConfig):
return OptimalMatcher(
@classmethod
def from_config(cls, config: OptimalMatchConfig):
return cls(
affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
)
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
MatchConfig = Annotated[
Union[
GreedyMatchConfig,

View File

@ -0,0 +1,712 @@
from collections import defaultdict
from collections.abc import Callable, Mapping
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics, preprocessing
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation, MetricsProtocol
__all__ = ["DetectionAP", "ClassificationAP"]
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
APImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class DetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(self.metric(y_true, y_score))
return {"detection_AP": score}
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(DetectionAPConfig, DetectionAP)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["detection_roc_auc"] = "detection_roc_auc"
class DetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(metrics.roc_auc_score(y_true, y_score))
return {"detection_ROC_AUC": score}
@classmethod
def from_config(
cls, config: DetectionROCAUCConfig, class_names: List[str]
):
return cls()
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation = "pascal_voc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_mAP": float(mean_ap),
**{
f"classification_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationAPConfig,
class_names: List[str],
):
return cls(
class_names,
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
class ClassificationROCAUCConfig(BaseConfig):
name: Literal["classification_roc_auc"] = "classification_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
**{
f"classification_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationROCAUCConfig,
class_names: List[str],
):
return cls(
class_names,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
class TopClassAPConfig(BaseConfig):
name: Literal["top_class_ap"] = "top_class_ap"
ap_implementation: APImplementation = "pascal_voc"
class TopClassAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
top_class = match.pred_class
y_true.append(top_class == match.gt_class)
y_score.append(match.pred_class_score)
score = float(self.metric(y_true, y_score))
return {"top_class_AP": score}
@classmethod
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(TopClassAPConfig, TopClassAP)
class ClassificationBalancedAccuracyConfig(BaseConfig):
name: Literal["classification_balanced_accuracy"] = (
"classification_balanced_accuracy"
)
class ClassificationBalancedAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
top_class = match.pred_class
# Focus on matches
if match.gt_class is None or top_class is None:
continue
y_true.append(self.class_names.index(match.gt_class))
y_pred.append(self.class_names.index(top_class))
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
return {"classification_balanced_accuracy": score}
@classmethod
def from_config(
cls,
config: ClassificationBalancedAccuracyConfig,
class_names: List[str],
):
return cls(class_names)
metrics_registry.register(
ClassificationBalancedAccuracyConfig,
ClassificationBalancedAccuracy,
)
class ClipDetectionAPConfig(BaseConfig):
name: Literal["clip_detection_ap"] = "clip_detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class ClipDetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {"clip_detection_ap": self.metric(y_true, y_score)}
@classmethod
def from_config(
cls,
config: ClipDetectionAPConfig,
class_names: List[str],
):
return cls(implementation=config.ap_implementation)
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
class ClipDetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
}
@classmethod
def from_config(
cls,
config: ClipDetectionROCAUCConfig,
class_names: List[str],
):
return cls()
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
class ClipMulticlassAPConfig(BaseConfig):
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation,
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get max score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_mAP": float(mean_ap),
**{
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls, config: ClipMulticlassAPConfig, class_names: List[str]
):
return cls(
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
class ClipMulticlassROCAUCConfig(BaseConfig):
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get maximum score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
**{
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClipMulticlassROCAUCConfig,
class_names: List[str],
):
return cls(
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
MetricConfig = Annotated[
Union[
DetectionAPConfig,
DetectionROCAUCConfig,
ClassificationAPConfig,
ClassificationROCAUCConfig,
TopClassAPConfig,
ClassificationBalancedAccuracyConfig,
ClipDetectionAPConfig,
ClipDetectionROCAUCConfig,
ClipMulticlassAPConfig,
ClipMulticlassROCAUCConfig,
],
Field(discriminator="name"),
]
def build_metric(config: MetricConfig, class_names: List[str]):
return metrics_registry.build(config, class_names)
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}

View File

@ -1,267 +0,0 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction, TargetProtocol
__all__ = [
"ClassificationMetric",
"ClassificationMetricConfig",
"build_classification_metric",
]
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_prediction: bool
is_ground_truth: bool
is_generic: bool
true_class: Optional[str]
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: Mapping[str, List[MatchEval]]
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
Registry("classification_metric")
)
class BaseClassificationConfig(BaseConfig):
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class BaseClassificationMetric:
def __init__(
self,
targets: TargetProtocol,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.include = include
self.exclude = exclude
def include_class(self, class_name: str) -> bool:
if self.include is not None:
return class_name in self.include
if self.exclude is not None:
return class_name not in self.exclude
return True
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
name: Literal["average_precision"] = "average_precision"
ignore_non_predictions: bool = True
ignore_generic: bool = True
label: str = "average_precision"
class ClassificationAveragePrecision(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "average_precision",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: average_precision(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])
)
return {
f"mean_{self.label}": mean_score,
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if self.include_class(class_name)
},
}
@classification_metrics.register(ClassificationAveragePrecisionConfig)
@staticmethod
def from_config(
config: ClassificationAveragePrecisionConfig,
targets: TargetProtocol,
):
return ClassificationAveragePrecision(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
include=config.include,
exclude=config.exclude,
)
class ClassificationROCAUCConfig(BaseClassificationConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ClassificationROCAUC(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "roc_auc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
self.include = include
self.exclude = exclude
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: float(
metrics.roc_auc_score(
y_true[class_name],
y_score[class_name],
)
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])
)
return {
f"mean_{self.label}": mean_score,
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if self.include_class(class_name)
},
}
@classification_metrics.register(ClassificationROCAUCConfig)
@staticmethod
def from_config(
config: ClassificationROCAUCConfig, targets: TargetProtocol
):
return ClassificationROCAUC(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
)
ClassificationMetricConfig = Annotated[
Union[
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
Field(discriminator="name"),
]
def build_classification_metric(
config: ClassificationMetricConfig,
targets: TargetProtocol,
) -> ClassificationMetric:
return classification_metrics.build(config, targets)
def _extract_per_class_metric_data(
clip_evaluations: Sequence[ClipEval],
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
y_true = defaultdict(list)
y_score = defaultdict(list)
num_positives = defaultdict(lambda: 0)
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and ignore_generic:
continue
is_class = m.true_class == class_name
if is_class:
num_positives[class_name] += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and ignore_non_predictions:
continue
y_true[class_name].append(is_class)
y_score[class_name].append(m.score)
return y_true, y_score, num_positives

View File

@ -1,135 +0,0 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.metrics.common import average_precision
@dataclass
class ClipEval:
true_classes: Set[str]
class_scores: Dict[str, float]
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
"clip_classification_metric"
)
class ClipClassificationAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
class ClipClassificationAveragePrecision:
def __init__(self, label: str = "average_precision"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
for clip_eval in clip_evaluations:
for class_name, score in clip_eval.class_scores.items():
y_true[class_name].append(class_name in clip_eval.true_classes)
y_score[class_name].append(score)
class_scores = {
class_name: float(
average_precision(
y_true=y_true[class_name],
y_score=y_score[class_name],
)
)
for class_name in y_true
}
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
return {
f"mean_{self.label}": float(mean),
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if not np.isnan(score)
},
}
@clip_classification_metrics.register(
ClipClassificationAveragePrecisionConfig
)
@staticmethod
def from_config(config: ClipClassificationAveragePrecisionConfig):
return ClipClassificationAveragePrecision(label=config.label)
class ClipClassificationROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
class ClipClassificationROCAUC:
def __init__(self, label: str = "roc_auc"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
for clip_eval in clip_evaluations:
for class_name, score in clip_eval.class_scores.items():
y_true[class_name].append(class_name in clip_eval.true_classes)
y_score[class_name].append(score)
class_scores = {
class_name: float(
metrics.roc_auc_score(
y_true=y_true[class_name],
y_score=y_score[class_name],
)
)
for class_name in y_true
}
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
return {
f"mean_{self.label}": float(mean),
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if not np.isnan(score)
},
}
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
@staticmethod
def from_config(config: ClipClassificationROCAUCConfig):
return ClipClassificationROCAUC(label=config.label)
ClipClassificationMetricConfig = Annotated[
Union[
ClipClassificationAveragePrecisionConfig,
ClipClassificationROCAUCConfig,
],
Field(discriminator="name"),
]
def build_clip_metric(config: ClipClassificationMetricConfig):
return clip_classification_metrics.build(config)

View File

@ -1,173 +0,0 @@
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.metrics.common import average_precision
@dataclass
class ClipEval:
gt_det: bool
score: float
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
"clip_detection_metric"
)
class ClipDetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
class ClipDetectionAveragePrecision:
def __init__(self, label: str = "average_precision"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.score)
score = average_precision(y_true=y_true, y_score=y_score)
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
@staticmethod
def from_config(config: ClipDetectionAveragePrecisionConfig):
return ClipDetectionAveragePrecision(label=config.label)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
class ClipDetectionROCAUC:
def __init__(self, label: str = "roc_auc"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.score)
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
@staticmethod
def from_config(config: ClipDetectionROCAUCConfig):
return ClipDetectionROCAUC(label=config.label)
class ClipDetectionRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
threshold: float = 0.5
label: str = "recall"
class ClipDetectionRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
if clip_eval.gt_det:
num_positives += 1
if clip_eval.score >= self.threshold and clip_eval.gt_det:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionRecallConfig)
@staticmethod
def from_config(config: ClipDetectionRecallConfig):
return ClipDetectionRecall(
threshold=config.threshold, label=config.label
)
class ClipDetectionPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
threshold: float = 0.5
label: str = "precision"
class ClipDetectionPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
if clip_eval.score >= self.threshold:
num_detections += 1
if clip_eval.score >= self.threshold and clip_eval.gt_det:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
@staticmethod
def from_config(config: ClipDetectionPrecisionConfig):
return ClipDetectionPrecision(
threshold=config.threshold, label=config.label
)
ClipDetectionMetricConfig = Annotated[
Union[
ClipDetectionAveragePrecisionConfig,
ClipDetectionROCAUCConfig,
ClipDetectionRecallConfig,
ClipDetectionPrecisionConfig,
],
Field(discriminator="name"),
]
def build_clip_metric(config: ClipDetectionMetricConfig):
return clip_detection_metrics.build(config)

View File

@ -1,60 +0,0 @@
from typing import Optional, Tuple
import numpy as np
__all__ = [
"compute_precision_recall",
"average_precision",
]
def compute_precision_recall(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true)
y_score = np.array(y_score)
if num_positives is None:
num_positives = y_true.sum()
# Sort by score
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
y_score_sorted = y_score[sort_ind]
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
return precision, recall, y_score_sorted
def average_precision(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> float:
precision, recall, _ = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec

View File

@ -1,226 +0,0 @@
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
__all__ = [
"DetectionMetricConfig",
"DetectionMetric",
"build_detection_metric",
]
@dataclass
class MatchEval:
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_prediction: bool
is_ground_truth: bool
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: List[MatchEval]
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
class DetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
ignore_non_predictions: bool = True
class DetectionAveragePrecision:
def __init__(self, label: str, ignore_non_predictions: bool = True):
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
ap = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: ap}
@detection_metrics.register(DetectionAveragePrecisionConfig)
@staticmethod
def from_config(config: DetectionAveragePrecisionConfig):
return DetectionAveragePrecision(
label=config.label,
ignore_non_predictions=config.ignore_non_predictions,
)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
ignore_non_predictions: bool = True
class DetectionROCAUC:
def __init__(
self,
label: str = "roc_auc",
ignore_non_predictions: bool = True,
):
self.label = label
self.ignore_non_predictions = ignore_non_predictions
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
y_true: List[bool] = []
y_score: List[float] = []
for clip_eval in clip_evals:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}
@detection_metrics.register(DetectionROCAUCConfig)
@staticmethod
def from_config(config: DetectionROCAUCConfig):
return DetectionROCAUC(
label=config.label,
ignore_non_predictions=config.ignore_non_predictions,
)
class DetectionRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
label: str = "recall"
threshold: float = 0.5
class DetectionRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.label = label
self.threshold = threshold
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_ground_truth:
num_positives += 1
if m.score >= self.threshold and m.is_ground_truth:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@detection_metrics.register(DetectionRecallConfig)
@staticmethod
def from_config(config: DetectionRecallConfig):
return DetectionRecall(threshold=config.threshold, label=config.label)
class DetectionPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
label: str = "precision"
threshold: float = 0.5
class DetectionPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_detection = m.score >= self.threshold
if is_detection:
num_detections += 1
if is_detection and m.is_ground_truth:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@detection_metrics.register(DetectionPrecisionConfig)
@staticmethod
def from_config(config: DetectionPrecisionConfig):
return DetectionPrecision(
threshold=config.threshold,
label=config.label,
)
DetectionMetricConfig = Annotated[
Union[
DetectionAveragePrecisionConfig,
DetectionROCAUCConfig,
DetectionRecallConfig,
DetectionPrecisionConfig,
],
Field(discriminator="name"),
]
def build_detection_metric(config: DetectionMetricConfig):
return detection_metrics.build(config)

View File

@ -1,314 +0,0 @@
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics, preprocessing
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
__all__ = [
"TopClassMetricConfig",
"TopClassMetric",
"build_top_class_metric",
]
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_ground_truth: bool
is_generic: bool
is_prediction: bool
pred_class: Optional[str]
true_class: Optional[str]
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: List[MatchEval]
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
class TopClassAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
ignore_generic: bool = True
ignore_non_predictions: bool = True
class TopClassAveragePrecision:
def __init__(
self,
ignore_generic: bool = True,
ignore_non_predictions: bool = True,
label: str = "average_precision",
):
self.ignore_generic = ignore_generic
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
score = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: score}
@top_class_metrics.register(TopClassAveragePrecisionConfig)
@staticmethod
def from_config(config: TopClassAveragePrecisionConfig):
return TopClassAveragePrecision(
ignore_generic=config.ignore_generic,
label=config.label,
)
class TopClassROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
ignore_generic: bool = True
ignore_non_predictions: bool = True
label: str = "roc_auc"
class TopClassROCAUC:
def __init__(
self,
ignore_generic: bool = True,
ignore_non_predictions: bool = True,
label: str = "roc_auc",
):
self.ignore_generic = ignore_generic
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
y_true: List[bool] = []
y_score: List[float] = []
for clip_eval in clip_evals:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}
@top_class_metrics.register(TopClassROCAUCConfig)
@staticmethod
def from_config(config: TopClassROCAUCConfig):
return TopClassROCAUC(
ignore_generic=config.ignore_generic,
label=config.label,
)
class TopClassRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
threshold: float = 0.5
label: str = "recall"
class TopClassRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_ground_truth:
num_positives += 1
if m.score >= self.threshold and m.pred_class == m.true_class:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@top_class_metrics.register(TopClassRecallConfig)
@staticmethod
def from_config(config: TopClassRecallConfig):
return TopClassRecall(
threshold=config.threshold,
label=config.label,
)
class TopClassPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
threshold: float = 0.5
label: str = "precision"
class TopClassPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_detection = m.score >= self.threshold
if is_detection:
num_detections += 1
if is_detection and m.pred_class == m.true_class:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@top_class_metrics.register(TopClassPrecisionConfig)
@staticmethod
def from_config(config: TopClassPrecisionConfig):
return TopClassPrecision(
threshold=config.threshold,
label=config.label,
)
class BalancedAccuracyConfig(BaseConfig):
name: Literal["balanced_accuracy"] = "balanced_accuracy"
label: str = "balanced_accuracy"
exclude_noise: bool = False
noise_class: str = "noise"
class BalancedAccuracy:
def __init__(
self,
exclude_noise: bool = True,
noise_class: str = "noise",
label: str = "balanced_accuracy",
):
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic:
# Ignore matches that correspond to a sound event
# with unknown class
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore predictions that were not matched to a
# ground truth
continue
if m.pred_class is None and self.exclude_noise:
# Ignore non-predictions
continue
y_true.append(m.true_class or self.noise_class)
y_pred.append(m.pred_class or self.noise_class)
encoder = preprocessing.LabelEncoder()
encoder.fit(list(set(y_true) | set(y_pred)))
y_true = encoder.transform(y_true)
y_pred = encoder.transform(y_pred)
score = metrics.balanced_accuracy_score(y_true, y_pred)
return {self.label: score}
@top_class_metrics.register(BalancedAccuracyConfig)
@staticmethod
def from_config(config: BalancedAccuracyConfig):
return BalancedAccuracy(
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
label=config.label,
)
TopClassMetricConfig = Annotated[
Union[
TopClassAveragePrecisionConfig,
TopClassROCAUCConfig,
TopClassRecallConfig,
TopClassPrecisionConfig,
BalancedAccuracyConfig,
],
Field(discriminator="name"),
]
def build_top_class_metric(config: TopClassMetricConfig):
return top_class_metrics.build(config)

View File

@ -0,0 +1,560 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import BaseConfig, Registry
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipEvaluation,
MatchEvaluation,
PlotterProtocol,
PreprocessorProtocol,
)
__all__ = [
"build_plotter",
"ExampleGallery",
"ExampleGalleryConfig",
]
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
class ExampleGalleryConfig(BaseConfig):
name: Literal["example_gallery"] = "example_gallery"
examples_per_class: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleGallery(PlotterProtocol):
def __init__(
self,
examples_per_class: int,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.examples_per_class = examples_per_class
self.preprocessor = preprocessor or build_preprocessor()
self.audio_loader = audio_loader or build_audio_loader()
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
per_class_matches = group_matches(clip_evaluations)
for class_name, matches in per_class_matches.items():
true_positives = get_binned_sample(
matches.true_positives,
n_examples=self.examples_per_class,
)
false_positives = get_binned_sample(
matches.false_positives,
n_examples=self.examples_per_class,
)
false_negatives = random.sample(
matches.false_negatives,
k=min(self.examples_per_class, len(matches.false_negatives)),
)
cross_triggers = get_binned_sample(
matches.cross_triggers,
n_examples=self.examples_per_class,
)
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.examples_per_class,
)
yield f"example_gallery/{class_name}", fig
plt.close(fig)
@classmethod
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
class ClipEvaluationPlotConfig(BaseConfig):
name: Literal["example_clip"] = "example_clip"
num_plots: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class PlotClipEvaluation(PlotterProtocol):
def __init__(
self,
num_plots: int = 3,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.num_plots = num_plots
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
examples = random.sample(
clip_evaluations,
k=min(self.num_plots, len(clip_evaluations)),
)
for index, clip_evaluation in enumerate(examples):
fig, ax = plt.subplots()
plot_matches(
clip_evaluation.matches,
clip=clip_evaluation.clip,
audio_loader=self.audio_loader,
ax=ax,
)
yield f"clip_evaluation/example_{index}", fig
plt.close(fig)
@classmethod
def from_config(
cls,
config: ClipEvaluationPlotConfig,
class_names: List[str],
):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
num_plots=config.num_plots,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
class DetectionPRCurveConfig(BaseConfig):
name: Literal["detection_pr_curve"] = "detection_pr_curve"
class DetectionPRCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(recall, precision, label="Detector")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend()
yield "detection_pr_curve", fig
@classmethod
def from_config(
cls,
config: DetectionPRCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
class ClassificationPRCurvesConfig(BaseConfig):
name: Literal["classification_pr_curves"] = "classification_pr_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationPRCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
precision, recall, _ = metrics.precision_recall_curve(
y_true_class,
y_pred_class,
)
ax.plot(recall, precision, label=class_name)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_pr_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationPRCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
class DetectionROCCurveConfig(BaseConfig):
name: Literal["detection_roc_curve"] = "detection_roc_curve"
class DetectionROCCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, label="Detection")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
yield "detection_roc_curve", fig
@classmethod
def from_config(
cls,
config: DetectionROCCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
class ClassificationROCCurvesConfig(BaseConfig):
name: Literal["classification_roc_curves"] = "classification_roc_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_roced_class = y_pred[:, class_index]
fpr, tpr, _ = metrics.roc_curve(
y_true_class,
y_roced_class,
)
ax.plot(fpr, tpr, label=class_name)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_roc_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationROCCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
class ConfusionMatrixConfig(BaseConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
background_class: str = "noise"
class ConfusionMatrix(PlotterProtocol):
def __init__(self, background_class: str, class_names: List[str]):
self.background_class = background_class
self.class_names = class_names
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else self.background_class
)
top_class = match.pred_class
y_pred.append(
top_class
if top_class is not None
else self.background_class
)
display = metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=[*self.class_names, self.background_class],
)
yield "confusion_matrix", display.figure_
@classmethod
def from_config(
cls,
config: ConfusionMatrixConfig,
class_names: List[str],
):
return cls(
background_class=config.background_class,
class_names=class_names,
)
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
PlotConfig = Annotated[
Union[
ExampleGalleryConfig,
ClipEvaluationPlotConfig,
DetectionPRCurveConfig,
ClassificationPRCurvesConfig,
DetectionROCCurveConfig,
ClassificationROCCurvesConfig,
ConfusionMatrixConfig,
],
Field(discriminator="name"),
]
def build_plotter(
config: PlotConfig, class_names: List[str]
) -> PlotterProtocol:
return plots_registry.build(config, class_names)
@dataclass
class ClassMatches:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -1,54 +0,0 @@
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol
class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
figsize: tuple[int, int] = (10, 10)
dpi: int = 100
class BasePlot:
def __init__(
self,
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None,
dpi: int = 100,
theme: str = "default",
):
self.targets = targets
self.label = label
self.figsize = figsize
self.dpi = dpi
self.theme = theme
self.title = title
def create_figure(self) -> Figure:
plt.style.use(self.theme)
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
if self.title is not None:
fig.suptitle(self.title)
return fig
@classmethod
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
return cls(
targets=targets,
figsize=config.figsize,
dpi=config.dpi,
theme=config.theme,
label=config.label,
title=config.title,
**kwargs,
)

View File

@ -1,370 +0,0 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.classification import (
ClipEval,
_extract_per_class_metric_data,
)
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
plot_threshold_precision_curve,
plot_threshold_precision_curves,
plot_threshold_recall_curve,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
ClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
Registry("classification_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdPrecisionCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_precision_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, _, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_precision_curve(
thresholds,
precision,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdPrecisionCurveConfig)
@staticmethod
def from_config(
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
):
return ThresholdPrecisionCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdRecallCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
yield self.label, fig
return
for class_name, (_, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_recall_curve(
thresholds,
recall,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdRecallCurveConfig)
@staticmethod
def from_config(
config: ThresholdRecallCurveConfig, targets: TargetProtocol
):
return ThresholdRecallCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: metrics.roc_curve(
y_true[class_name],
y_score[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPrecisionCurveConfig,
ThresholdRecallCurveConfig,
],
Field(discriminator="name"),
]
def build_classification_plotter(
config: ClassificationPlotConfig,
targets: TargetProtocol,
) -> ClassificationPlotter:
return classification_plots.build(config, targets)

View File

@ -1,189 +0,0 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipClassificationPlotConfig",
"ClipClassificationPlotter",
"build_clip_classification_plotter",
]
ClipClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_classification_plots: Registry[
ClipClassificationPlotter, [TargetProtocol]
] = Registry("clip_classification_plot")
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False
class PRCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
data[class_name] = (precision, recall, thresholds)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Classification ROC Curve"
separate_figures: bool = False
class ROCCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
data[class_name] = (fpr, tpr, thresholds)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)
ClipClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
],
Field(discriminator="name"),
]
def build_clip_classification_plotter(
config: ClipClassificationPlotConfig,
targets: TargetProtocol,
) -> ClipClassificationPlotter:
return clip_classification_plots.build(config, targets)

View File

@ -1,163 +0,0 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipDetectionPlotConfig",
"ClipDetectionPlotter",
"build_clip_detection_plotter",
]
ClipDetectionPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
Registry("clip_detection_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@clip_detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Detection ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@clip_detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fig = self.create_figure()
ax = fig.subplots()
df = pd.DataFrame({"is_true": y_true, "score": y_score})
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
yield self.label, fig
@clip_detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
)
ClipDetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
Field(discriminator="name"),
]
def build_clip_detection_plotter(
config: ClipDetectionPlotConfig,
targets: TargetProtocol,
) -> ClipDetectionPlotter:
return clip_detection_plots.build(config, targets)

View File

@ -1,309 +0,0 @@
import random
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
name="detection_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ScoreDistributionPlot(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
df = pd.DataFrame({"is_true": y_true, "score": y_score})
fig = self.create_figure()
ax = fig.subplots()
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
yield self.label, fig
@detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
title: Optional[str] = "Example Detection"
figsize: tuple[int, int] = (10, 4)
num_examples: int = 5
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleDetectionPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 5,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
sample = clip_evaluations
if self.num_examples < len(sample):
sample = random.sample(sample, self.num_examples)
for num_example, clip_eval in enumerate(sample):
fig = self.create_figure()
ax = fig.subplots()
plot_clip_detections(
clip_eval,
ax=ax,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
)
yield f"{self.label}/example_{num_example}", fig
plt.close(fig)
@detection_plots.register(ExampleDetectionPlotConfig)
@staticmethod
def from_config(
config: ExampleDetectionPlotConfig,
targets: TargetProtocol,
):
return ExampleDetectionPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
DetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
Field(discriminator="name"),
]
def build_detection_plotter(
config: DetectionPlotConfig,
targets: TargetProtocol,
) -> DetectionPlotter:
return detection_plots.build(config, targets)

View File

@ -1,444 +0,0 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Annotated,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@top_class_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@top_class_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
exclude_noise: bool = False
noise_class: str = "noise"
normalize: Literal["true", "pred", "all", "none"] = "true"
threshold: float = 0.2
add_colorbar: bool = True
cmap: str = "Blues"
class ConfusionMatrix(BasePlot):
def __init__(
self,
*args,
exclude_generic: bool = True,
exclude_noise: bool = False,
noise_class: str = "noise",
add_colorbar: bool = True,
normalize: Literal["true", "pred", "all", "none"] = "true",
cmap: str = "Blues",
threshold: float = 0.2,
**kwargs,
):
super().__init__(*args, **kwargs)
self.exclude_generic = exclude_generic
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.normalize = normalize
self.add_colorbar = add_colorbar
self.threshold = threshold
self.cmap = cmap
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and self.exclude_noise:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < self.threshold:
if self.exclude_noise:
continue
pred_class = self.noise_class
if m.is_generic:
if self.exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = self.targets.detection_class_name
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
fig = self.create_figure()
ax = fig.subplots()
class_names = [*self.targets.class_names]
if not self.exclude_generic:
class_names.append(self.targets.detection_class_name)
if not self.exclude_noise:
class_names.append(self.noise_class)
metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=class_names,
ax=ax,
xticks_rotation="vertical",
cmap=self.cmap,
colorbar=self.add_colorbar,
normalize=self.normalize if self.normalize != "none" else None,
values_format=".2f",
)
yield self.label, fig
@top_class_plots.register(ConfusionMatrixConfig)
@staticmethod
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
return ConfusionMatrix.build(
config=config,
targets=targets,
exclude_generic=config.exclude_generic,
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
add_colorbar=config.add_colorbar,
normalize=config.normalize,
cmap=config.cmap,
)
class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification"
label: str = "example_classification"
title: Optional[str] = "Example Classification"
num_examples: int = 4
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleClassificationPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 4,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
self.num_examples = num_examples
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample(
matches.true_positives,
n_examples=self.num_examples,
)
false_positives: List[MatchEval] = get_binned_sample(
matches.false_positives,
n_examples=self.num_examples,
)
false_negatives: List[MatchEval] = random.sample(
matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)),
)
cross_triggers: List[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples
)
fig = self.create_figure()
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.num_examples,
fig=fig,
)
if self.title is not None:
fig.suptitle(f"{self.title}: {class_name}")
else:
fig.suptitle(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@top_class_plots.register(ExampleClassificationPlotConfig)
@staticmethod
def from_config(
config: ExampleClassificationPlotConfig,
targets: TargetProtocol,
):
return ExampleClassificationPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
threshold=config.threshold,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
ExampleClassificationPlotConfig,
],
Field(discriminator="name"),
]
def build_top_class_plotter(
config: TopClassPlotConfig,
targets: TargetProtocol,
) -> TopClassPlotter:
return top_class_plots.build(config, targets)
@dataclass
class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list)
def group_matches(
clip_evals: Sequence[ClipEval],
threshold: float = 0.2,
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals:
for match in clip_eval.matches:
gt_class = match.true_class
pred_class = match.pred_class
is_pred = match.score >= threshold
if not is_pred and gt_class is not None:
class_examples[gt_class].false_negatives.append(match)
continue
if not is_pred:
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -5,9 +5,9 @@ from pydantic import Field
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipMatches
from batdetect2.typing import ClipEvaluation
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
@ -21,18 +21,20 @@ class FullEvaluationTableConfig(BaseConfig):
class FullEvaluationTable:
def __call__(
self, clip_evaluations: Sequence[ClipMatches]
self, clip_evaluations: Sequence[ClipEvaluation]
) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations)
@tables_registry.register(FullEvaluationTableConfig)
@staticmethod
def from_config(config: FullEvaluationTableConfig):
return FullEvaluationTable()
@classmethod
def from_config(cls, config: FullEvaluationTableConfig):
return cls()
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
def extract_matches_dataframe(
clip_evaluations: Sequence[ClipMatches],
clip_evaluations: Sequence[ClipEvaluation],
) -> pd.DataFrame:
data = []
@ -76,8 +78,8 @@ def extract_matches_dataframe(
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.top_class,
("pred", "class_score"): match.top_class_score,
("pred", "class"): match.pred_class,
("pred", "class_score"): match.pred_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,

View File

@ -1,39 +0,0 @@
from typing import Annotated, Optional, Union
from pydantic import Field
from batdetect2.evaluate.tasks.base import tasks_registry
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.clip_classification import (
ClipClassificationTaskConfig,
)
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
__all__ = [
"TaskConfig",
"build_task",
]
TaskConfig = Annotated[
Union[
ClassificationTaskConfig,
DetectionTaskConfig,
ClipDetectionTaskConfig,
ClipClassificationTaskConfig,
TopClassDetectionTaskConfig,
],
Field(discriminator="name"),
]
def build_task(
config: TaskConfig,
targets: Optional[TargetProtocol] = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
return tasks_registry.build(config, targets)

View File

@ -1,175 +0,0 @@
from typing import (
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.match import (
MatchConfig,
StartTimeMatchConfig,
build_matcher,
)
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BaseTaskConfig",
"BaseTask",
]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
"tasks"
)
T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig):
prefix: str
ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol
matcher: MatcherProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
prefix: str
def __init__(
self,
matcher: MatcherProtocol,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str,
ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
):
self.matcher = matcher
self.metrics = metrics
self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end
def compute_metrics(
self,
eval_outputs: List[T_Output],
) -> Dict[str, float]:
scores = [metric(eval_outputs) for metric in self.metrics]
return {
f"{self.prefix}/{name}": score
for metric_output in scores
for name, score in metric_output.items()
}
def generate_plots(
self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots:
for name, fig in plot(eval_outputs):
yield f"{self.prefix}/{name}", fig
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> T_Output: ...
def include_sound_event_annotation(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.ignore_start_end,
)
def include_prediction(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.ignore_start_end,
)
@classmethod
def build(
cls,
config: BaseTaskConfig,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
**kwargs,
):
matcher = build_matcher(config.matching_strategy)
return cls(
matcher=matcher,
targets=targets,
metrics=metrics,
plots=plots,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
**kwargs,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -1,149 +0,0 @@
from typing import (
List,
Literal,
Sequence,
)
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig,
ClassificationMetricConfig,
ClipEval,
MatchEval,
build_classification_metric,
)
from batdetect2.evaluate.plots.classification import (
ClassificationPlotConfig,
build_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClassificationTaskConfig(BaseTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification"
metrics: List[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
)
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True
class ClassificationTask(BaseTask[ClipEval]):
def __init__(
self,
*args,
include_generics: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.include_generics = include_generics
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
all_gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
per_class_matches = {}
for class_name in self.targets.class_names:
class_idx = self.targets.class_names.index(class_name)
# Only match to targets of the given class
gts = [
sound_event
for sound_event in all_gts
if self.is_class(sound_event, class_name)
]
scores = [float(pred.class_scores[class_idx]) for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
score = (
float(pred.class_scores[class_idx])
if pred is not None
else 0
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
is_generic=gt is not None and true_class is None,
true_class=true_class,
score=score,
)
)
per_class_matches[class_name] = matches
return ClipEval(clip=clip, matches=per_class_matches)
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@tasks_registry.register(ClassificationTaskConfig)
@staticmethod
def from_config(
config: ClassificationTaskConfig,
targets: TargetProtocol,
):
metrics = [
build_classification_metric(metric, targets)
for metric in config.metrics
]
plots = [
build_classification_plotter(plot, targets)
for plot in config.plots
]
return ClassificationTask.build(
config=config,
plots=plots,
targets=targets,
metrics=metrics,
)

View File

@ -1,85 +0,0 @@
from collections import defaultdict
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.clip_classification import (
ClipClassificationAveragePrecisionConfig,
ClipClassificationMetricConfig,
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_classification import (
ClipClassificationPlotConfig,
build_clip_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification"
metrics: List[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(),
]
)
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gt_classes = set()
for sound_event in clip_annotation.sound_events:
if not self.include_sound_event_annotation(sound_event, clip):
continue
class_name = self.targets.encode_class(sound_event)
if class_name is None:
continue
gt_classes.add(class_name)
pred_scores = defaultdict(float)
for pred in predictions:
if not self.include_prediction(pred, clip):
continue
for class_idx, class_name in enumerate(self.targets.class_names):
pred_scores[class_name] = max(
float(pred.class_scores[class_idx]),
pred_scores[class_name],
)
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
@tasks_registry.register(ClipClassificationTaskConfig)
@staticmethod
def from_config(
config: ClipClassificationTaskConfig,
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_classification_plotter(plot, targets)
for plot in config.plots
]
return ClipClassificationTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -1,76 +0,0 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.clip_detection import (
ClipDetectionAveragePrecisionConfig,
ClipDetectionMetricConfig,
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_detection import (
ClipDetectionPlotConfig,
build_clip_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection"
metrics: List[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(),
]
)
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gt_det = any(
self.include_sound_event_annotation(sound_event, clip)
for sound_event in clip_annotation.sound_events
)
pred_score = 0
for pred in predictions:
if not self.include_prediction(pred, clip):
continue
pred_score = max(pred_score, pred.detection_score)
return ClipEval(
gt_det=gt_det,
score=pred_score,
)
@tasks_registry.register(ClipDetectionTaskConfig)
@staticmethod
def from_config(
config: ClipDetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_detection_plotter(plot, targets)
for plot in config.plots
]
return ClipDetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -1,88 +0,0 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.detection import (
ClipEval,
DetectionAveragePrecisionConfig,
DetectionMetricConfig,
MatchEval,
build_detection_metric,
)
from batdetect2.evaluate.plots.detection import (
DetectionPlotConfig,
build_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class DetectionTaskConfig(BaseTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection"
metrics: List[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
)
plots: List[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
scores = [pred.detection_score for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
matches.append(
MatchEval(
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
score=pred.detection_score if pred is not None else 0,
)
)
return ClipEval(clip=clip, matches=matches)
@tasks_registry.register(DetectionTaskConfig)
@staticmethod
def from_config(
config: DetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_detection_metric(metric) for metric in config.metrics]
plots = [
build_detection_plotter(plot, targets) for plot in config.plots
]
return DetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -1,111 +0,0 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
MatchEval,
TopClassAveragePrecisionConfig,
TopClassMetricConfig,
build_top_class_metric,
)
from batdetect2.evaluate.plots.top_class import (
TopClassPlotConfig,
build_top_class_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class TopClassDetectionTaskConfig(BaseTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class"
metrics: List[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
)
plots: List[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
class_idx = (
pred.class_scores.argmax() if pred is not None else None
)
score = (
float(pred.class_scores[class_idx]) if pred is not None else 0
)
pred_class = (
self.targets.class_names[class_idx]
if class_idx is not None
else None
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_ground_truth=gt is not None,
is_prediction=pred is not None,
true_class=true_class,
is_generic=gt is not None and true_class is None,
pred_class=pred_class,
score=score,
)
)
return ClipEval(clip=clip, matches=matches)
@tasks_registry.register(TopClassDetectionTaskConfig)
@staticmethod
def from_config(
config: TopClassDetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_top_class_metric(metric) for metric in config.metrics]
plots = [
build_top_class_plotter(plot, targets) for plot in config.plots
]
return TopClassDetectionTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -11,6 +11,7 @@ from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
@ -21,6 +22,7 @@ __all__ = [
"plot_cross_trigger_match",
"plot_false_negative_match",
"plot_false_positive_match",
"plot_matches",
"plot_spectrogram",
"plot_true_positive_match",
"plot_detection_heatmap",

View File

@ -65,6 +65,8 @@ def plot_anchor_points(
if not targets.filter(sound_event):
continue
sound_event = targets.transform(sound_event)
position, _ = targets.encode_roi(sound_event)
positions.append(position)

View File

@ -19,7 +19,7 @@ def create_ax(
) -> axes.Axes:
"""Create a new axis if none is provided"""
if ax is None:
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
return ax # type: ignore
@ -66,9 +66,6 @@ def plot_spectrogram(
vmax=vmax,
)
ax.set_xlim(start_time, end_time)
ax.set_ylim(min_freq, max_freq)
if add_colorbar:
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))

View File

@ -1,113 +0,0 @@
from typing import Optional
from matplotlib import axes, patches
from soundevent.plot import plot_geometry
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.plotting.clips import (
AudioLoader,
PreprocessorProtocol,
plot_clip,
)
from batdetect2.plotting.common import create_ax
__all__ = [
"plot_clip_detections",
]
def plot_clip_detections(
clip_eval: ClipEval,
figsize: tuple[int, int] = (10, 10),
ax: Optional[axes.Axes] = None,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
threshold: float = 0.2,
add_legend: bool = True,
add_title: bool = True,
fill: bool = False,
linewidth: float = 1.0,
gt_color: str = "green",
gt_linestyle: str = "-",
true_pred_color: str = "yellow",
true_pred_linestyle: str = "--",
false_pred_color: str = "blue",
false_pred_linestyle: str = "-",
missed_gt_color: str = "red",
missed_gt_linestyle: str = "-",
) -> axes.Axes:
ax = create_ax(figsize=figsize, ax=ax)
plot_clip(
clip_eval.clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
ax=ax,
)
for m in clip_eval.matches:
is_match = (
m.pred is not None and m.gt is not None and m.score >= threshold
)
if m.pred is not None:
color = true_pred_color if is_match else false_pred_color
plot_geometry(
m.pred.geometry,
ax=ax,
add_points=False,
facecolor="none" if not fill else color,
alpha=m.pred.detection_score,
linewidth=linewidth,
linestyle=true_pred_linestyle
if is_match
else missed_gt_linestyle,
color=color,
)
if m.gt is not None:
color = gt_color if is_match else missed_gt_color
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
ax=ax,
add_points=False,
linewidth=linewidth,
facecolor="none" if not fill else color,
linestyle=gt_linestyle if is_match else false_pred_linestyle,
color=color,
)
if add_title:
ax.set_title(clip_eval.clip.recording.path.name)
if add_legend:
ax.legend(
handles=[
patches.Patch(
label="found GT",
edgecolor=gt_color,
facecolor="none" if not fill else gt_color,
linestyle=gt_linestyle,
),
patches.Patch(
label="missed GT",
edgecolor=missed_gt_color,
facecolor="none" if not fill else missed_gt_color,
linestyle=missed_gt_linestyle,
),
patches.Patch(
label="true Det",
edgecolor=true_pred_color,
facecolor="none" if not fill else true_pred_color,
linestyle=true_pred_linestyle,
),
patches.Patch(
label="false Det",
edgecolor=false_pred_color,
facecolor="none" if not fill else false_pred_color,
linestyle=false_pred_linestyle,
),
]
)
return ax

View File

@ -1,109 +1,81 @@
from typing import Optional, Sequence
from typing import List, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.plotting.matches import (
MatchProtocol,
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_true_positive_match,
)
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
__all__ = ["plot_match_gallery"]
def plot_match_gallery(
true_positives: Sequence[MatchProtocol],
false_positives: Sequence[MatchProtocol],
false_negatives: Sequence[MatchProtocol],
cross_triggers: Sequence[MatchProtocol],
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
n_examples: int = 5,
duration: float = 0.1,
fig: Optional[Figure] = None,
):
if fig is None:
fig = plt.figure(figsize=(20, 20))
fig = plt.figure(figsize=(20, 20))
axes = fig.subplots(
nrows=4,
ncols=n_examples,
sharex="none",
sharey="row",
)
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
for index, match in enumerate(true_positives[:n_examples]):
ax = plt.subplot(4, n_examples, index + 1)
try:
plot_true_positive_match(
tp_match,
ax=tp_ax,
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
for index, match in enumerate(false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plot_false_positive_match(
fp_match,
ax=fp_ax,
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
for index, match in enumerate(false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plot_false_negative_match(
fn_match,
ax=fn_ax,
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
try:
plot_cross_trigger_match(
ct_match,
ax=ct_ax,
match,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
fig.tight_layout()
return fig

View File

@ -1,17 +1,16 @@
from typing import Optional, Protocol, Tuple, Union
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clips import plot_clip
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
RawPrediction,
)
from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
__all__ = [
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
@ -19,14 +18,6 @@ __all__ = [
]
class MatchProtocol(Protocol):
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
score: float
true_class: Optional[str]
DEFAULT_DURATION = 0.05
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
@ -36,8 +27,88 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[MatchEvaluation],
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
) -> Axes:
ax = plot_clip(
clip,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if color_mapper is None:
color_mapper = TagColorMapper()
for match in matches:
if match.is_cross_trigger():
plot_cross_trigger_match(
match,
ax=ax,
fill=fill,
add_points=add_points,
add_spectrogram=False,
use_score=True,
color=cross_trigger_color,
add_text=False,
)
elif match.is_true_positive():
plot_true_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=true_positive_color,
add_text=False,
)
elif match.is_false_negative():
plot_false_negative_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
add_points=add_points,
color=false_negative_color,
add_text=False,
)
elif match.is_false_positive:
plot_false_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=false_positive_color,
add_text=False,
)
else:
continue
return ax
def plot_false_positive_match(
match: MatchProtocol,
match: MatchEvaluation,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -48,24 +119,21 @@ def plot_false_positive_match(
add_spectrogram: bool = True,
add_text: bool = True,
add_points: bool = False,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.pred is not None
assert match.pred_geometry is not None
assert match.sound_event_annotation is None
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
clip = data.Clip(
start_time=max(
start_time - duration / 2,
0,
),
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
match.clip.end_time,
),
recording=match.clip.recording,
)
@ -82,33 +150,30 @@ def plot_false_positive_match(
)
ax = plot.plot_geometry(
match.pred.geometry,
match.pred_geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=match.score if use_score else 1,
alpha=match.pred_score if use_score else 1,
color=color,
)
if add_text:
ax.text(
plt.text(
start_time,
high_freq,
f"score={match.score:.2f}",
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("False Positive")
return ax
def plot_false_negative_match(
match: MatchProtocol,
match: MatchEvaluation,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -117,28 +182,26 @@ def plot_false_negative_match(
duration: float = DEFAULT_DURATION,
add_spectrogram: bool = True,
add_points: bool = False,
add_title: bool = True,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.gt is not None
geometry = match.gt.sound_event.geometry
assert match.pred_geometry is None
assert match.sound_event_annotation is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time = compute_bounds(geometry)[0]
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(
start_time - duration / 2,
0,
),
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
start_time + duration / 2, sound_event.recording.duration
),
recording=match.clip.recording,
recording=sound_event.recording,
)
if add_spectrogram:
@ -152,23 +215,33 @@ def plot_false_negative_match(
spec_cmap=spec_cmap,
)
ax = plot.plot_geometry(
geometry,
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
if add_title:
ax.set_title("False Negative")
if add_text:
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_true_positive_match(
match: MatchProtocol,
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -185,42 +258,39 @@ def plot_true_positive_match(
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
add_title: bool = True,
) -> Axes:
assert match.gt is not None
assert match.pred is not None
geometry = match.gt.sound_event.geometry
assert match.sound_event_annotation is not None
assert match.pred_geometry is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(
start_time - duration / 2,
0,
),
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
start_time + duration / 2, sound_event.recording.duration
),
recording=match.clip.recording,
recording=sound_event.recording,
)
if add_spectrogram:
ax = plot_clip(
clip,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
ax = plot.plot_geometry(
geometry,
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -229,34 +299,31 @@ def plot_true_positive_match(
)
plot.plot_geometry(
match.pred.geometry,
match.pred_geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=match.score if use_score else 1,
alpha=match.pred_score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
if add_text:
ax.text(
plt.text(
start_time,
high_freq,
f"score={match.score:.2f}",
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("True Positive")
return ax
def plot_cross_trigger_match(
match: MatchProtocol,
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -267,7 +334,6 @@ def plot_cross_trigger_match(
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
@ -275,24 +341,20 @@ def plot_cross_trigger_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.gt is not None
assert match.pred is not None
geometry = match.gt.sound_event.geometry
assert match.sound_event_annotation is not None
assert match.pred_geometry is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(
start_time - duration / 2,
0,
),
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
start_time + duration / 2, sound_event.recording.duration
),
recording=match.clip.recording,
recording=sound_event.recording,
)
if add_spectrogram:
@ -306,9 +368,11 @@ def plot_cross_trigger_match(
spec_cmap=spec_cmap,
)
ax = plot.plot_geometry(
geometry,
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -317,28 +381,24 @@ def plot_cross_trigger_match(
)
ax = plot.plot_geometry(
match.pred.geometry,
match.pred_geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=match.score if use_score else 1,
alpha=match.pred_score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
if add_text:
ax.text(
plt.text(
start_time,
high_freq,
f"score={match.score:.2f}\nclass={match.true_class}",
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("Cross Trigger")
return ax

View File

@ -1,286 +0,0 @@
from typing import Dict, Optional, Tuple
import numpy as np
import seaborn as sns
from cycler import cycler
from matplotlib import axes
from batdetect2.plotting.common import create_ax
def set_default_styler(ax: axes.Axes) -> axes.Axes:
color_cycler = cycler(color=sns.color_palette("muted"))
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
marker=["o", "s", "^"]
)
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
color_cycler
)
ax.set_prop_cycle(custom_cycler)
return ax
def set_default_style(ax: axes.Axes) -> axes.Axes:
ax = set_default_styler(ax)
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
return ax
def plot_pr_curve(
precision: np.ndarray,
recall: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
add_legend: bool = False,
label: str = "PR Curve",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
recall,
precision,
label=label,
marker="o",
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_legend:
ax.legend()
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
return ax
def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, recall, thresholds) in data.items():
ax.plot(
recall,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
return ax
def plot_threshold_precision_curve(
threshold: np.ndarray,
precision: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, _, thresholds) in data.items():
ax.plot(
thresholds,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_recall_curve(
threshold: np.ndarray,
recall: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (_, recall, thresholds) in data.items():
ax.plot(
thresholds,
recall,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_roc_curve(
fpr: np.ndarray,
tpr: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
fpr,
tpr,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (fpr, tpr, thresholds) in data.items():
ax.plot(
fpr,
tpr,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def _get_marker_positions(
thresholds: np.ndarray,
n_points: int = 11,
) -> np.ndarray:
size = len(thresholds)
cut_points = np.linspace(0, 1, n_points)
indices = np.searchsorted(thresholds[::-1], cut_points)
return np.clip(size - indices, 0, size - 1) # type: ignore

View File

@ -28,10 +28,12 @@ class CenterAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor:
return center_tensor(wav)
@audio_transforms.register(CenterAudioConfig)
@staticmethod
def from_config(config: CenterAudioConfig, samplerate: int):
return CenterAudio()
@classmethod
def from_config(cls, config: CenterAudioConfig, samplerate: int):
return cls()
audio_transforms.register(CenterAudioConfig, CenterAudio)
class ScaleAudioConfig(BaseConfig):
@ -42,10 +44,12 @@ class ScaleAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor:
return peak_normalize(wav)
@audio_transforms.register(ScaleAudioConfig)
@staticmethod
def from_config(config: ScaleAudioConfig, samplerate: int):
return ScaleAudio()
@classmethod
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
return cls()
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
class FixDurationConfig(BaseConfig):
@ -71,12 +75,13 @@ class FixDuration(torch.nn.Module):
return torch.nn.functional.pad(wav, (0, self.length - length))
@audio_transforms.register(FixDurationConfig)
@staticmethod
def from_config(config: FixDurationConfig, samplerate: int):
return FixDuration(samplerate=samplerate, duration=config.duration)
@classmethod
def from_config(cls, config: FixDurationConfig, samplerate: int):
return cls(samplerate=samplerate, duration=config.duration)
audio_transforms.register(FixDurationConfig, FixDuration)
AudioTransform = Annotated[
Union[
FixDurationConfig,

View File

@ -285,11 +285,10 @@ class PCEN(torch.nn.Module):
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
).to(spec.dtype)
@spectrogram_transforms.register(PcenConfig)
@staticmethod
def from_config(config: PcenConfig, samplerate: int):
@classmethod
def from_config(cls, config: PcenConfig, samplerate: int):
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
return PCEN(
return cls(
smoothing_constant=smooth,
gain=config.gain,
bias=config.bias,
@ -297,6 +296,9 @@ class PCEN(torch.nn.Module):
)
spectrogram_transforms.register(PcenConfig, PCEN)
def _compute_smoothing_constant(
samplerate: int,
time_constant: float,
@ -333,10 +335,12 @@ class ScaleAmplitude(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return self.scaler(spec)
@spectrogram_transforms.register(ScaleAmplitudeConfig)
@staticmethod
def from_config(config: ScaleAmplitudeConfig, samplerate: int):
return ScaleAmplitude(scale=config.scale)
@classmethod
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
return cls(scale=config.scale)
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
class SpectralMeanSubstractionConfig(BaseConfig):
@ -348,13 +352,19 @@ class SpectralMeanSubstraction(torch.nn.Module):
mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0)
@spectrogram_transforms.register(SpectralMeanSubstractionConfig)
@staticmethod
@classmethod
def from_config(
cls,
config: SpectralMeanSubstractionConfig,
samplerate: int,
):
return SpectralMeanSubstraction()
return cls()
spectrogram_transforms.register(
SpectralMeanSubstractionConfig,
SpectralMeanSubstraction,
)
class PeakNormalizeConfig(BaseConfig):
@ -365,12 +375,13 @@ class PeakNormalize(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return peak_normalize(spec)
@spectrogram_transforms.register(PeakNormalizeConfig)
@staticmethod
def from_config(config: PeakNormalizeConfig, samplerate: int):
return PeakNormalize()
@classmethod
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
return cls()
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
SpectrogramTransform = Annotated[
Union[
PcenConfig,

View File

@ -99,7 +99,7 @@ DEFAULT_DETECTION_CLASS = TargetClassConfig(
DEFAULT_CLASSES = [
TargetClassConfig(
name="barbar",
tags=[data.Tag(key="class", value="Barbastella barbastellus")],
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
),
TargetClassConfig(
name="eptser",

View File

@ -1,11 +1,11 @@
from batdetect2.train.augmentations import (
AugmentationsConfig,
AddEchoConfig,
MaskFrequencyConfig,
EchoAugmentationConfig,
FrequencyMaskAugmentationConfig,
RandomAudioSource,
MaskTimeConfig,
ScaleVolumeConfig,
WarpConfig,
TimeMaskAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
add_echo,
build_augmentations,
mask_frequency,
@ -43,20 +43,20 @@ __all__ = [
"AugmentationsConfig",
"ClassificationLossConfig",
"DetectionLossConfig",
"AddEchoConfig",
"MaskFrequencyConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
"RandomAudioSource",
"SizeLossConfig",
"MaskTimeConfig",
"TimeMaskAugmentationConfig",
"TrainingConfig",
"TrainingDataset",
"TrainingModule",
"ValidationDataset",
"ScaleVolumeConfig",
"WarpConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo",
"build_augmentations",
"build_clip_labeler",

View File

@ -12,23 +12,21 @@ from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
from batdetect2.typing import AudioLoader, Augmentation
__all__ = [
"AugmentationConfig",
"AugmentationsConfig",
"DEFAULT_AUGMENTATION_CONFIG",
"AddEchoConfig",
"EchoAugmentationConfig",
"AudioSource",
"MaskFrequencyConfig",
"MixAudioConfig",
"MaskTimeConfig",
"ScaleVolumeConfig",
"WarpConfig",
"FrequencyMaskAugmentationConfig",
"MixAugmentationConfig",
"TimeMaskAugmentationConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo",
"build_augmentations",
"load_augmentation_config",
@ -39,19 +37,10 @@ __all__ = [
"warp_spectrogram",
]
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
Registry(name="audio_augmentation")
)
spec_augmentations: Registry[Augmentation, []] = Registry(
name="spec_augmentation"
)
class MixAudioConfig(BaseConfig):
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
name: Literal["mix_audio"] = "mix_audio"
@ -98,27 +87,6 @@ class MixAudio(torch.nn.Module):
)
return mixed_audio, mixed_annotations
@audio_augmentations.register(MixAudioConfig)
@staticmethod
def from_config(
config: MixAudioConfig,
samplerate: int,
source: Optional[AudioSource],
):
if source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.",
stacklevel=2,
)
return lambda wav, clip_annotation: (wav, clip_annotation)
return MixAudio(
example_source=source,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
def mix_audio(
wav1: torch.Tensor,
@ -168,7 +136,7 @@ def combine_clip_annotations(
)
class AddEchoConfig(BaseConfig):
class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
name: Literal["add_echo"] = "add_echo"
@ -181,17 +149,14 @@ class AddEchoConfig(BaseConfig):
class AddEcho(torch.nn.Module):
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
min_weight: float = 0.1,
max_weight: float = 1.0,
max_delay: float = 0.005,
max_delay: int = 2560,
):
super().__init__()
self.samplerate = samplerate
self.min_weight = min_weight
self.max_weight = max_weight
self.max_delay_s = max_delay
self.max_delay = int(max_delay * samplerate)
self.max_delay = max_delay
def forward(
self,
@ -202,20 +167,6 @@ class AddEcho(torch.nn.Module):
weight = np.random.uniform(self.min_weight, self.max_weight)
return add_echo(wav, delay=delay, weight=weight), clip_annotation
@audio_augmentations.register(AddEchoConfig)
@staticmethod
def from_config(
config: AddEchoConfig,
samplerate: int,
source: Optional[AudioSource],
):
return AddEcho(
samplerate=samplerate,
min_weight=config.min_weight,
max_weight=config.max_weight,
max_delay=config.max_delay,
)
def add_echo(
wav: torch.Tensor,
@ -232,7 +183,7 @@ def add_echo(
return mix_audio(wav, audio_delay, weight)
class ScaleVolumeConfig(BaseConfig):
class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
name: Literal["scale_volume"] = "scale_volume"
@ -255,27 +206,19 @@ class ScaleVolume(torch.nn.Module):
factor = np.random.uniform(self.min_scaling, self.max_scaling)
return scale_volume(spec, factor=factor), clip_annotation
@spec_augmentations.register(ScaleVolumeConfig)
@staticmethod
def from_config(config: ScaleVolumeConfig):
return ScaleVolume(
min_scaling=config.min_scaling,
max_scaling=config.max_scaling,
)
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
"""Scale the amplitude of the spectrogram by a factor."""
return spec * factor
class WarpConfig(BaseConfig):
class WarpAugmentationConfig(BaseConfig):
name: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04
class Warp(torch.nn.Module):
class WarpSpectrogram(torch.nn.Module):
def __init__(self, delta: float = 0.04) -> None:
super().__init__()
self.delta = delta
@ -291,11 +234,6 @@ class Warp(torch.nn.Module):
warp_clip_annotation(clip_annotation, factor=factor),
)
@spec_augmentations.register(WarpConfig)
@staticmethod
def from_config(config: WarpConfig):
return Warp(delta=config.delta)
def warp_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation,
@ -356,7 +294,7 @@ def warp_spectrogram(
).squeeze(0)
class MaskTimeConfig(BaseConfig):
class TimeMaskAugmentationConfig(BaseConfig):
name: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05
@ -398,14 +336,6 @@ class MaskTime(torch.nn.Module):
]
return mask_time(spec, masks), clip_annotation
@spec_augmentations.register(MaskTimeConfig)
@staticmethod
def from_config(config: MaskTimeConfig):
return MaskTime(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
def mask_time(
spec: torch.Tensor,
@ -421,7 +351,7 @@ def mask_time(
return spec
class MaskFrequencyConfig(BaseConfig):
class FrequencyMaskAugmentationConfig(BaseConfig):
name: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10
@ -464,14 +394,6 @@ class MaskFrequency(torch.nn.Module):
]
return mask_frequency(spec, masks), clip_annotation
@spec_augmentations.register(MaskFrequencyConfig)
@staticmethod
def from_config(config: MaskFrequencyConfig):
return MaskFrequency(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
def mask_frequency(
spec: torch.Tensor,
@ -488,8 +410,8 @@ def mask_frequency(
AudioAugmentationConfig = Annotated[
Union[
MixAudioConfig,
AddEchoConfig,
MixAugmentationConfig,
EchoAugmentationConfig,
],
Field(discriminator="name"),
]
@ -497,22 +419,22 @@ AudioAugmentationConfig = Annotated[
SpectrogramAugmentationConfig = Annotated[
Union[
ScaleVolumeConfig,
WarpConfig,
MaskFrequencyConfig,
MaskTimeConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="name"),
]
AugmentationConfig = Annotated[
Union[
MixAudioConfig,
AddEchoConfig,
ScaleVolumeConfig,
WarpConfig,
MaskFrequencyConfig,
MaskTimeConfig,
MixAugmentationConfig,
EchoAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="name"),
]
@ -591,7 +513,7 @@ def build_augmentation_from_config(
)
if config.name == "warp":
return Warp(
return WarpSpectrogram(
delta=config.delta,
)
@ -616,14 +538,14 @@ def build_augmentation_from_config(
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
enabled=True,
audio=[
MixAudioConfig(),
AddEchoConfig(),
MixAugmentationConfig(),
EchoAugmentationConfig(),
],
spectrogram=[
ScaleVolumeConfig(),
WarpConfig(),
MaskTimeConfig(),
MaskFrequencyConfig(),
VolumeAugmentationConfig(),
WarpAugmentationConfig(),
TimeMaskAugmentationConfig(),
FrequencyMaskAugmentationConfig(),
],
)
@ -644,9 +566,9 @@ class AugmentationSequence(torch.nn.Module):
return tensor, clip_annotation
def build_audio_augmentations(
steps: Optional[Sequence[AudioAugmentationConfig]] = None,
samplerate: int = TARGET_SAMPLERATE_HZ,
def build_augmentation_sequence(
samplerate: int,
steps: Optional[Sequence[AugmentationConfig]] = None,
audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]:
if not steps:
@ -655,8 +577,10 @@ def build_audio_augmentations(
augmentations = []
for step_config in steps:
augmentation = audio_augmentations.build(
step_config, samplerate, audio_source
augmentation = build_augmentation_from_config(
step_config,
samplerate=samplerate,
audio_source=audio_source,
)
if augmentation is None:
@ -672,30 +596,6 @@ def build_audio_augmentations(
return AugmentationSequence(augmentations)
def build_spectrogram_augmentations(
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
) -> Optional[Augmentation]:
if not steps:
return None
augmentations = []
for step_config in steps:
augmentation = spec_augmentations.build(step_config)
if augmentation is None:
continue
augmentations.append(
MaybeApply(
augmentation=augmentation,
probability=step_config.probability,
)
)
return AugmentationSequence(augmentations)
def build_augmentations(
samplerate: int,
config: Optional[AugmentationsConfig] = None,
@ -709,14 +609,16 @@ def build_augmentations(
lambda: config.to_yaml_string(),
)
audio_augmentation = build_audio_augmentations(
audio_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,
samplerate=samplerate,
audio_source=audio_source,
)
spectrogram_augmentation = build_spectrogram_augmentations(
steps=config.spectrogram,
spectrogram_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,
audio_source=audio_source,
)
return audio_augmentation, spectrogram_augmentation

View File

@ -1,4 +1,4 @@
from typing import Any, List
from typing import List
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
@ -10,6 +10,7 @@ from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import (
ClipEvaluation,
EvaluatorProtocol,
ModelOutput,
RawPrediction,
@ -35,23 +36,23 @@ class ValidationMetrics(Callback):
def generate_plots(
self,
eval_outputs: Any,
pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
):
plotter = get_image_logger(pl_module.logger) # type: ignore
if plotter is None:
return
for figure_name, fig in self.evaluator.generate_plots(eval_outputs):
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, pl_module.global_step)
def log_metrics(
self,
eval_outputs: Any,
pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
):
metrics = self.evaluator.compute_metrics(eval_outputs)
metrics = self.evaluator.compute_metrics(evaluated_clips)
pl_module.log_dict(metrics)
def on_validation_epoch_end(
@ -59,13 +60,13 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
eval_outputs = self.evaluator.evaluate(
clip_evaluations = self.evaluator.evaluate(
self._clip_annotations,
self._predictions,
)
self.log_metrics(eval_outputs, pl_module)
self.generate_plots(eval_outputs, pl_module)
self.log_metrics(pl_module, clip_evaluations)
self.generate_plots(pl_module, clip_evaluations)
return super().on_validation_epoch_end(trainer, pl_module)

View File

@ -8,7 +8,7 @@ from loguru import logger
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate import build_evaluator
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
@ -105,10 +105,7 @@ def train(
trainer = trainer or build_trainer(
config,
targets=targets,
evaluator=build_evaluator(
config.train.validation,
targets=targets,
),
evaluator=build_evaluator(config.train.validation, targets=targets),
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
@ -146,7 +143,7 @@ def build_trainer_callbacks(
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="classification/mean_average_precision",
monitor="total_loss/val",
),
ValidationMetrics(evaluator),
]

View File

@ -1,8 +1,6 @@
from batdetect2.typing.evaluate import (
AffinityFunction,
ClipMatches,
ClipEvaluation,
EvaluatorProtocol,
MatcherProtocol,
MatchEvaluation,
MetricsProtocol,
PlotterProtocol,
@ -38,22 +36,19 @@ from batdetect2.typing.train import (
)
__all__ = [
"AffinityFunction",
"AudioLoader",
"Augmentation",
"BackboneModel",
"BatDetect2Prediction",
"ClipMatches",
"ClipEvaluation",
"ClipLabeller",
"ClipperProtocol",
"DetectionModel",
"EvaluatorProtocol",
"GeometryDecoder",
"Heatmaps",
"LossProtocol",
"Losses",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"ModelOutput",
"PlotterProtocol",
@ -68,4 +63,5 @@ __all__ = [
"SoundEventFilter",
"TargetProtocol",
"TrainExample",
"EvaluatorProtocol",
]

View File

@ -31,7 +31,6 @@ class MatchEvaluation:
sound_event_annotation: Optional[data.SoundEventAnnotation]
gt_det: bool
gt_class: Optional[str]
gt_geometry: Optional[data.Geometry]
pred_score: float
pred_class_scores: Dict[str, float]
@ -40,32 +39,44 @@ class MatchEvaluation:
affinity: float
@property
def top_class(self) -> Optional[str]:
def pred_class(self) -> Optional[str]:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def is_prediction(self) -> bool:
return self.pred_geometry is not None
@property
def is_generic(self) -> bool:
return self.gt_det and self.gt_class is None
@property
def top_class_score(self) -> float:
pred_class = self.top_class
def pred_class_score(self) -> float:
pred_class = self.pred_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
def is_true_positive(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class == self.pred_class
)
def is_false_positive(self, threshold: float = 0) -> bool:
return self.gt_det is None and self.pred_score > threshold
def is_false_negative(self, threshold: float = 0) -> bool:
return self.gt_det and self.pred_score <= threshold
def is_cross_trigger(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class != self.pred_class
)
@dataclass
class ClipMatches:
class ClipEvaluation:
clip: data.Clip
matches: List[MatchEvaluation]
@ -92,36 +103,29 @@ class AffinityFunction(Protocol, Generic[Geom]):
class MetricsProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]: ...
class PlotterProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...
EvaluationOutput = TypeVar("EvaluationOutput")
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
class EvaluatorProtocol(Protocol):
targets: TargetProtocol
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> EvaluationOutput: ...
) -> List[ClipEvaluation]: ...
def compute_metrics(
self, eval_outputs: EvaluationOutput
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]: ...
def generate_plots(
self, eval_outputs: EvaluationOutput
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...