From e752e96b93fac80c56b149d9670be87bf523714a Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 15 Sep 2025 16:01:15 +0100 Subject: [PATCH] Restructure eval metrics and plotting --- src/batdetect2/data/_core.py | 36 ++- src/batdetect2/data/conditions.py | 34 ++- src/batdetect2/data/transforms.py | 23 +- src/batdetect2/evaluate/__init__.py | 3 + src/batdetect2/evaluate/affinity.py | 151 +++++++++++ src/batdetect2/evaluate/config.py | 20 +- src/batdetect2/evaluate/dataframe.py | 93 +++---- src/batdetect2/evaluate/evaluate.py | 23 +- src/batdetect2/evaluate/evaluator.py | 169 ++++++++++++ src/batdetect2/evaluate/match.py | 371 ++++++++------------------ src/batdetect2/evaluate/metrics.py | 160 ++++++++--- src/batdetect2/evaluate/plots.py | 163 +++++++++++ src/batdetect2/plotting/__init__.py | 2 + src/batdetect2/plotting/evaluation.py | 160 ----------- src/batdetect2/plotting/gallery.py | 81 ++++++ src/batdetect2/plotting/matches.py | 18 +- src/batdetect2/train/callbacks.py | 64 +---- src/batdetect2/train/clips.py | 11 +- src/batdetect2/train/train.py | 20 +- src/batdetect2/typing/evaluate.py | 19 +- 20 files changed, 991 insertions(+), 630 deletions(-) create mode 100644 src/batdetect2/evaluate/affinity.py create mode 100644 src/batdetect2/evaluate/evaluator.py create mode 100644 src/batdetect2/evaluate/plots.py delete mode 100644 src/batdetect2/plotting/evaluation.py create mode 100644 src/batdetect2/plotting/gallery.py diff --git a/src/batdetect2/data/_core.py b/src/batdetect2/data/_core.py index 4ff9f7d..6b16dd8 100644 --- a/src/batdetect2/data/_core.py +++ b/src/batdetect2/data/_core.py @@ -1,6 +1,7 @@ from typing import Generic, Protocol, Type, TypeVar from pydantic import BaseModel +from typing_extensions import ParamSpec __all__ = [ "Registry", @@ -8,26 +9,36 @@ __all__ = [ T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) T_Type = TypeVar("T_Type", covariant=True) +P_Type = ParamSpec("P_Type") -class LogicProtocol(Generic[T_Config, T_Type], Protocol): - """A generic protocol for the logic classes (conditions or transforms).""" +class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol): + """A generic protocol for the logic classes.""" @classmethod - def from_config(cls, config: T_Config) -> T_Type: ... + def from_config( + cls, + config: T_Config, + *args: P_Type.args, + **kwargs: P_Type.kwargs, + ) -> T_Type: ... T_Proto = TypeVar("T_Proto", bound=LogicProtocol) -class Registry(Generic[T_Type]): +class Registry(Generic[T_Type, P_Type]): """A generic class to create and manage a registry of items.""" def __init__(self, name: str): self._name = name self._registry = {} - def register(self, config_cls: Type[T_Config]): + def register( + self, + config_cls: Type[T_Config], + logic_cls: LogicProtocol[T_Config, T_Type, P_Type], + ) -> None: """A decorator factory to register a new item.""" fields = config_cls.model_fields @@ -39,13 +50,14 @@ class Registry(Generic[T_Type]): if not isinstance(name, str): raise ValueError("'name' field must be a string literal.") - def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]: - self._registry[name] = logic_cls - return logic_cls + self._registry[name] = logic_cls - return decorator - - def build(self, config: BaseModel) -> T_Type: + def build( + self, + config: BaseModel, + *args: P_Type.args, + **kwargs: P_Type.kwargs, + ) -> T_Type: """Builds a logic instance from a config object.""" name = getattr(config, "name") # noqa: B009 @@ -58,4 +70,4 @@ class Registry(Generic[T_Type]): f"No {self._name} with name '{name}' is registered." ) - return self._registry[name].from_config(config) + return self._registry[name].from_config(config, *args, **kwargs) diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py index 42a59e9..b3d8fea 100644 --- a/src/batdetect2/data/conditions.py +++ b/src/batdetect2/data/conditions.py @@ -10,7 +10,7 @@ from batdetect2.data._core import Registry SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] -_conditions: Registry[SoundEventCondition] = Registry("condition") +condition_registry: Registry[SoundEventCondition, []] = Registry("condition") class HasTagConfig(BaseConfig): @@ -18,7 +18,6 @@ class HasTagConfig(BaseConfig): tag: data.Tag -@_conditions.register(HasTagConfig) class HasTag: def __init__(self, tag: data.Tag): self.tag = tag @@ -33,12 +32,14 @@ class HasTag: return cls(tag=config.tag) +condition_registry.register(HasTagConfig, HasTag) + + class HasAllTagsConfig(BaseConfig): name: Literal["has_all_tags"] = "has_all_tags" tags: List[data.Tag] -@_conditions.register(HasAllTagsConfig) class HasAllTags: def __init__(self, tags: List[data.Tag]): if not tags: @@ -56,12 +57,14 @@ class HasAllTags: return cls(tags=config.tags) +condition_registry.register(HasAllTagsConfig, HasAllTags) + + class HasAnyTagConfig(BaseConfig): name: Literal["has_any_tag"] = "has_any_tag" tags: List[data.Tag] -@_conditions.register(HasAnyTagConfig) class HasAnyTag: def __init__(self, tags: List[data.Tag]): if not tags: @@ -79,6 +82,8 @@ class HasAnyTag: return cls(tags=config.tags) +condition_registry.register(HasAnyTagConfig, HasAnyTag) + Operator = Literal["gt", "gte", "lt", "lte", "eq"] @@ -109,7 +114,6 @@ def _build_comparator( raise ValueError(f"Invalid operator {operator}") -@_conditions.register(DurationConfig) class Duration: def __init__(self, operator: Operator, seconds: float): self.operator = operator @@ -135,6 +139,9 @@ class Duration: return cls(operator=config.operator, seconds=config.seconds) +condition_registry.register(DurationConfig, Duration) + + class FrequencyConfig(BaseConfig): name: Literal["frequency"] = "frequency" boundary: Literal["low", "high"] @@ -142,7 +149,6 @@ class FrequencyConfig(BaseConfig): hertz: float -@_conditions.register(FrequencyConfig) class Frequency: def __init__( self, @@ -184,12 +190,14 @@ class Frequency: ) +condition_registry.register(FrequencyConfig, Frequency) + + class AllOfConfig(BaseConfig): name: Literal["all_of"] = "all_of" conditions: Sequence["SoundEventConditionConfig"] -@_conditions.register(AllOfConfig) class AllOf: def __init__(self, conditions: List[SoundEventCondition]): self.conditions = conditions @@ -207,12 +215,14 @@ class AllOf: return cls(conditions) +condition_registry.register(AllOfConfig, AllOf) + + class AnyOfConfig(BaseConfig): name: Literal["any_of"] = "any_of" conditions: List["SoundEventConditionConfig"] -@_conditions.register(AnyOfConfig) class AnyOf: def __init__(self, conditions: List[SoundEventCondition]): self.conditions = conditions @@ -230,12 +240,14 @@ class AnyOf: return cls(conditions) +condition_registry.register(AnyOfConfig, AnyOf) + + class NotConfig(BaseConfig): name: Literal["not"] = "not" condition: "SoundEventConditionConfig" -@_conditions.register(NotConfig) class Not: def __init__(self, condition: SoundEventCondition): self.condition = condition @@ -251,6 +263,8 @@ class Not: return cls(condition) +condition_registry.register(NotConfig, Not) + SoundEventConditionConfig = Annotated[ Union[ HasTagConfig, @@ -269,7 +283,7 @@ SoundEventConditionConfig = Annotated[ def build_sound_event_condition( config: SoundEventConditionConfig, ) -> SoundEventCondition: - return _conditions.build(config) + return condition_registry.build(config) def filter_clip_annotation( diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py index 5dc7e84..d57f567 100644 --- a/src/batdetect2/data/transforms.py +++ b/src/batdetect2/data/transforms.py @@ -17,7 +17,7 @@ SoundEventTransform = Callable[ data.SoundEventAnnotation, ] -_transforms: Registry[SoundEventTransform] = Registry("transform") +transform_registry: Registry[SoundEventTransform, []] = Registry("transform") class SetFrequencyBoundConfig(BaseConfig): @@ -26,7 +26,6 @@ class SetFrequencyBoundConfig(BaseConfig): hertz: float -@_transforms.register(SetFrequencyBoundConfig) class SetFrequencyBound: def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"): self.hertz = hertz @@ -69,13 +68,15 @@ class SetFrequencyBound: return cls(hertz=config.hertz, boundary=config.boundary) +transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound) + + class ApplyIfConfig(BaseConfig): name: Literal["apply_if"] = "apply_if" transform: "SoundEventTransformConfig" condition: SoundEventConditionConfig -@_transforms.register(ApplyIfConfig) class ApplyIf: def __init__( self, @@ -101,13 +102,15 @@ class ApplyIf: return cls(condition=condition, transform=transform) +transform_registry.register(ApplyIfConfig, ApplyIf) + + class ReplaceTagConfig(BaseConfig): name: Literal["replace_tag"] = "replace_tag" original: data.Tag replacement: data.Tag -@_transforms.register(ReplaceTagConfig) class ReplaceTag: def __init__( self, @@ -136,6 +139,9 @@ class ReplaceTag: return cls(original=config.original, replacement=config.replacement) +transform_registry.register(ReplaceTagConfig, ReplaceTag) + + class MapTagValueConfig(BaseConfig): name: Literal["map_tag_value"] = "map_tag_value" tag_key: str @@ -143,7 +149,6 @@ class MapTagValueConfig(BaseConfig): target_key: Optional[str] = None -@_transforms.register(MapTagValueConfig) class MapTagValue: def __init__( self, @@ -193,12 +198,14 @@ class MapTagValue: ) +transform_registry.register(MapTagValueConfig, MapTagValue) + + class ApplyAllConfig(BaseConfig): name: Literal["apply_all"] = "apply_all" steps: List["SoundEventTransformConfig"] = Field(default_factory=list) -@_transforms.register(ApplyAllConfig) class ApplyAll: def __init__(self, steps: List[SoundEventTransform]): self.steps = steps @@ -218,6 +225,8 @@ class ApplyAll: return cls(steps) +transform_registry.register(ApplyAllConfig, ApplyAll) + SoundEventTransformConfig = Annotated[ Union[ SetFrequencyBoundConfig, @@ -233,7 +242,7 @@ SoundEventTransformConfig = Annotated[ def build_sound_event_transform( config: SoundEventTransformConfig, ) -> SoundEventTransform: - return _transforms.build(config) + return transform_registry.build(config) def transform_clip_annotation( diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index 412ed72..211edf9 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,6 +1,9 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config +from batdetect2.evaluate.evaluator import Evaluator, build_evaluator __all__ = [ "EvaluationConfig", "load_evaluation_config", + "Evaluator", + "build_evaluator", ] diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py new file mode 100644 index 0000000..fe753bc --- /dev/null +++ b/src/batdetect2/evaluate/affinity.py @@ -0,0 +1,151 @@ +from typing import Annotated, Literal, Optional, Union + +from pydantic import Field +from soundevent import data +from soundevent.evaluation import compute_affinity + +from batdetect2.configs import BaseConfig +from batdetect2.data._core import Registry +from batdetect2.typing.evaluate import AffinityFunction + +affinity_functions: Registry[AffinityFunction, []] = Registry( + "matching_strategy" +) + + +class TimeAffinityConfig(BaseConfig): + name: Literal["time_affinity"] = "time_affinity" + time_buffer: float = 0.01 + + +class TimeAffinity(AffinityFunction): + def __init__(self, time_buffer: float): + self.time_buffer = time_buffer + + def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): + return compute_timestamp_affinity( + geometry1, geometry2, time_buffer=self.time_buffer + ) + + @classmethod + def from_config(cls, config: TimeAffinityConfig): + return cls(time_buffer=config.time_buffer) + + +affinity_functions.register(TimeAffinityConfig, TimeAffinity) + + +def compute_timestamp_affinity( + geometry1: data.Geometry, + geometry2: data.Geometry, + time_buffer: float = 0.01, +) -> float: + assert isinstance(geometry1, data.TimeStamp) + assert isinstance(geometry2, data.TimeStamp) + + start_time1 = geometry1.coordinates + start_time2 = geometry2.coordinates + + a = min(start_time1, start_time2) + b = max(start_time1, start_time2) + + if b - a >= 2 * time_buffer: + return 0 + + intersection = a - b + 2 * time_buffer + union = b - a + 2 * time_buffer + return intersection / union + + +class IntervalIOUConfig(BaseConfig): + name: Literal["interval_iou"] = "interval_iou" + time_buffer: float = 0.01 + + +class IntervalIOU(AffinityFunction): + def __init__(self, time_buffer: float): + self.time_buffer = time_buffer + + def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): + return compute_interval_iou( + geometry1, + geometry2, + time_buffer=self.time_buffer, + ) + + @classmethod + def from_config(cls, config: IntervalIOUConfig): + return cls(time_buffer=config.time_buffer) + + +affinity_functions.register(IntervalIOUConfig, IntervalIOU) + + +def compute_interval_iou( + geometry1: data.Geometry, + geometry2: data.Geometry, + time_buffer: float = 0.01, +) -> float: + assert isinstance(geometry1, data.TimeInterval) + assert isinstance(geometry2, data.TimeInterval) + + start_time1, end_time1 = geometry1.coordinates + start_time2, end_time2 = geometry1.coordinates + + start_time1 -= time_buffer + start_time2 -= time_buffer + end_time1 += time_buffer + end_time2 += time_buffer + + intersection = max( + 0, min(end_time1, end_time2) - max(start_time1, start_time2) + ) + union = ( + (end_time1 - start_time1) + (end_time2 - start_time2) - intersection + ) + + if union == 0: + return 0 + + return intersection / union + + +class GeometricIOUConfig(BaseConfig): + name: Literal["geometric_iou"] = "geometric_iou" + time_buffer: float = 0.01 + freq_buffer: float = 1000 + + +class GeometricIOU(AffinityFunction): + def __init__(self, time_buffer: float): + self.time_buffer = time_buffer + + def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): + return compute_affinity( + geometry1, + geometry2, + time_buffer=self.time_buffer, + ) + + @classmethod + def from_config(cls, config: GeometricIOUConfig): + return cls(time_buffer=config.time_buffer) + + +affinity_functions.register(GeometricIOUConfig, GeometricIOU) + +AffinityConfig = Annotated[ + Union[ + TimeAffinityConfig, + IntervalIOUConfig, + GeometricIOUConfig, + ], + Field(discriminator="name"), +] + + +def build_affinity_function( + config: Optional[AffinityConfig] = None, +) -> AffinityFunction: + config = config or GeometricIOUConfig() + return affinity_functions.build(config) diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index cc93a89..324c948 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -1,10 +1,16 @@ -from typing import Optional +from typing import List, Optional from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig +from batdetect2.evaluate.metrics import ( + ClassificationAPConfig, + DetectionAPConfig, + MetricConfig, +) +from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig __all__ = [ "EvaluationConfig", @@ -13,7 +19,19 @@ __all__ = [ class EvaluationConfig(BaseConfig): + ignore_start_end: float = 0.01 match: MatchConfig = Field(default_factory=StartTimeMatchConfig) + metrics: List[MetricConfig] = Field( + default_factory=lambda: [ + DetectionAPConfig(), + ClassificationAPConfig(), + ] + ) + plots: List[PlotConfig] = Field( + default_factory=lambda: [ + ExampleGalleryConfig(), + ] + ) def load_evaluation_config( diff --git a/src/batdetect2/evaluate/dataframe.py b/src/batdetect2/evaluate/dataframe.py index 4cc0ff9..7398d34 100644 --- a/src/batdetect2/evaluate/dataframe.py +++ b/src/batdetect2/evaluate/dataframe.py @@ -3,60 +3,61 @@ from typing import List import pandas as pd from soundevent.geometry import compute_bounds -from batdetect2.typing.evaluate import MatchEvaluation +from batdetect2.typing.evaluate import ClipEvaluation -def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame: +def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame: data = [] - for match in matches: - gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None - pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None + for clip_evaluation in clip_evaluations: + for match in clip_evaluation.matches: + gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None + pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None - sound_event_annotation = match.sound_event_annotation + sound_event_annotation = match.sound_event_annotation - if sound_event_annotation is not None: - geometry = sound_event_annotation.sound_event.geometry - assert geometry is not None - gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = ( - compute_bounds(geometry) + if sound_event_annotation is not None: + geometry = sound_event_annotation.sound_event.geometry + assert geometry is not None + gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = ( + compute_bounds(geometry) + ) + + if match.pred_geometry is not None: + pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = ( + compute_bounds(match.pred_geometry) + ) + + data.append( + { + ("recording", "uuid"): match.clip.recording.uuid, + ("clip", "uuid"): match.clip.uuid, + ("clip", "start_time"): match.clip.start_time, + ("clip", "end_time"): match.clip.end_time, + ("gt", "uuid"): match.sound_event_annotation.uuid + if match.sound_event_annotation is not None + else None, + ("gt", "class"): match.gt_class, + ("gt", "det"): match.gt_det, + ("gt", "start_time"): gt_start_time, + ("gt", "end_time"): gt_end_time, + ("gt", "low_freq"): gt_low_freq, + ("gt", "high_freq"): gt_high_freq, + ("pred", "score"): match.pred_score, + ("pred", "class"): match.pred_class, + ("pred", "class_score"): match.pred_class_score, + ("pred", "start_time"): pred_start_time, + ("pred", "end_time"): pred_end_time, + ("pred", "low_freq"): pred_low_freq, + ("pred", "high_freq"): pred_high_freq, + ("match", "affinity"): match.affinity, + **{ + ("pred_class_score", key): value + for key, value in match.pred_class_scores.items() + }, + } ) - if match.pred_geometry is not None: - pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = ( - compute_bounds(match.pred_geometry) - ) - - data.append( - { - ("recording", "uuid"): match.clip.recording.uuid, - ("clip", "uuid"): match.clip.uuid, - ("clip", "start_time"): match.clip.start_time, - ("clip", "end_time"): match.clip.end_time, - ("gt", "uuid"): match.sound_event_annotation.uuid - if match.sound_event_annotation is not None - else None, - ("gt", "class"): match.gt_class, - ("gt", "det"): match.gt_det, - ("gt", "start_time"): gt_start_time, - ("gt", "end_time"): gt_end_time, - ("gt", "low_freq"): gt_low_freq, - ("gt", "high_freq"): gt_high_freq, - ("pred", "score"): match.pred_score, - ("pred", "class"): match.pred_class, - ("pred", "class_score"): match.pred_class_score, - ("pred", "start_time"): pred_start_time, - ("pred", "end_time"): pred_end_time, - ("pred", "low_freq"): pred_low_freq, - ("pred", "high_freq"): pred_high_freq, - ("match", "affinity"): match.affinity, - **{ - ("pred_class_score", key): value - for key, value in match.pred_class_scores.items() - }, - } - ) - df = pd.DataFrame(data) df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore return df diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 174f75b..7bcc71b 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -4,11 +4,8 @@ import pandas as pd from soundevent import data from batdetect2.evaluate.dataframe import extract_matches_dataframe -from batdetect2.evaluate.match import build_matcher, match_all_predictions -from batdetect2.evaluate.metrics import ( - ClassificationMeanAveragePrecision, - DetectionAveragePrecision, -) +from batdetect2.evaluate.evaluator import build_evaluator +from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP from batdetect2.models import Model from batdetect2.plotting.clips import build_audio_loader from batdetect2.postprocess import get_raw_predictions @@ -55,6 +52,8 @@ def evaluate( clip_annotations = [] predictions = [] + evaluator = build_evaluator(config=config.evaluation) + for batch in loader: outputs = model.detector(batch.spec) @@ -76,20 +75,12 @@ def evaluate( clip_annotations.extend(clip_annotations) predictions.extend(predictions) - matcher = build_matcher(config.evaluation.match) - - matches = match_all_predictions( - clip_annotations, - predictions, - targets=targets, - matcher=matcher, - ) - + matches = evaluator.evaluate(clip_annotations, predictions) df = extract_matches_dataframe(matches) metrics = [ - DetectionAveragePrecision(), - ClassificationMeanAveragePrecision(class_names=targets.class_names), + DetectionAP(), + ClassificationAP(class_names=targets.class_names), ] results = { diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py new file mode 100644 index 0000000..d60f333 --- /dev/null +++ b/src/batdetect2/evaluate/evaluator.py @@ -0,0 +1,169 @@ +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +from matplotlib.figure import Figure +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.evaluate.config import EvaluationConfig +from batdetect2.evaluate.match import build_matcher, match +from batdetect2.evaluate.metrics import build_metric +from batdetect2.evaluate.plots import build_plotter +from batdetect2.targets import build_targets +from batdetect2.typing.evaluate import ( + ClipEvaluation, + MatcherProtocol, + MetricsProtocol, + PlotterProtocol, +) +from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.targets import TargetProtocol + +__all__ = [ + "Evaluator", + "build_evaluator", +] + + +class Evaluator: + def __init__( + self, + config: EvaluationConfig, + targets: TargetProtocol, + matcher: MatcherProtocol, + metrics: List[MetricsProtocol], + plots: List[PlotterProtocol], + ): + self.config = config + self.targets = targets + self.matcher = matcher + self.metrics = metrics + self.plots = plots + + def match( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipEvaluation: + clip = clip_annotation.clip + ground_truth = [ + sound_event + for sound_event in clip_annotation.sound_events + if self.filter_sound_event_annotations(sound_event, clip) + ] + predictions = [ + prediction + for prediction in predictions + if self.filter_predictions(prediction, clip) + ] + return ClipEvaluation( + clip=clip_annotation.clip, + matches=match( + ground_truth, + predictions, + clip=clip, + targets=self.targets, + matcher=self.matcher, + ), + ) + + def filter_sound_event_annotations( + self, + sound_event_annotation: data.SoundEventAnnotation, + clip: data.Clip, + ) -> bool: + if not self.targets.filter(sound_event_annotation): + return False + + geometry = sound_event_annotation.sound_event.geometry + if geometry is None: + return False + + return is_in_bounds( + geometry, + clip, + self.config.ignore_start_end, + ) + + def filter_predictions( + self, + prediction: RawPrediction, + clip: data.Clip, + ) -> bool: + return is_in_bounds( + prediction.geometry, + clip, + self.config.ignore_start_end, + ) + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> List[ClipEvaluation]: + if len(clip_annotations) != len(predictions): + raise ValueError( + "Number of annotated clips and sets of predictions do not match" + ) + + return [ + self.match(clip_annotation, preds) + for clip_annotation, preds in zip(clip_annotations, predictions) + ] + + def compute_metrics( + self, + clip_evaluations: Sequence[ClipEvaluation], + ) -> Dict[str, float]: + results = {} + + for metric in self.metrics: + results.update(metric(clip_evaluations)) + + return results + + def generate_plots( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Iterable[Tuple[str, Figure]]: + for plotter in self.plots: + for name, fig in plotter(clip_evaluations): + yield name, fig + + +def build_evaluator( + config: Optional[EvaluationConfig] = None, + targets: Optional[TargetProtocol] = None, + matcher: Optional[MatcherProtocol] = None, + plots: Optional[List[PlotterProtocol]] = None, + metrics: Optional[List[MetricsProtocol]] = None, +) -> Evaluator: + config = config or EvaluationConfig() + targets = targets or build_targets() + matcher = matcher or build_matcher(config.match) + + if metrics is None: + metrics = [ + build_metric(config, targets.class_names) + for config in config.metrics + ] + + if plots is None: + plots = [build_plotter(config) for config in config.plots] + + return Evaluator( + config=config, + targets=targets, + matcher=matcher, + metrics=metrics, + plots=plots, + ) + + +def is_in_bounds( + geometry: data.Geometry, + clip: data.Clip, + buffer: float, +) -> bool: + start_time = compute_bounds(geometry)[0] + return (start_time >= clip.start_time + buffer) and ( + start_time <= clip.end_time - buffer + ) diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 914341d..6df3a36 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,9 +1,7 @@ from collections.abc import Callable, Iterable, Mapping -from dataclasses import dataclass, field from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union import numpy as np -from loguru import logger from pydantic import Field from soundevent import data from soundevent.evaluation import compute_affinity @@ -12,6 +10,11 @@ from soundevent.geometry import compute_bounds from batdetect2.configs import BaseConfig from batdetect2.data._core import Registry +from batdetect2.evaluate.affinity import ( + AffinityConfig, + GeometricIOUConfig, + build_affinity_function, +) from batdetect2.targets import build_targets from batdetect2.typing import ( MatchEvaluation, @@ -23,7 +26,88 @@ from batdetect2.typing.postprocess import RawPrediction MatchingGeometry = Literal["bbox", "interval", "timestamp"] """The geometry representation to use for matching.""" -matching_strategy = Registry("matching_strategy") +matching_strategies = Registry("matching_strategy") + + +def match( + sound_event_annotations: Sequence[data.SoundEventAnnotation], + raw_predictions: Sequence[RawPrediction], + clip: data.Clip, + targets: Optional[TargetProtocol] = None, + matcher: Optional[MatcherProtocol] = None, +) -> List[MatchEvaluation]: + if matcher is None: + matcher = build_matcher() + + if targets is None: + targets = build_targets() + + target_geometries: List[data.Geometry] = [ # type: ignore + sound_event_annotation.sound_event.geometry + for sound_event_annotation in sound_event_annotations + ] + + predicted_geometries = [ + raw_prediction.geometry for raw_prediction in raw_predictions + ] + + scores = [ + raw_prediction.detection_score for raw_prediction in raw_predictions + ] + + matches = [] + + for source_idx, target_idx, affinity in matcher( + ground_truth=target_geometries, + predictions=predicted_geometries, + scores=scores, + ): + target = ( + sound_event_annotations[target_idx] + if target_idx is not None + else None + ) + prediction = ( + raw_predictions[source_idx] if source_idx is not None else None + ) + + gt_det = target_idx is not None + gt_class = targets.encode_class(target) if target is not None else None + + pred_score = float(prediction.detection_score) if prediction else 0 + + pred_geometry = ( + predicted_geometries[source_idx] + if source_idx is not None + else None + ) + + class_scores = ( + { + str(class_name): float(score) + for class_name, score in zip( + targets.class_names, + prediction.class_scores, + ) + } + if prediction is not None + else {} + ) + + matches.append( + MatchEvaluation( + clip=clip, + sound_event_annotation=target, + gt_det=gt_det, + gt_class=gt_class, + pred_score=pred_score, + pred_class_scores=class_scores, + pred_geometry=pred_geometry, + affinity=affinity, + ) + ) + + return matches class StartTimeMatchConfig(BaseConfig): @@ -31,7 +115,6 @@ class StartTimeMatchConfig(BaseConfig): distance_threshold: float = 0.01 -@matching_strategy.register(StartTimeMatchConfig) class StartTimeMatcher(MatcherProtocol): def __init__(self, distance_threshold: float): self.distance_threshold = distance_threshold @@ -54,6 +137,9 @@ class StartTimeMatcher(MatcherProtocol): return cls(distance_threshold=config.distance_threshold) +matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher) + + def match_start_times( ground_truth: Sequence[data.Geometry], predictions: Sequence[data.Geometry], @@ -74,8 +160,8 @@ def match_start_times( gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth]) pred_times = np.array([compute_bounds(geom)[0] for geom in predictions]) - scores = np.array(scores) + scores = np.array(scores) sort_args = np.argsort(scores)[::-1] distances = np.abs(gt_times[None, :] - pred_times[:, None]) @@ -143,89 +229,25 @@ _geometry_cast_functions: Mapping[ } -def _timestamp_affinity( - geometry1: data.Geometry, - geometry2: data.Geometry, - time_buffer: float = 0.01, - freq_buffer: float = 1000, -) -> float: - assert isinstance(geometry1, data.TimeStamp) - assert isinstance(geometry2, data.TimeStamp) - - start_time1 = geometry1.coordinates - start_time2 = geometry2.coordinates - - a = min(start_time1, start_time2) - b = max(start_time1, start_time2) - - if b - a >= 2 * time_buffer: - return 0 - - intersection = a - b + 2 * time_buffer - union = b - a + 2 * time_buffer - return intersection / union - - -def _interval_affinity( - geometry1: data.Geometry, - geometry2: data.Geometry, - time_buffer: float = 0.01, - freq_buffer: float = 1000, -) -> float: - assert isinstance(geometry1, data.TimeInterval) - assert isinstance(geometry2, data.TimeInterval) - - start_time1, end_time1 = geometry1.coordinates - start_time2, end_time2 = geometry1.coordinates - - start_time1 -= time_buffer - start_time2 -= time_buffer - end_time1 += time_buffer - end_time2 += time_buffer - - intersection = max( - 0, min(end_time1, end_time2) - max(start_time1, start_time2) - ) - union = ( - (end_time1 - start_time1) + (end_time2 - start_time2) - intersection - ) - - if union == 0: - return 0 - - return intersection / union - - -_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = { - "timestamp": _timestamp_affinity, - "interval": _interval_affinity, - "bbox": compute_affinity, -} - - class GreedyMatchConfig(BaseConfig): name: Literal["greedy_match"] = "greedy_match" geometry: MatchingGeometry = "timestamp" - affinity_threshold: float = 0.0 - time_buffer: float = 0.005 - frequency_buffer: float = 1_000 + affinity_threshold: float = 0.5 + affinity_function: AffinityConfig = Field( + default_factory=GeometricIOUConfig + ) -@matching_strategy.register(GreedyMatchConfig) class GreedyMatcher(MatcherProtocol): def __init__( self, geometry: MatchingGeometry, affinity_threshold: float, - time_buffer: float, - frequency_buffer: float, + affinity_function: AffinityFunction, ): self.geometry = geometry + self.affinity_function = affinity_function self.affinity_threshold = affinity_threshold - self.time_buffer = time_buffer - self.frequency_buffer = frequency_buffer - - self.affinity_function = _affinity_functions[self.geometry] self.cast_geometry = _geometry_cast_functions[self.geometry] def __call__( @@ -240,28 +262,27 @@ class GreedyMatcher(MatcherProtocol): scores=scores, affinity_function=self.affinity_function, affinity_threshold=self.affinity_threshold, - time_buffer=self.time_buffer, - freq_buffer=self.frequency_buffer, ) @classmethod def from_config(cls, config: GreedyMatchConfig): + affinity_function = build_affinity_function(config.affinity_function) return cls( geometry=config.geometry, affinity_threshold=config.affinity_threshold, - time_buffer=config.time_buffer, - frequency_buffer=config.frequency_buffer, + affinity_function=affinity_function, ) +matching_strategies.register(GreedyMatchConfig, GreedyMatcher) + + def greedy_match( ground_truth: Sequence[data.Geometry], predictions: Sequence[data.Geometry], scores: Sequence[float], affinity_threshold: float = 0.5, affinity_function: AffinityFunction = compute_affinity, - time_buffer: float = 0.001, - freq_buffer: float = 1000, ) -> Iterable[Tuple[Optional[int], Optional[int], float]]: """Performs a greedy, one-to-one matching of source to target geometries. @@ -279,10 +300,6 @@ def greedy_match( Confidence scores for each source geometry for prioritization. affinity_threshold The minimum affinity score required for a valid match. - time_buffer - Time tolerance in seconds for affinity calculation. - freq_buffer - Frequency tolerance in Hertz for affinity calculation. Yields ------ @@ -314,12 +331,7 @@ def greedy_match( affinities = np.array( [ - affinity_function( - source_geometry, - target_geometry, - time_buffer=time_buffer, - freq_buffer=freq_buffer, - ) + affinity_function(source_geometry, target_geometry) for target_geometry in ground_truth ] ) @@ -344,12 +356,11 @@ def greedy_match( class OptimalMatchConfig(BaseConfig): name: Literal["optimal_match"] = "optimal_match" - affinity_threshold: float = 0.0 + affinity_threshold: float = 0.5 time_buffer: float = 0.005 frequency_buffer: float = 1_000 -@matching_strategy.register(OptimalMatchConfig) class OptimalMatcher(MatcherProtocol): def __init__( self, @@ -384,6 +395,8 @@ class OptimalMatcher(MatcherProtocol): ) +matching_strategies.register(OptimalMatchConfig, OptimalMatcher) + MatchConfig = Annotated[ Union[ GreedyMatchConfig, @@ -396,174 +409,4 @@ MatchConfig = Annotated[ def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol: config = config or StartTimeMatchConfig() - return matching_strategy.build(config) - - -def _is_in_bounds( - geometry: data.Geometry, - clip: data.Clip, - buffer: float, -) -> bool: - start_time = compute_bounds(geometry)[0] - return (start_time >= clip.start_time + buffer) and ( - start_time <= clip.end_time - buffer - ) - - -def match_sound_events_and_predictions( - clip_annotation: data.ClipAnnotation, - raw_predictions: List[RawPrediction], - targets: Optional[TargetProtocol] = None, - matcher: Optional[MatcherProtocol] = None, - ignore_start_end: float = 0.01, -) -> List[MatchEvaluation]: - if matcher is None: - matcher = build_matcher() - - if targets is None: - targets = build_targets() - - target_sound_events = [ - sound_event_annotation - for sound_event_annotation in clip_annotation.sound_events - if targets.filter(sound_event_annotation) - and sound_event_annotation.sound_event.geometry is not None - and _is_in_bounds( - sound_event_annotation.sound_event.geometry, - clip=clip_annotation.clip, - buffer=ignore_start_end, - ) - ] - - target_geometries: List[data.Geometry] = [ - sound_event_annotation.sound_event.geometry - for sound_event_annotation in target_sound_events - if sound_event_annotation.sound_event.geometry is not None - ] - - raw_predictions = [ - raw_prediction - for raw_prediction in raw_predictions - if _is_in_bounds( - raw_prediction.geometry, - clip=clip_annotation.clip, - buffer=ignore_start_end, - ) - ] - - predicted_geometries = [ - raw_prediction.geometry for raw_prediction in raw_predictions - ] - - scores = [ - raw_prediction.detection_score for raw_prediction in raw_predictions - ] - - matches = [] - - for source_idx, target_idx, affinity in matcher( - ground_truth=target_geometries, - predictions=predicted_geometries, - scores=scores, - ): - target = ( - target_sound_events[target_idx] if target_idx is not None else None - ) - prediction = ( - raw_predictions[source_idx] if source_idx is not None else None - ) - - gt_det = target_idx is not None - gt_class = targets.encode_class(target) if target is not None else None - - pred_score = float(prediction.detection_score) if prediction else 0 - - pred_geometry = ( - predicted_geometries[source_idx] - if source_idx is not None - else None - ) - - class_scores = ( - { - str(class_name): float(score) - for class_name, score in zip( - targets.class_names, - prediction.class_scores, - ) - } - if prediction is not None - else {} - ) - - matches.append( - MatchEvaluation( - clip=clip_annotation.clip, - sound_event_annotation=target, - gt_det=gt_det, - gt_class=gt_class, - pred_score=pred_score, - pred_class_scores=class_scores, - pred_geometry=pred_geometry, - affinity=affinity, - ) - ) - - return matches - - -def match_all_predictions( - clip_annotations: List[data.ClipAnnotation], - predictions: List[List[RawPrediction]], - targets: Optional[TargetProtocol] = None, - matcher: Optional[MatcherProtocol] = None, - ignore_start_end: float = 0.01, -) -> List[MatchEvaluation]: - logger.info("Matching all annotations and predictions...") - return [ - match - for clip_annotation, raw_predictions in zip( - clip_annotations, - predictions, - ) - for match in match_sound_events_and_predictions( - clip_annotation, - raw_predictions, - targets=targets, - matcher=matcher, - ignore_start_end=ignore_start_end, - ) - ] - - -@dataclass -class ClassExamples: - false_positives: List[MatchEvaluation] = field(default_factory=list) - false_negatives: List[MatchEvaluation] = field(default_factory=list) - true_positives: List[MatchEvaluation] = field(default_factory=list) - cross_triggers: List[MatchEvaluation] = field(default_factory=list) - - -def group_matches(matches: List[MatchEvaluation]) -> ClassExamples: - class_examples = ClassExamples() - - for match in matches: - gt_class = match.gt_class - pred_class = match.pred_class - - if pred_class is None: - class_examples.false_negatives.append(match) - continue - - if gt_class is None: - class_examples.false_positives.append(match) - continue - - if gt_class != pred_class: - class_examples.cross_triggers.append(match) - class_examples.cross_triggers.append(match) - continue - - class_examples.true_positives.append(match) - - return class_examples + return matching_strategies.build(config) diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index c29fa5f..0a9be88 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -1,61 +1,151 @@ -from typing import Dict, List +from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union import numpy as np -import pandas as pd +from pydantic import Field from sklearn import metrics from sklearn.preprocessing import label_binarize -from batdetect2.typing import MatchEvaluation, MetricsProtocol +from batdetect2.configs import BaseConfig +from batdetect2.data._core import Registry +from batdetect2.typing import MetricsProtocol +from batdetect2.typing.evaluate import ClipEvaluation -__all__ = ["DetectionAveragePrecision"] +__all__ = ["DetectionAP", "ClassificationAP"] -class DetectionAveragePrecision(MetricsProtocol): - def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: +metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric") + + +class DetectionAPConfig(BaseConfig): + name: Literal["detection_ap"] = "detection_ap" + + +class DetectionAP(MetricsProtocol): + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: y_true, y_score = zip( - *[(match.gt_det, match.pred_score) for match in matches] + *[ + (match.gt_det, match.pred_score) + for clip_eval in clip_evaluations + for match in clip_eval.matches + ] ) score = float(metrics.average_precision_score(y_true, y_score)) return {"detection_AP": score} + @classmethod + def from_config(cls, config: DetectionAPConfig, class_names: List[str]): + return cls() -class ClassificationMeanAveragePrecision(MetricsProtocol): - def __init__(self, class_names: List[str]): + +metrics_registry.register(DetectionAPConfig, DetectionAP) + + +class ClassificationAPConfig(BaseConfig): + name: Literal["classification_ap"] = "classification_ap" + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class ClassificationAP(MetricsProtocol): + def __init__( + self, + class_names: List[str], + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ): self.class_names = class_names - def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: - # NOTE: Need to exclude generic but unclassified targets - y_true = label_binarize( - [ - match.gt_class if match.gt_class is not None else "__NONE__" - for match in matches - if not (match.gt_det and match.gt_class is None) - ], - classes=self.class_names, - ) - y_pred = pd.DataFrame( - [ - { - name: match.pred_class_scores.get(name, 0) - for name in self.class_names - } - for match in matches - if not (match.gt_det and match.gt_class is None) - ] - ).fillna(0) + self.selected = class_names - ret = {} + if include is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name in include + ] + + if exclude is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name not in exclude + ] + + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + for match in clip_eval.matches: + # Ignore generic unclassified targets + if match.gt_det and match.gt_class is None: + continue + + y_true.append( + match.gt_class + if match.gt_class is not None + else "__NONE__" + ) + + y_pred.append( + np.array( + [ + match.pred_class_scores.get(name, 0) + for name in self.class_names + ] + ) + ) + + y_true = label_binarize(y_true, classes=self.class_names) + y_pred = np.stack(y_pred) + + class_scores = {} for class_index, class_name in enumerate(self.class_names): y_true_class = y_true[:, class_index] - y_pred_class = y_pred[class_name] + y_pred_class = y_pred[:, class_index] class_ap = metrics.average_precision_score( y_true_class, y_pred_class, ) - ret[f"classification_AP/{class_name}"] = float(class_ap) + class_scores[class_name] = float(class_ap) - ret["classification_mAP"] = np.mean( - [value for value in ret.values() if value != 0] + mean_ap = np.mean( + [value for value in class_scores.values() if value != 0] ) - return ret + return { + "classification_mAP": float(mean_ap), + **{ + f"classification_AP/{class_name}": class_scores[class_name] + for class_name in self.selected + }, + } + + @classmethod + def from_config( + cls, + config: ClassificationAPConfig, + class_names: List[str], + ): + return cls( + class_names, + include=config.include, + exclude=config.exclude, + ) + + +metrics_registry.register(ClassificationAPConfig, ClassificationAP) + + +MetricConfig = Annotated[ + Union[ClassificationAPConfig, DetectionAPConfig], + Field(discriminator="name"), +] + + +def build_metric(config: MetricConfig, class_names: List[str]): + return metrics_registry.build(config, class_names) diff --git a/src/batdetect2/evaluate/plots.py b/src/batdetect2/evaluate/plots.py new file mode 100644 index 0000000..0a398c7 --- /dev/null +++ b/src/batdetect2/evaluate/plots.py @@ -0,0 +1,163 @@ +import random +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union + +import matplotlib.pyplot as plt +import pandas as pd +from pydantic import Field + +from batdetect2.configs import BaseConfig +from batdetect2.data._core import Registry +from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader +from batdetect2.plotting.gallery import plot_match_gallery +from batdetect2.preprocess import PreprocessingConfig, build_preprocessor +from batdetect2.typing.evaluate import ( + ClipEvaluation, + MatchEvaluation, + PlotterProtocol, +) +from batdetect2.typing.preprocess import AudioLoader + +__all__ = [ + "build_plotter", + "ExampleGallery", + "ExampleGalleryConfig", +] + + +plots_registry: Registry[PlotterProtocol, []] = Registry("plot") + + +class ExampleGalleryConfig(BaseConfig): + name: Literal["example_gallery"] = "example_gallery" + examples_per_class: int = 5 + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + + +class ExampleGallery(PlotterProtocol): + def __init__( + self, + examples_per_class: int, + preprocessor: Optional[PreprocessorProtocol] = None, + audio_loader: Optional[AudioLoader] = None, + ): + self.examples_per_class = examples_per_class + self.preprocessor = preprocessor or build_preprocessor() + self.audio_loader = audio_loader or build_audio_loader() + + def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + per_class_matches = group_matches(clip_evaluations) + + for class_name, matches in per_class_matches.items(): + true_positives = get_binned_sample( + matches.true_positives, + n_examples=self.examples_per_class, + ) + + false_positives = get_binned_sample( + matches.false_positives, + n_examples=self.examples_per_class, + ) + + false_negatives = random.sample( + matches.false_negatives, + k=min(self.examples_per_class, len(matches.false_negatives)), + ) + + cross_triggers = get_binned_sample( + matches.cross_triggers, + n_examples=self.examples_per_class, + ) + + fig = plot_match_gallery( + true_positives, + false_positives, + false_negatives, + cross_triggers, + preprocessor=self.preprocessor, + audio_loader=self.audio_loader, + n_examples=self.examples_per_class, + ) + + yield f"example_gallery/{class_name}", fig + + plt.close(fig) + + @classmethod + def from_config(cls, config: ExampleGalleryConfig): + preprocessor = build_preprocessor(config.preprocessing) + audio_loader = build_audio_loader(config.preprocessing.audio) + return cls( + examples_per_class=config.examples_per_class, + preprocessor=preprocessor, + audio_loader=audio_loader, + ) + + +plots_registry.register(ExampleGalleryConfig, ExampleGallery) + + +PlotConfig = Annotated[ + Union[ExampleGalleryConfig,], Field(discriminator="name") +] + + +def build_plotter(config: PlotConfig) -> PlotterProtocol: + return plots_registry.build(config) + + +@dataclass +class ClassMatches: + false_positives: List[MatchEvaluation] = field(default_factory=list) + false_negatives: List[MatchEvaluation] = field(default_factory=list) + true_positives: List[MatchEvaluation] = field(default_factory=list) + cross_triggers: List[MatchEvaluation] = field(default_factory=list) + + +def group_matches( + clip_evaluations: Sequence[ClipEvaluation], +) -> Dict[str, ClassMatches]: + class_examples = defaultdict(ClassMatches) + + for clip_evaluation in clip_evaluations: + for match in clip_evaluation.matches: + gt_class = match.gt_class + pred_class = match.pred_class + + if pred_class is None: + class_examples[gt_class].false_negatives.append(match) + continue + + if gt_class is None: + class_examples[pred_class].false_positives.append(match) + continue + + if gt_class != pred_class: + class_examples[gt_class].cross_triggers.append(match) + class_examples[pred_class].cross_triggers.append(match) + continue + + class_examples[gt_class].true_positives.append(match) + + return class_examples + + +def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5): + if len(matches) < n_examples: + return matches + + indices, pred_scores = zip( + *[ + (index, match.pred_class_scores[pred_class]) + for index, match in enumerate(matches) + if (pred_class := match.pred_class) is not None + ] + ) + + bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") + df = pd.DataFrame({"indices": indices, "bins": bins}) + sample = df.groupby("bins").sample(1) + return [matches[ind] for ind in sample["indices"]] diff --git a/src/batdetect2/plotting/__init__.py b/src/batdetect2/plotting/__init__.py index acf14fb..824ef86 100644 --- a/src/batdetect2/plotting/__init__.py +++ b/src/batdetect2/plotting/__init__.py @@ -2,6 +2,7 @@ from batdetect2.plotting.clip_annotations import plot_clip_annotation from batdetect2.plotting.clip_predictions import plot_clip_prediction from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.common import plot_spectrogram +from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.heatmaps import ( plot_classification_heatmap, plot_detection_heatmap, @@ -26,4 +27,5 @@ __all__ = [ "plot_true_positive_match", "plot_detection_heatmap", "plot_classification_heatmap", + "plot_match_gallery", ] diff --git a/src/batdetect2/plotting/evaluation.py b/src/batdetect2/plotting/evaluation.py deleted file mode 100644 index 6e07f1e..0000000 --- a/src/batdetect2/plotting/evaluation.py +++ /dev/null @@ -1,160 +0,0 @@ -import random -from collections import defaultdict -from dataclasses import dataclass, field -from typing import List - -import matplotlib.pyplot as plt -import pandas as pd - -from batdetect2 import plotting -from batdetect2.typing.evaluate import MatchEvaluation -from batdetect2.typing.preprocess import PreprocessorProtocol - - -@dataclass -class ClassExamples: - false_positives: List[MatchEvaluation] = field(default_factory=list) - false_negatives: List[MatchEvaluation] = field(default_factory=list) - true_positives: List[MatchEvaluation] = field(default_factory=list) - cross_triggers: List[MatchEvaluation] = field(default_factory=list) - - -def plot_example_gallery( - matches: List[MatchEvaluation], - preprocessor: PreprocessorProtocol, - n_examples: int = 5, -): - class_examples = defaultdict(ClassExamples) - - for match in matches: - gt_class = match.gt_class - pred_class = match.pred_class - - if pred_class is None: - class_examples[gt_class].false_negatives.append(match) - continue - - if gt_class is None: - class_examples[pred_class].false_positives.append(match) - continue - - if gt_class != pred_class: - class_examples[gt_class].cross_triggers.append(match) - class_examples[pred_class].cross_triggers.append(match) - continue - - class_examples[gt_class].true_positives.append(match) - - for class_name, examples in class_examples.items(): - true_positives = get_binned_sample( - examples.true_positives, - n_examples=n_examples, - ) - - false_positives = get_binned_sample( - examples.false_positives, - n_examples=n_examples, - ) - - false_negatives = random.sample( - examples.false_negatives, - k=min(n_examples, len(examples.false_negatives)), - ) - - cross_triggers = get_binned_sample( - examples.cross_triggers, - n_examples=n_examples, - ) - - fig = plot_class_examples( - true_positives, - false_positives, - false_negatives, - cross_triggers, - preprocessor=preprocessor, - n_examples=n_examples, - ) - - yield class_name, fig - - plt.close(fig) - - -def plot_class_examples( - true_positives: List[MatchEvaluation], - false_positives: List[MatchEvaluation], - false_negatives: List[MatchEvaluation], - cross_triggers: List[MatchEvaluation], - preprocessor: PreprocessorProtocol, - n_examples: int = 5, - duration: float = 0.1, -): - fig = plt.figure(figsize=(20, 20)) - - for index, match in enumerate(true_positives[:n_examples]): - ax = plt.subplot(4, n_examples, index + 1) - try: - plotting.plot_true_positive_match( - match, - ax=ax, - preprocessor=preprocessor, - duration=duration, - ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): - continue - - for index, match in enumerate(false_positives[:n_examples]): - ax = plt.subplot(4, n_examples, n_examples + index + 1) - try: - plotting.plot_false_positive_match( - match, - ax=ax, - preprocessor=preprocessor, - duration=duration, - ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): - continue - - for index, match in enumerate(false_negatives[:n_examples]): - ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1) - try: - plotting.plot_false_negative_match( - match, - ax=ax, - preprocessor=preprocessor, - duration=duration, - ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): - continue - - for index, match in enumerate(cross_triggers[:n_examples]): - ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1) - try: - plotting.plot_cross_trigger_match( - match, - ax=ax, - preprocessor=preprocessor, - duration=duration, - ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): - continue - - return fig - - -def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5): - if len(matches) < n_examples: - return matches - - indices, pred_scores = zip( - *[ - (index, match.pred_class_scores[pred_class]) - for index, match in enumerate(matches) - if (pred_class := match.pred_class) is not None - ] - ) - - bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") - df = pd.DataFrame({"indices": indices, "bins": bins}) - sample = df.groupby("bins").sample(1) - return [matches[ind] for ind in sample["indices"]] diff --git a/src/batdetect2/plotting/gallery.py b/src/batdetect2/plotting/gallery.py new file mode 100644 index 0000000..175f4c7 --- /dev/null +++ b/src/batdetect2/plotting/gallery.py @@ -0,0 +1,81 @@ +from typing import List, Optional + +import matplotlib.pyplot as plt + +from batdetect2.plotting.matches import ( + plot_cross_trigger_match, + plot_false_negative_match, + plot_false_positive_match, + plot_true_positive_match, +) +from batdetect2.typing.evaluate import MatchEvaluation +from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol + +__all__ = ["plot_match_gallery"] + + +def plot_match_gallery( + true_positives: List[MatchEvaluation], + false_positives: List[MatchEvaluation], + false_negatives: List[MatchEvaluation], + cross_triggers: List[MatchEvaluation], + audio_loader: Optional[AudioLoader] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + n_examples: int = 5, + duration: float = 0.1, +): + fig = plt.figure(figsize=(20, 20)) + + for index, match in enumerate(true_positives[:n_examples]): + ax = plt.subplot(4, n_examples, index + 1) + try: + plot_true_positive_match( + match, + ax=ax, + audio_loader=audio_loader, + preprocessor=preprocessor, + duration=duration, + ) + except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + continue + + for index, match in enumerate(false_positives[:n_examples]): + ax = plt.subplot(4, n_examples, n_examples + index + 1) + try: + plot_false_positive_match( + match, + ax=ax, + audio_loader=audio_loader, + preprocessor=preprocessor, + duration=duration, + ) + except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + continue + + for index, match in enumerate(false_negatives[:n_examples]): + ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1) + try: + plot_false_negative_match( + match, + ax=ax, + audio_loader=audio_loader, + preprocessor=preprocessor, + duration=duration, + ) + except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + continue + + for index, match in enumerate(cross_triggers[:n_examples]): + ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1) + try: + plot_cross_trigger_match( + match, + ax=ax, + audio_loader=audio_loader, + preprocessor=preprocessor, + duration=duration, + ) + except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + continue + + return fig diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index c584bea..ccbe718 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -7,8 +7,8 @@ from soundevent.geometry import compute_bounds from soundevent.plot.tags import TagColorMapper from batdetect2.plotting.clip_predictions import plot_prediction -from batdetect2.plotting.clips import plot_clip -from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor +from batdetect2.plotting.clips import AudioLoader, plot_clip +from batdetect2.preprocess import PreprocessorProtocol from batdetect2.typing.evaluate import MatchEvaluation __all__ = [ @@ -32,6 +32,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--" def plot_matches( matches: List[data.Match], clip: data.Clip, + audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, @@ -46,12 +47,11 @@ def plot_matches( annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, ) -> Axes: - if preprocessor is None: - preprocessor = build_preprocessor() - ax = plot_clip( clip, ax=ax, + audio_loader=audio_loader, + preprocessor=preprocessor, figsize=figsize, audio_dir=audio_dir, spec_cmap=spec_cmap, @@ -116,6 +116,7 @@ def plot_matches( def plot_false_positive_match( match: MatchEvaluation, + audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, @@ -143,6 +144,7 @@ def plot_false_positive_match( ax = plot_clip( clip, + audio_loader=audio_loader, preprocessor=preprocessor, figsize=figsize, ax=ax, @@ -174,6 +176,7 @@ def plot_false_positive_match( def plot_false_negative_match( match: MatchEvaluation, + audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, @@ -203,6 +206,7 @@ def plot_false_negative_match( ax = plot_clip( clip, + audio_loader=audio_loader, preprocessor=preprocessor, figsize=figsize, ax=ax, @@ -237,6 +241,7 @@ def plot_false_negative_match( def plot_true_positive_match( match: MatchEvaluation, preprocessor: Optional[PreprocessorProtocol] = None, + audio_loader: Optional[AudioLoader] = None, figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, @@ -267,6 +272,7 @@ def plot_true_positive_match( ax = plot_clip( clip, + audio_loader=audio_loader, preprocessor=preprocessor, figsize=figsize, ax=ax, @@ -312,6 +318,7 @@ def plot_true_positive_match( def plot_cross_trigger_match( match: MatchEvaluation, preprocessor: Optional[PreprocessorProtocol] = None, + audio_loader: Optional[AudioLoader] = None, figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, @@ -342,6 +349,7 @@ def plot_cross_trigger_match( ax = plot_clip( clip, + audio_loader=audio_loader, preprocessor=preprocessor, figsize=figsize, ax=ax, diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index d759766..7607ce0 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -1,49 +1,26 @@ -from typing import List, Optional +from typing import List from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback from soundevent import data from torch.utils.data import DataLoader -from batdetect2.evaluate.match import ( - MatchConfig, - build_matcher, - match_all_predictions, -) -from batdetect2.plotting.clips import PreprocessorProtocol -from batdetect2.plotting.evaluation import plot_example_gallery +from batdetect2.evaluate import Evaluator from batdetect2.postprocess import get_raw_predictions from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import get_image_plotter -from batdetect2.typing import ( - MatchEvaluation, - MetricsProtocol, -) +from batdetect2.typing.evaluate import ClipEvaluation from batdetect2.typing.models import ModelOutput from batdetect2.typing.postprocess import RawPrediction from batdetect2.typing.train import TrainExample class ValidationMetrics(Callback): - def __init__( - self, - metrics: List[MetricsProtocol], - preprocessor: PreprocessorProtocol, - plot: bool = True, - match_config: Optional[MatchConfig] = None, - ): + def __init__(self, evaluator: Evaluator): super().__init__() - if len(metrics) == 0: - raise ValueError("At least one metric needs to be provided") - - self.match_config = match_config - self.metrics = metrics - self.preprocessor = preprocessor - self.plot = plot - - self.matcher = build_matcher(config=match_config) + self.evaluator = evaluator self._clip_annotations: List[data.ClipAnnotation] = [] self._predictions: List[List[RawPrediction]] = [] @@ -58,33 +35,22 @@ class ValidationMetrics(Callback): def plot_examples( self, pl_module: LightningModule, - matches: List[MatchEvaluation], + evaluated_clips: List[ClipEvaluation], ): plotter = get_image_plotter(pl_module.logger) # type: ignore if plotter is None: return - for class_name, fig in plot_example_gallery( - matches, - preprocessor=self.preprocessor, - n_examples=4, - ): - plotter( - f"examples/{class_name}", - fig, - pl_module.global_step, - ) + for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): + plotter(figure_name, fig, pl_module.global_step) def log_metrics( self, pl_module: LightningModule, - matches: List[MatchEvaluation], + evaluated_clips: List[ClipEvaluation], ): - metrics = {} - for metric in self.metrics: - metrics.update(metric(matches).items()) - + metrics = self.evaluator.compute_metrics(evaluated_clips) pl_module.log_dict(metrics) def on_validation_epoch_end( @@ -92,17 +58,13 @@ class ValidationMetrics(Callback): trainer: Trainer, pl_module: LightningModule, ) -> None: - matches = match_all_predictions( + clip_evaluations = self.evaluator.evaluate( self._clip_annotations, self._predictions, - targets=pl_module.model.targets, - matcher=self.matcher, ) - self.log_metrics(pl_module, matches) - - if self.plot: - self.plot_examples(pl_module, matches) + self.log_metrics(pl_module, clip_evaluations) + self.plot_examples(pl_module, clip_evaluations) return super().on_validation_epoch_end(trainer, pl_module) diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 6333ebb..a91fc49 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -14,7 +14,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_MAX_EMPTY_CLIP = 0.1 -registry: Registry[ClipperProtocol] = Registry("clipper") +clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper") class RandomClipConfig(BaseConfig): @@ -25,7 +25,6 @@ class RandomClipConfig(BaseConfig): min_sound_event_overlap: float = 0 -@registry.register(RandomClipConfig) class RandomClip: def __init__( self, @@ -61,6 +60,9 @@ class RandomClip: ) +clipper_registry.register(RandomClipConfig, RandomClip) + + def get_subclip_annotation( clip_annotation: data.ClipAnnotation, random: bool = True, @@ -156,7 +158,6 @@ class PaddedClipConfig(BaseConfig): chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION -@registry.register(PaddedClipConfig) class PaddedClip: def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION): self.chunk_size = chunk_size @@ -183,6 +184,8 @@ class PaddedClip: return cls(chunk_size=config.chunk_size) +clipper_registry.register(PaddedClipConfig, PaddedClip) + ClipConfig = Annotated[ Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") ] @@ -195,4 +198,4 @@ def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol: "Building clipper with config: \n{}", lambda: config.to_yaml_string(), ) - return registry.build(config) + return clipper_registry.build(config) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 9822191..8837da4 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -10,10 +10,7 @@ from soundevent import data from torch.utils.data import DataLoader from batdetect2.evaluate.config import EvaluationConfig -from batdetect2.evaluate.metrics import ( - ClassificationMeanAveragePrecision, - DetectionAveragePrecision, -) +from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.plotting.clips import AudioLoader, build_audio_loader from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets @@ -146,7 +143,6 @@ def build_training_module( def build_trainer_callbacks( targets: TargetProtocol, - preprocessor: PreprocessorProtocol, config: EvaluationConfig, checkpoint_dir: Optional[Path] = None, experiment_name: Optional[str] = None, @@ -161,6 +157,8 @@ def build_trainer_callbacks( if run_name is not None: checkpoint_dir = checkpoint_dir / run_name + evaluator = build_evaluator(config=config, targets=targets) + return [ ModelCheckpoint( dirpath=str(checkpoint_dir), @@ -168,16 +166,7 @@ def build_trainer_callbacks( filename="best-{epoch:02d}-{val_loss:.0f}", monitor="total_loss/val", ), - ValidationMetrics( - metrics=[ - DetectionAveragePrecision(), - ClassificationMeanAveragePrecision( - class_names=targets.class_names - ), - ], - preprocessor=preprocessor, - match_config=config.match, - ), + ValidationMetrics(evaluator), ] @@ -214,7 +203,6 @@ def build_trainer( callbacks=build_trainer_callbacks( targets, config=conf.evaluation, - preprocessor=build_preprocessor(conf.preprocess), checkpoint_dir=checkpoint_dir, experiment_name=experiment_name, run_name=run_name, diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index b706bed..e3bf7e0 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -11,6 +11,7 @@ from typing import ( TypeVar, ) +from matplotlib.figure import Figure from soundevent import data __all__ = [ @@ -50,6 +51,12 @@ class MatchEvaluation: return self.pred_class_scores[pred_class] +@dataclass +class ClipEvaluation: + clip: data.Clip + matches: List[MatchEvaluation] + + class MatcherProtocol(Protocol): def __call__( self, @@ -67,10 +74,16 @@ class AffinityFunction(Protocol, Generic[Geom]): self, geometry1: Geom, geometry2: Geom, - time_buffer: float = 0.01, - freq_buffer: float = 1000, ) -> float: ... class MetricsProtocol(Protocol): - def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ... + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: ... + + +class PlotterProtocol(Protocol): + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Iterable[Tuple[str, Figure]]: ...