diff --git a/example_data/config.yaml b/example_data/config.yaml index bd4fa00..8cbf4e5 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -140,13 +140,14 @@ train: validation: metrics: - name: detection_ap - - name: detection_roc_auc - name: classification_ap - - name: classification_roc_auc - - name: top_class_ap - - name: classification_balanced_accuracy - - name: clip_ap - - name: clip_roc_auc + plots: + - name: example_gallery + - name: example_clip + - name: detection_pr_curve + - name: classification_pr_curves + - name: detection_roc_curve + - name: classification_roc_curves evaluation: match_strategy: @@ -155,6 +156,14 @@ evaluation: metrics: - name: classification_ap - name: detection_ap + - name: detection_roc_auc + - name: classification_roc_auc + - name: top_class_ap + - name: classification_balanced_accuracy + - name: clip_multiclass_ap + - name: clip_multiclass_roc_auc + - name: clip_detection_ap + - name: clip_detection_roc_auc plots: - name: example_gallery - name: example_clip diff --git a/src/batdetect2/api/base.py b/src/batdetect2/api/base.py index cdfc843..556db7d 100644 --- a/src/batdetect2/api/base.py +++ b/src/batdetect2/api/base.py @@ -1,6 +1,7 @@ from pathlib import Path -from typing import Optional, Sequence +from typing import List, Optional, Sequence +import torch from soundevent import data from batdetect2.audio import build_audio_loader @@ -8,6 +9,7 @@ from batdetect2.config import BatDetect2Config from batdetect2.evaluate import build_evaluator, evaluate from batdetect2.models import Model, build_model from batdetect2.postprocess import build_postprocessor +from batdetect2.postprocess.decoding import to_raw_predictions from batdetect2.preprocess import build_preprocessor from batdetect2.targets.targets import build_targets from batdetect2.train import train @@ -19,6 +21,7 @@ from batdetect2.typing import ( PreprocessorProtocol, TargetProtocol, ) +from batdetect2.typing.postprocess import RawPrediction class BatDetect2API: @@ -92,6 +95,18 @@ class BatDetect2API: run_name=run_name, ) + def process_spectrogram( + self, + spec: torch.Tensor, + start_times: Optional[Sequence[float]] = None, + ) -> List[List[RawPrediction]]: + outputs = self.model.detector(spec) + clip_detections = self.postprocessor(outputs, start_times=start_times) + return [ + to_raw_predictions(clip_dets.numpy(), self.targets) + for clip_dets in clip_detections + ] + @classmethod def from_config(cls, config: BatDetect2Config): targets = build_targets(config=config.targets) @@ -109,7 +124,7 @@ class BatDetect2API: ) evaluator = build_evaluator( - config=config.evaluation, + config=config.evaluation.evaluator, targets=targets, ) @@ -164,7 +179,7 @@ class BatDetect2API: ) evaluator = build_evaluator( - config=config.evaluation, + config=config.evaluation.evaluator, targets=targets, ) diff --git a/src/batdetect2/audio/clips.py b/src/batdetect2/audio/clips.py index 77100f3..86ddf18 100644 --- a/src/batdetect2/audio/clips.py +++ b/src/batdetect2/audio/clips.py @@ -56,18 +56,16 @@ class RandomClip: min_sound_event_overlap=self.min_sound_event_overlap, ) - @classmethod - def from_config(cls, config: RandomClipConfig): - return cls( + @clipper_registry.register(RandomClipConfig) + @staticmethod + def from_config(config: RandomClipConfig): + return RandomClip( duration=config.duration, max_empty=config.max_empty, min_sound_event_overlap=config.min_sound_event_overlap, ) -clipper_registry.register(RandomClipConfig, RandomClip) - - def get_subclip_annotation( clip_annotation: data.ClipAnnotation, random: bool = True, @@ -184,13 +182,12 @@ class PaddedClip: ) return clip_annotation.model_copy(update=dict(clip=clip)) - @classmethod - def from_config(cls, config: PaddedClipConfig): - return cls(chunk_size=config.chunk_size) + @clipper_registry.register(PaddedClipConfig) + @staticmethod + def from_config(config: PaddedClipConfig): + return PaddedClip(chunk_size=config.chunk_size) -clipper_registry.register(PaddedClipConfig, PaddedClip) - ClipConfig = Annotated[ Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") ] diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index c7ffcd3..7513d73 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -53,6 +53,7 @@ class BaseConfig(BaseModel): """ return yaml.dump( self.model_dump( + mode="json", exclude_none=exclude_none, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, diff --git a/src/batdetect2/core/registries.py b/src/batdetect2/core/registries.py index 9bd2b09..d059e8c 100644 --- a/src/batdetect2/core/registries.py +++ b/src/batdetect2/core/registries.py @@ -1,16 +1,16 @@ import sys -from typing import Generic, Protocol, Type, TypeVar +from typing import Callable, Dict, Generic, Tuple, Type, TypeVar from pydantic import BaseModel -from typing_extensions import assert_type if sys.version_info >= (3, 10): - from typing import ParamSpec + from typing import Concatenate, ParamSpec else: - from typing_extensions import ParamSpec + from typing_extensions import Concatenate, ParamSpec __all__ = [ "Registry", + "SimpleRegistry", ] T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) @@ -18,19 +18,26 @@ T_Type = TypeVar("T_Type", covariant=True) P_Type = ParamSpec("P_Type") -class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol): - """A generic protocol for the logic classes.""" - - @classmethod - def from_config( - cls, - config: T_Config, - *args: P_Type.args, - **kwargs: P_Type.kwargs, - ) -> T_Type: ... +T = TypeVar("T") -T_Proto = TypeVar("T_Proto", bound=LogicProtocol) +class SimpleRegistry(Generic[T]): + def __init__(self, name: str): + self._name = name + self._registry = {} + + def register(self, name: str): + def decorator(obj: T) -> T: + self._registry[name] = obj + return obj + + return decorator + + def get(self, name: str) -> T: + return self._registry[name] + + def has(self, name: str) -> bool: + return name in self._registry class Registry(Generic[T_Type, P_Type]): @@ -38,13 +45,15 @@ class Registry(Generic[T_Type, P_Type]): def __init__(self, name: str): self._name = name - self._registry = {} + self._registry: Dict[ + str, Callable[Concatenate[..., P_Type], T_Type] + ] = {} + self._config_types: Dict[str, Type[BaseModel]] = {} def register( self, config_cls: Type[T_Config], - logic_cls: LogicProtocol[T_Config, T_Type, P_Type], - ) -> None: + ): fields = config_cls.model_fields if "name" not in fields: @@ -52,10 +61,21 @@ class Registry(Generic[T_Type, P_Type]): name = fields["name"].default + self._config_types[name] = config_cls + if not isinstance(name, str): raise ValueError("'name' field must be a string literal.") - self._registry[name] = logic_cls + def decorator( + func: Callable[Concatenate[T_Config, P_Type], T_Type], + ): + self._registry[name] = func + return func + + return decorator + + def get_config_types(self) -> Tuple[Type[BaseModel], ...]: + return tuple(self._config_types.values()) def build( self, @@ -75,4 +95,4 @@ class Registry(Generic[T_Type, P_Type]): f"No {self._name} with name '{name}' is registered." ) - return self._registry[name].from_config(config, *args, **kwargs) + return self._registry[name](config, *args, **kwargs) diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py index 54c082b..4ee8590 100644 --- a/src/batdetect2/data/conditions.py +++ b/src/batdetect2/data/conditions.py @@ -10,7 +10,7 @@ from batdetect2.core.registries import Registry SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] -condition_registry: Registry[SoundEventCondition, []] = Registry("condition") +conditions: Registry[SoundEventCondition, []] = Registry("condition") class HasTagConfig(BaseConfig): @@ -27,12 +27,10 @@ class HasTag: ) -> bool: return self.tag in sound_event_annotation.tags - @classmethod - def from_config(cls, config: HasTagConfig): - return cls(tag=config.tag) - - -condition_registry.register(HasTagConfig, HasTag) + @conditions.register(HasTagConfig) + @staticmethod + def from_config(config: HasTagConfig): + return HasTag(tag=config.tag) class HasAllTagsConfig(BaseConfig): @@ -52,12 +50,10 @@ class HasAllTags: ) -> bool: return self.tags.issubset(sound_event_annotation.tags) - @classmethod - def from_config(cls, config: HasAllTagsConfig): - return cls(tags=config.tags) - - -condition_registry.register(HasAllTagsConfig, HasAllTags) + @conditions.register(HasAllTagsConfig) + @staticmethod + def from_config(config: HasAllTagsConfig): + return HasAllTags(tags=config.tags) class HasAnyTagConfig(BaseConfig): @@ -77,13 +73,12 @@ class HasAnyTag: ) -> bool: return bool(self.tags.intersection(sound_event_annotation.tags)) - @classmethod - def from_config(cls, config: HasAnyTagConfig): - return cls(tags=config.tags) + @conditions.register(HasAnyTagConfig) + @staticmethod + def from_config(config: HasAnyTagConfig): + return HasAnyTag(tags=config.tags) -condition_registry.register(HasAnyTagConfig, HasAnyTag) - Operator = Literal["gt", "gte", "lt", "lte", "eq"] @@ -134,12 +129,10 @@ class Duration: return self._comparator(duration) - @classmethod - def from_config(cls, config: DurationConfig): - return cls(operator=config.operator, seconds=config.seconds) - - -condition_registry.register(DurationConfig, Duration) + @conditions.register(DurationConfig) + @staticmethod + def from_config(config: DurationConfig): + return Duration(operator=config.operator, seconds=config.seconds) class FrequencyConfig(BaseConfig): @@ -181,18 +174,16 @@ class Frequency: return self._comparator(high_freq) - @classmethod - def from_config(cls, config: FrequencyConfig): - return cls( + @conditions.register(FrequencyConfig) + @staticmethod + def from_config(config: FrequencyConfig): + return Frequency( operator=config.operator, boundary=config.boundary, hertz=config.hertz, ) -condition_registry.register(FrequencyConfig, Frequency) - - class AllOfConfig(BaseConfig): name: Literal["all_of"] = "all_of" conditions: Sequence["SoundEventConditionConfig"] @@ -207,15 +198,13 @@ class AllOf: ) -> bool: return all(c(sound_event_annotation) for c in self.conditions) - @classmethod - def from_config(cls, config: AllOfConfig): + @conditions.register(AllOfConfig) + @staticmethod + def from_config(config: AllOfConfig): conditions = [ build_sound_event_condition(cond) for cond in config.conditions ] - return cls(conditions) - - -condition_registry.register(AllOfConfig, AllOf) + return AllOf(conditions) class AnyOfConfig(BaseConfig): @@ -232,15 +221,13 @@ class AnyOf: ) -> bool: return any(c(sound_event_annotation) for c in self.conditions) - @classmethod - def from_config(cls, config: AnyOfConfig): + @conditions.register(AnyOfConfig) + @staticmethod + def from_config(config: AnyOfConfig): conditions = [ build_sound_event_condition(cond) for cond in config.conditions ] - return cls(conditions) - - -condition_registry.register(AnyOfConfig, AnyOf) + return AnyOf(conditions) class NotConfig(BaseConfig): @@ -257,14 +244,13 @@ class Not: ) -> bool: return not self.condition(sound_event_annotation) - @classmethod - def from_config(cls, config: NotConfig): + @conditions.register(NotConfig) + @staticmethod + def from_config(config: NotConfig): condition = build_sound_event_condition(config.condition) - return cls(condition) + return Not(condition) -condition_registry.register(NotConfig, Not) - SoundEventConditionConfig = Annotated[ Union[ HasTagConfig, @@ -283,7 +269,7 @@ SoundEventConditionConfig = Annotated[ def build_sound_event_condition( config: SoundEventConditionConfig, ) -> SoundEventCondition: - return condition_registry.build(config) + return conditions.build(config) def filter_clip_annotation( diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py index 62dd1c4..a826b7d 100644 --- a/src/batdetect2/data/transforms.py +++ b/src/batdetect2/data/transforms.py @@ -17,7 +17,7 @@ SoundEventTransform = Callable[ data.SoundEventAnnotation, ] -transform_registry: Registry[SoundEventTransform, []] = Registry("transform") +transforms: Registry[SoundEventTransform, []] = Registry("transform") class SetFrequencyBoundConfig(BaseConfig): @@ -63,12 +63,10 @@ class SetFrequencyBound: update=dict(sound_event=sound_event) ) - @classmethod - def from_config(cls, config: SetFrequencyBoundConfig): - return cls(hertz=config.hertz, boundary=config.boundary) - - -transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound) + @transforms.register(SetFrequencyBoundConfig) + @staticmethod + def from_config(config: SetFrequencyBoundConfig): + return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary) class ApplyIfConfig(BaseConfig): @@ -95,14 +93,12 @@ class ApplyIf: return self.transform(sound_event_annotation) - @classmethod - def from_config(cls, config: ApplyIfConfig): + @transforms.register(ApplyIfConfig) + @staticmethod + def from_config(config: ApplyIfConfig): transform = build_sound_event_transform(config.transform) condition = build_sound_event_condition(config.condition) - return cls(condition=condition, transform=transform) - - -transform_registry.register(ApplyIfConfig, ApplyIf) + return ApplyIf(condition=condition, transform=transform) class ReplaceTagConfig(BaseConfig): @@ -134,12 +130,12 @@ class ReplaceTag: return sound_event_annotation.model_copy(update=dict(tags=tags)) - @classmethod - def from_config(cls, config: ReplaceTagConfig): - return cls(original=config.original, replacement=config.replacement) - - -transform_registry.register(ReplaceTagConfig, ReplaceTag) + @transforms.register(ReplaceTagConfig) + @staticmethod + def from_config(config: ReplaceTagConfig): + return ReplaceTag( + original=config.original, replacement=config.replacement + ) class MapTagValueConfig(BaseConfig): @@ -189,18 +185,16 @@ class MapTagValue: return sound_event_annotation.model_copy(update=dict(tags=tags)) - @classmethod - def from_config(cls, config: MapTagValueConfig): - return cls( + @transforms.register(MapTagValueConfig) + @staticmethod + def from_config(config: MapTagValueConfig): + return MapTagValue( tag_key=config.tag_key, value_mapping=config.value_mapping, target_key=config.target_key, ) -transform_registry.register(MapTagValueConfig, MapTagValue) - - class ApplyAllConfig(BaseConfig): name: Literal["apply_all"] = "apply_all" steps: List["SoundEventTransformConfig"] = Field(default_factory=list) @@ -219,14 +213,13 @@ class ApplyAll: return sound_event_annotation - @classmethod - def from_config(cls, config: ApplyAllConfig): + @transforms.register(ApplyAllConfig) + @staticmethod + def from_config(config: ApplyAllConfig): steps = [build_sound_event_transform(step) for step in config.steps] - return cls(steps) + return ApplyAll(steps) -transform_registry.register(ApplyAllConfig, ApplyAll) - SoundEventTransformConfig = Annotated[ Union[ SetFrequencyBoundConfig, @@ -242,7 +235,7 @@ SoundEventTransformConfig = Annotated[ def build_sound_event_transform( config: SoundEventTransformConfig, ) -> SoundEventTransform: - return transform_registry.build(config) + return transforms.build(config) def transform_clip_annotation( diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index 3e02ed0..03d31db 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,11 +1,11 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.evaluate import evaluate -from batdetect2.evaluate.evaluator import Evaluator, build_evaluator +from batdetect2.evaluate.evaluator import MultipleEvaluator, build_evaluator __all__ = [ "EvaluationConfig", "load_evaluation_config", "evaluate", - "Evaluator", + "MultipleEvaluator", "build_evaluator", ] diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index 5a2ab91..1ffa868 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -27,12 +27,10 @@ class TimeAffinity(AffinityFunction): geometry1, geometry2, time_buffer=self.time_buffer ) - @classmethod - def from_config(cls, config: TimeAffinityConfig): - return cls(time_buffer=config.time_buffer) - - -affinity_functions.register(TimeAffinityConfig, TimeAffinity) + @affinity_functions.register(TimeAffinityConfig) + @staticmethod + def from_config(config: TimeAffinityConfig): + return TimeAffinity(time_buffer=config.time_buffer) def compute_timestamp_affinity( @@ -73,12 +71,10 @@ class IntervalIOU(AffinityFunction): time_buffer=self.time_buffer, ) - @classmethod - def from_config(cls, config: IntervalIOUConfig): - return cls(time_buffer=config.time_buffer) - - -affinity_functions.register(IntervalIOUConfig, IntervalIOU) + @affinity_functions.register(IntervalIOUConfig) + @staticmethod + def from_config(config: IntervalIOUConfig): + return IntervalIOU(time_buffer=config.time_buffer) def compute_interval_iou( @@ -127,13 +123,12 @@ class GeometricIOU(AffinityFunction): time_buffer=self.time_buffer, ) - @classmethod - def from_config(cls, config: GeometricIOUConfig): - return cls(time_buffer=config.time_buffer) + @affinity_functions.register(GeometricIOUConfig) + @staticmethod + def from_config(config: GeometricIOUConfig): + return GeometricIOU(time_buffer=config.time_buffer) -affinity_functions.register(GeometricIOUConfig, GeometricIOU) - AffinityConfig = Annotated[ Union[ TimeAffinityConfig, diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index 2ed5bf3..de1ffae 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -1,16 +1,13 @@ -from typing import List, Optional +from typing import Optional from pydantic import Field from soundevent import data from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig -from batdetect2.evaluate.metrics import ( - ClassificationAPConfig, - DetectionAPConfig, - MetricConfig, +from batdetect2.evaluate.evaluator import ( + EvaluatorConfig, + MultipleEvaluatorConfig, ) -from batdetect2.evaluate.plots import PlotConfig from batdetect2.logging import CSVLoggerConfig, LoggerConfig __all__ = [ @@ -20,15 +17,7 @@ __all__ = [ class EvaluationConfig(BaseConfig): - ignore_start_end: float = 0.01 - match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig) - metrics: List[MetricConfig] = Field( - default_factory=lambda: [ - DetectionAPConfig(), - ClassificationAPConfig(), - ] - ) - plots: List[PlotConfig] = Field(default_factory=list) + evaluator: EvaluatorConfig = Field(default_factory=MultipleEvaluatorConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 2fd723f..a151107 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -55,7 +55,10 @@ def evaluate( num_workers=num_workers, ) - evaluator = build_evaluator(config=config.evaluation, targets=targets) + evaluator = build_evaluator( + config=config.evaluation.evaluator, + targets=targets, + ) logger = build_logger( config.evaluation.logger, diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py deleted file mode 100644 index dbbbee2..0000000 --- a/src/batdetect2/evaluate/evaluator.py +++ /dev/null @@ -1,173 +0,0 @@ -from typing import Dict, Iterable, List, Optional, Sequence, Tuple - -from matplotlib.figure import Figure -from soundevent import data -from soundevent.geometry import compute_bounds - -from batdetect2.evaluate.config import EvaluationConfig -from batdetect2.evaluate.match import build_matcher, match -from batdetect2.evaluate.metrics import build_metric -from batdetect2.evaluate.plots import build_plotter -from batdetect2.targets import build_targets -from batdetect2.typing.evaluate import ( - ClipEvaluation, - EvaluatorProtocol, - MatcherProtocol, - MetricsProtocol, - PlotterProtocol, -) -from batdetect2.typing.postprocess import RawPrediction -from batdetect2.typing.targets import TargetProtocol - -__all__ = [ - "Evaluator", - "build_evaluator", -] - - -class Evaluator: - def __init__( - self, - config: EvaluationConfig, - targets: TargetProtocol, - matcher: MatcherProtocol, - metrics: List[MetricsProtocol], - plots: List[PlotterProtocol], - ): - self.config = config - self.targets = targets - self.matcher = matcher - self.metrics = metrics - self.plots = plots - - def match( - self, - clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], - ) -> ClipEvaluation: - clip = clip_annotation.clip - ground_truth = [ - sound_event - for sound_event in clip_annotation.sound_events - if self.filter_sound_event_annotations(sound_event, clip) - ] - predictions = [ - prediction - for prediction in predictions - if self.filter_predictions(prediction, clip) - ] - return ClipEvaluation( - clip=clip_annotation.clip, - matches=match( - ground_truth, - predictions, - clip=clip, - targets=self.targets, - matcher=self.matcher, - ), - ) - - def filter_sound_event_annotations( - self, - sound_event_annotation: data.SoundEventAnnotation, - clip: data.Clip, - ) -> bool: - if not self.targets.filter(sound_event_annotation): - return False - - geometry = sound_event_annotation.sound_event.geometry - if geometry is None: - return False - - return is_in_bounds( - geometry, - clip, - self.config.ignore_start_end, - ) - - def filter_predictions( - self, - prediction: RawPrediction, - clip: data.Clip, - ) -> bool: - return is_in_bounds( - prediction.geometry, - clip, - self.config.ignore_start_end, - ) - - def evaluate( - self, - clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], - ) -> List[ClipEvaluation]: - if len(clip_annotations) != len(predictions): - raise ValueError( - "Number of annotated clips and sets of predictions do not match" - ) - - return [ - self.match(clip_annotation, preds) - for clip_annotation, preds in zip(clip_annotations, predictions) - ] - - def compute_metrics( - self, - clip_evaluations: Sequence[ClipEvaluation], - ) -> Dict[str, float]: - results = {} - - for metric in self.metrics: - results.update(metric(clip_evaluations)) - - return results - - def generate_plots( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Iterable[Tuple[str, Figure]]: - for plotter in self.plots: - for name, fig in plotter(clip_evaluations): - yield name, fig - - -def build_evaluator( - config: Optional[EvaluationConfig] = None, - targets: Optional[TargetProtocol] = None, - matcher: Optional[MatcherProtocol] = None, - plots: Optional[List[PlotterProtocol]] = None, - metrics: Optional[List[MetricsProtocol]] = None, -) -> EvaluatorProtocol: - config = config or EvaluationConfig() - targets = targets or build_targets() - matcher = matcher or build_matcher(config.match_strategy) - - if metrics is None: - metrics = [ - build_metric(config, targets.class_names) - for config in config.metrics - ] - - if plots is None: - plots = [ - build_plotter(config, targets.class_names) - for config in config.plots - ] - - return Evaluator( - config=config, - targets=targets, - matcher=matcher, - metrics=metrics, - plots=plots, - ) - - -def is_in_bounds( - geometry: data.Geometry, - clip: data.Clip, - buffer: float, -) -> bool: - start_time = compute_bounds(geometry)[0] - return (start_time >= clip.start_time + buffer) and ( - start_time <= clip.end_time - buffer - ) diff --git a/src/batdetect2/evaluate/evaluator/__init__.py b/src/batdetect2/evaluate/evaluator/__init__.py new file mode 100644 index 0000000..92b31a9 --- /dev/null +++ b/src/batdetect2/evaluate/evaluator/__init__.py @@ -0,0 +1,114 @@ +from typing import ( + Annotated, + Any, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +from matplotlib.figure import Figure +from pydantic import Field +from soundevent import data + +from batdetect2.core.configs import BaseConfig +from batdetect2.evaluate.evaluator.base import evaluators +from batdetect2.evaluate.evaluator.clip import ClipMetricsConfig +from batdetect2.evaluate.evaluator.per_class import ClassificationMetricsConfig +from batdetect2.evaluate.evaluator.single import GlobalEvaluatorConfig +from batdetect2.targets import build_targets +from batdetect2.typing import ( + EvaluatorProtocol, + RawPrediction, + TargetProtocol, +) + +__all__ = [ + "EvaluatorConfig", + "build_evaluator", +] + + +EvaluatorConfig = Annotated[ + Union[ + ClassificationMetricsConfig, + GlobalEvaluatorConfig, + ClipMetricsConfig, + "MultipleEvaluatorConfig", + ], + Field(discriminator="name"), +] + + +class MultipleEvaluatorConfig(BaseConfig): + name: Literal["multiple_evaluations"] = "multiple_evaluations" + evaluations: List[EvaluatorConfig] = Field( + default_factory=lambda: [ + ClassificationMetricsConfig(), + GlobalEvaluatorConfig(), + ] + ) + + +class MultipleEvaluator: + def __init__( + self, + targets: TargetProtocol, + evaluators: Sequence[EvaluatorProtocol], + ): + self.targets = targets + self.evaluators = evaluators + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> List[Any]: + return [ + evaluator.evaluate( + clip_annotations, + predictions, + ) + for evaluator in self.evaluators + ] + + def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]: + results = {} + + for evaluator, outputs in zip(self.evaluators, eval_outputs): + results.update(evaluator.compute_metrics(outputs)) + + return results + + def generate_plots( + self, + eval_outputs: List[Any], + ) -> Iterable[Tuple[str, Figure]]: + for evaluator, outputs in zip(self.evaluators, eval_outputs): + for name, fig in evaluator.generate_plots(outputs): + yield name, fig + + @evaluators.register(MultipleEvaluatorConfig) + @staticmethod + def from_config(config: MultipleEvaluatorConfig, targets: TargetProtocol): + return MultipleEvaluator( + evaluators=[ + build_evaluator(conf, targets=targets) + for conf in config.evaluations + ], + targets=targets, + ) + + +def build_evaluator( + config: Optional[EvaluatorConfig] = None, + targets: Optional[TargetProtocol] = None, +) -> EvaluatorProtocol: + targets = targets or build_targets() + + config = config or MultipleEvaluatorConfig() + return evaluators.build(config, targets) diff --git a/src/batdetect2/evaluate/evaluator/base.py b/src/batdetect2/evaluate/evaluator/base.py new file mode 100644 index 0000000..8248ee5 --- /dev/null +++ b/src/batdetect2/evaluate/evaluator/base.py @@ -0,0 +1,107 @@ +from pydantic import Field +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.core import BaseConfig +from batdetect2.core.registries import Registry +from batdetect2.evaluate.match import ( + MatchConfig, + StartTimeMatchConfig, + build_matcher, +) +from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol +from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.targets import TargetProtocol + +__all__ = [ + "BaseEvaluatorConfig", + "BaseEvaluator", +] + +evaluators: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry("metric") + + +class BaseEvaluatorConfig(BaseConfig): + prefix: str + ignore_start_end: float = 0.01 + matching_strategy: MatchConfig = Field( + default_factory=StartTimeMatchConfig + ) + + +class BaseEvaluator(EvaluatorProtocol): + targets: TargetProtocol + + matcher: MatcherProtocol + + ignore_start_end: float + + prefix: str + + def __init__( + self, + matcher: MatcherProtocol, + targets: TargetProtocol, + prefix: str, + ignore_start_end: float = 0.01, + ): + self.matcher = matcher + self.targets = targets + self.prefix = prefix + self.ignore_start_end = ignore_start_end + + def filter_sound_event_annotations( + self, + sound_event_annotation: data.SoundEventAnnotation, + clip: data.Clip, + ) -> bool: + if not self.targets.filter(sound_event_annotation): + return False + + geometry = sound_event_annotation.sound_event.geometry + if geometry is None: + return False + + return is_in_bounds( + geometry, + clip, + self.ignore_start_end, + ) + + def filter_predictions( + self, + prediction: RawPrediction, + clip: data.Clip, + ) -> bool: + return is_in_bounds( + prediction.geometry, + clip, + self.ignore_start_end, + ) + + @classmethod + def build( + cls, + config: BaseEvaluatorConfig, + targets: TargetProtocol, + **kwargs, + ): + matcher = build_matcher(config.matching_strategy) + return cls( + matcher=matcher, + targets=targets, + prefix=config.prefix, + ignore_start_end=config.ignore_start_end, + **kwargs, + ) + + +def is_in_bounds( + geometry: data.Geometry, + clip: data.Clip, + buffer: float, +) -> bool: + start_time = compute_bounds(geometry)[0] + return (start_time >= clip.start_time + buffer) and ( + start_time <= clip.end_time - buffer + ) diff --git a/src/batdetect2/evaluate/evaluator/clip.py b/src/batdetect2/evaluate/evaluator/clip.py new file mode 100644 index 0000000..1556bc5 --- /dev/null +++ b/src/batdetect2/evaluate/evaluator/clip.py @@ -0,0 +1,163 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Callable, Dict, List, Literal, Sequence, Set + +from pydantic import Field, field_validator +from sklearn import metrics +from soundevent import data + +from batdetect2.evaluate.evaluator.base import ( + BaseEvaluator, + BaseEvaluatorConfig, + evaluators, +) +from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.targets import TargetProtocol + + +@dataclass +class ClipInfo: + gt_det: bool + gt_classes: Set[str] + pred_score: float + pred_class_scores: Dict[str, float] + + +ClipMetric = Callable[[Sequence[ClipInfo]], float] + + +def clip_detection_average_precision( + clip_evaluations: Sequence[ClipInfo], +) -> float: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + y_true.append(clip_eval.gt_det) + y_score.append(clip_eval.pred_score) + + return average_precision(y_true=y_true, y_score=y_score) + + +def clip_detection_roc_auc( + clip_evaluations: Sequence[ClipInfo], +) -> float: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + y_true.append(clip_eval.gt_det) + y_score.append(clip_eval.pred_score) + + return float(metrics.roc_auc_score(y_true=y_true, y_score=y_score)) + + +clip_metrics = { + "average_precision": clip_detection_average_precision, + "roc_auc": clip_detection_roc_auc, +} + + +class ClipMetricsConfig(BaseEvaluatorConfig): + name: Literal["clip"] = "clip" + prefix: str = "clip" + metrics: List[str] = Field( + default_factory=lambda: [ + "average_precision", + "roc_auc", + ] + ) + + @field_validator("metrics", mode="after") + @classmethod + def validate_metrics(cls, v: List[str]) -> List[str]: + for metric_name in v: + if metric_name not in clip_metrics: + raise ValueError(f"Unknown metric {metric_name}") + return v + + +class ClipEvaluator(BaseEvaluator): + def __init__(self, *args, metrics: Dict[str, ClipMetric], **kwargs): + super().__init__(*args, **kwargs) + self.metrics = metrics + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> List[ClipInfo]: + return [ + self.match_clip(clip_annotation, preds) + for clip_annotation, preds in zip(clip_annotations, predictions) + ] + + def compute_metrics( + self, + eval_outputs: List[ClipInfo], + ) -> Dict[str, float]: + scores = { + name: metric(eval_outputs) for name, metric in self.metrics.items() + } + return { + f"{self.prefix}/{name}": score for name, score in scores.items() + } + + def match_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipInfo: + clip = clip_annotation.clip + + gt_det = False + gt_classes = set() + for sound_event in clip_annotation.sound_events: + if self.filter_sound_event_annotations(sound_event, clip): + continue + + gt_det = True + class_name = self.targets.encode_class(sound_event) + + if class_name is None: + continue + + gt_classes.add(class_name) + + pred_score = 0 + pred_class_scores: defaultdict[str, float] = defaultdict(lambda: 0) + for pred in predictions: + if self.filter_predictions(pred, clip): + continue + + pred_score = max(pred_score, pred.detection_score) + + for class_name, class_score in zip( + self.targets.class_names, + pred.class_scores, + ): + pred_class_scores[class_name] = max( + pred_class_scores[class_name], + class_score, + ) + + return ClipInfo( + gt_det=gt_det, + gt_classes=gt_classes, + pred_score=pred_score, + pred_class_scores=pred_class_scores, + ) + + @evaluators.register(ClipMetricsConfig) + @staticmethod + def from_config( + config: ClipMetricsConfig, + targets: TargetProtocol, + ): + metrics = {name: clip_metrics.get(name) for name in config.metrics} + return ClipEvaluator.build( + config=config, + metrics=metrics, + targets=targets, + ) diff --git a/src/batdetect2/evaluate/evaluator/multiple.py b/src/batdetect2/evaluate/evaluator/multiple.py new file mode 100644 index 0000000..e69de29 diff --git a/src/batdetect2/evaluate/evaluator/per_class.py b/src/batdetect2/evaluate/evaluator/per_class.py new file mode 100644 index 0000000..c5177d8 --- /dev/null +++ b/src/batdetect2/evaluate/evaluator/per_class.py @@ -0,0 +1,219 @@ +from collections import defaultdict +from typing import ( + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Sequence, +) + +import numpy as np +from pydantic import Field +from soundevent import data + +from batdetect2.evaluate.evaluator.base import ( + BaseEvaluator, + BaseEvaluatorConfig, + evaluators, +) +from batdetect2.evaluate.match import match +from batdetect2.evaluate.metrics.per_class_matches import ( + ClassificationAveragePrecisionConfig, + PerClassMatchMetric, + PerClassMatchMetricConfig, + build_per_class_matches_metric, +) +from batdetect2.typing import ( + ClipMatches, + RawPrediction, + TargetProtocol, +) + +ScoreFn = Callable[[RawPrediction, int], float] + + +def score_by_class_score(pred: RawPrediction, class_index: int) -> float: + return float(pred.class_scores[class_index]) + + +def score_by_adjusted_class_score( + pred: RawPrediction, + class_index: int, +) -> float: + return float(pred.class_scores[class_index]) * pred.detection_score + + +ScoreFunctionOption = Literal["class_score", "adjusted_class_score"] +score_functions: Mapping[ScoreFunctionOption, ScoreFn] = { + "class_score": score_by_class_score, + "adjusted_class_score": score_by_adjusted_class_score, +} + + +def get_score_fn(name: ScoreFunctionOption) -> ScoreFn: + return score_functions[name] + + +class ClassificationMetricsConfig(BaseEvaluatorConfig): + name: Literal["classification"] = "classification" + prefix: str = "classification" + include_generics: bool = True + score_by: ScoreFunctionOption = "class_score" + metrics: List[PerClassMatchMetricConfig] = Field( + default_factory=lambda: [ClassificationAveragePrecisionConfig()] + ) + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class PerClassEvaluator(BaseEvaluator): + def __init__( + self, + *args, + metrics: Dict[str, PerClassMatchMetric], + score_fn: ScoreFn, + include_generics: bool = True, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.score_fn = score_fn + self.metrics = metrics + + self.include_generics = include_generics + + self.include = include + self.exclude = exclude + + self.selected = self.targets.class_names + if include is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name in include + ] + + if exclude is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name not in exclude + ] + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> Dict[str, List[ClipMatches]]: + ret = defaultdict(list) + + for clip_annotation, preds in zip(clip_annotations, predictions): + matches = self.match_clip(clip_annotation, preds) + for class_name, clip_eval in matches.items(): + ret[class_name].append(clip_eval) + + return ret + + def compute_metrics( + self, + eval_outputs: Dict[str, List[ClipMatches]], + ) -> Dict[str, float]: + results = {} + + for metric_name, metric in self.metrics.items(): + class_scores = { + class_name: metric(eval_outputs[class_name], class_name) + for class_name in self.targets.class_names + } + mean = float( + np.mean([v for v in class_scores.values() if v != np.nan]) + ) + + results[f"{self.prefix}/mean_{metric_name}"] = mean + + for class_name, value in class_scores.items(): + if class_name not in self.selected: + continue + + results[f"{self.prefix}/{metric_name}/{class_name}"] = value + + return results + + def match_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> Dict[str, ClipMatches]: + clip = clip_annotation.clip + + preds = [ + pred for pred in predictions if self.filter_predictions(pred, clip) + ] + + all_gts = [ + sound_event + for sound_event in clip_annotation.sound_events + if self.filter_sound_event_annotations(sound_event, clip) + ] + + ret = {} + + for class_name in self.targets.class_names: + class_idx = self.targets.class_names.index(class_name) + + # Only match to targets of the given class + gts = [ + sound_event + for sound_event in all_gts + if self.is_class(sound_event, class_name) + ] + scores = [self.score_fn(pred, class_idx) for pred in preds] + + ret[class_name] = match( + gts, + preds, + clip=clip, + scores=scores, + targets=self.targets, + matcher=self.matcher, + ) + + return ret + + def is_class( + self, + sound_event: data.SoundEventAnnotation, + class_name: str, + ) -> bool: + sound_event_class = self.targets.encode_class(sound_event) + + if sound_event_class is None and self.include_generics: + # Sound events that are generic could be of the given + # class + return True + + return sound_event_class == class_name + + @evaluators.register(ClassificationMetricsConfig) + @staticmethod + def from_config( + config: ClassificationMetricsConfig, + targets: TargetProtocol, + ): + metrics = { + metric.name: build_per_class_matches_metric(metric) + for metric in config.metrics + } + return PerClassEvaluator.build( + config=config, + targets=targets, + metrics=metrics, + score_fn=get_score_fn(config.score_by), + include_generics=config.include_generics, + include=config.include, + exclude=config.exclude, + ) diff --git a/src/batdetect2/evaluate/evaluator/single.py b/src/batdetect2/evaluate/evaluator/single.py new file mode 100644 index 0000000..4e3d22b --- /dev/null +++ b/src/batdetect2/evaluate/evaluator/single.py @@ -0,0 +1,126 @@ +from typing import Callable, Dict, List, Literal, Mapping, Sequence + +from pydantic import Field +from soundevent import data + +from batdetect2.evaluate.evaluator.base import ( + BaseEvaluator, + BaseEvaluatorConfig, + evaluators, +) +from batdetect2.evaluate.match import match +from batdetect2.evaluate.metrics.matches import ( + DetectionAveragePrecisionConfig, + MatchesMetric, + MatchMetricConfig, + build_match_metric, +) +from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol + +ScoreFn = Callable[[RawPrediction], float] + + +def score_by_detection_score(pred: RawPrediction) -> float: + return pred.detection_score + + +def score_by_top_class_score(pred: RawPrediction) -> float: + return pred.class_scores.max() + + +ScoreFunctionOption = Literal["detection_score", "top_class_score"] +score_functions: Mapping[ScoreFunctionOption, ScoreFn] = { + "detection_score": score_by_detection_score, + "top_class_score": score_by_top_class_score, +} + + +def get_score_fn(name: ScoreFunctionOption) -> ScoreFn: + return score_functions[name] + + +class GlobalEvaluatorConfig(BaseEvaluatorConfig): + name: Literal["detection"] = "detection" + prefix: str = "detection" + score_by: ScoreFunctionOption = "detection_score" + metrics: List[MatchMetricConfig] = Field( + default_factory=lambda: [DetectionAveragePrecisionConfig()] + ) + + +class GlobalEvaluator(BaseEvaluator): + def __init__( + self, + *args, + score_fn: ScoreFn, + metrics: Dict[str, MatchesMetric], + **kwargs, + ): + super().__init__(*args, **kwargs) + self.metrics = metrics + self.score_fn = score_fn + + def compute_metrics( + self, + eval_outputs: List[ClipMatches], + ) -> Dict[str, float]: + scores = { + name: metric(eval_outputs) for name, metric in self.metrics.items() + } + return { + f"{self.prefix}/{name}": score for name, score in scores.items() + } + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> List[ClipMatches]: + return [ + self.match_clip(clip_annotation, preds) + for clip_annotation, preds in zip(clip_annotations, predictions) + ] + + def match_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipMatches: + clip = clip_annotation.clip + + gts = [ + sound_event + for sound_event in clip_annotation.sound_events + if self.filter_sound_event_annotations(sound_event, clip) + ] + preds = [ + pred for pred in predictions if self.filter_predictions(pred, clip) + ] + scores = [self.score_fn(pred) for pred in preds] + + return match( + gts, + preds, + scores=scores, + clip=clip, + targets=self.targets, + matcher=self.matcher, + ) + + @evaluators.register(GlobalEvaluatorConfig) + @staticmethod + def from_config( + config: GlobalEvaluatorConfig, + targets: TargetProtocol, + ): + metrics = { + metric.name: build_match_metric(metric) + for metric in config.metrics + } + score_fn = get_score_fn(config.score_by) + return GlobalEvaluator.build( + config=config, + score_fn=score_fn, + metrics=metrics, + targets=targets, + ) diff --git a/src/batdetect2/evaluate/evaluator/top_class.py b/src/batdetect2/evaluate/evaluator/top_class.py new file mode 100644 index 0000000..149447e --- /dev/null +++ b/src/batdetect2/evaluate/evaluator/top_class.py @@ -0,0 +1,133 @@ +from typing import Dict, List, Literal, Sequence + +from pydantic import Field, field_validator +from soundevent import data + +from batdetect2.evaluate.match import match +from batdetect2.evaluate.metrics.base import ( + BaseMetric, + BaseMetricConfig, + metrics_registry, +) +from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.evaluate.metrics.detection import DetectionMetric +from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol + +__all__ = [ + "TopClassEvaluator", + "TopClassEvaluatorConfig", +] + + +def top_class_average_precision( + clip_evaluations: Sequence[ClipMatches], +) -> float: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + is_generic = m.gt_det and (m.gt_class is None) + + # Ignore ground truth sounds with unknown class + if is_generic: + continue + + num_positives += int(m.gt_det) + + # Ignore matches that don't correspond to a prediction + if m.pred_geometry is None: + continue + + y_true.append(m.gt_det & (m.top_class == m.gt_class)) + y_score.append(m.top_class_score) + + return average_precision(y_true, y_score, num_positives=num_positives) + + +top_class_metrics = { + "average_precision": top_class_average_precision, +} + + +class TopClassEvaluatorConfig(BaseMetricConfig): + name: Literal["top_class"] = "top_class" + prefix: str = "top_class" + metrics: List[str] = Field(default_factory=lambda: ["average_precision"]) + + @field_validator("metrics", mode="after") + @classmethod + def validate_metrics(cls, v: List[str]) -> List[str]: + for metric_name in v: + if metric_name not in top_class_metrics: + raise ValueError(f"Unknown metric {metric_name}") + return v + + +class TopClassEvaluator(BaseMetric): + def __init__(self, *args, metrics: Dict[str, DetectionMetric], **kwargs): + super().__init__(*args, **kwargs) + self.metrics = metrics + + def __call__( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> Dict[str, float]: + clip_evaluations = [ + self.match_clip(clip_annotation, preds) + for clip_annotation, preds in zip(clip_annotations, predictions) + ] + scores = { + name: metric(clip_evaluations) + for name, metric in self.metrics.items() + } + return { + f"{self.prefix}/{name}": score for name, score in scores.items() + } + + def match_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipMatches: + clip = clip_annotation.clip + + gts = [ + sound_event + for sound_event in clip_annotation.sound_events + if self.filter_sound_event_annotations(sound_event, clip) + ] + preds = [ + pred for pred in predictions if self.filter_predictions(pred, clip) + ] + # Use score of top class for matching + scores = [pred.class_scores.max() for pred in preds] + + return match( + gts, + preds, + scores=scores, + clip=clip, + targets=self.targets, + matcher=self.matcher, + ) + + @classmethod + def from_config( + cls, + config: TopClassEvaluatorConfig, + targets: TargetProtocol, + ): + metrics = { + name: top_class_metrics.get(name) for name in config.metrics + } + return super().build( + config=config, + metrics=metrics, + targets=targets, + ) + + +metrics_registry.register(TopClassEvaluatorConfig, TopClassEvaluator) diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index 625e7fe..621c869 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -8,7 +8,7 @@ from batdetect2.evaluate.tables import FullEvaluationTable from batdetect2.logging import get_image_logger, get_table_logger from batdetect2.models import Model from batdetect2.postprocess import to_raw_predictions -from batdetect2.typing import ClipEvaluation, EvaluatorProtocol +from batdetect2.typing import ClipMatches, EvaluatorProtocol class EvaluationModule(LightningModule): @@ -56,7 +56,7 @@ class EvaluationModule(LightningModule): self.plot_examples(self.clip_evaluations) self.log_table(self.clip_evaluations) - def log_table(self, evaluated_clips: Sequence[ClipEvaluation]): + def log_table(self, evaluated_clips: Sequence[ClipMatches]): table_logger = get_table_logger(self.logger) # type: ignore if table_logger is None: @@ -65,7 +65,7 @@ class EvaluationModule(LightningModule): df = FullEvaluationTable()(evaluated_clips) table_logger("full_evaluation", df, 0) - def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]): + def plot_examples(self, evaluated_clips: Sequence[ClipMatches]): plotter = get_image_logger(self.logger) # type: ignore if plotter is None: @@ -74,7 +74,7 @@ class EvaluationModule(LightningModule): for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): plotter(figure_name, fig, self.global_step) - def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]): + def log_metrics(self, evaluated_clips: Sequence[ClipMatches]): metrics = self.evaluator.compute_metrics(evaluated_clips) self.log_dict(metrics) diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index af3545b..74feb4f 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -8,8 +8,7 @@ from soundevent.evaluation import compute_affinity from soundevent.evaluation import match_geometries as optimal_match from soundevent.geometry import compute_bounds -from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry +from batdetect2.core import BaseConfig, Registry from batdetect2.evaluate.affinity import ( AffinityConfig, GeometricIOUConfig, @@ -17,11 +16,13 @@ from batdetect2.evaluate.affinity import ( ) from batdetect2.targets import build_targets from batdetect2.typing import ( + AffinityFunction, + MatcherProtocol, MatchEvaluation, + RawPrediction, TargetProtocol, ) -from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol -from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.evaluate import ClipMatches MatchingGeometry = Literal["bbox", "interval", "timestamp"] """The geometry representation to use for matching.""" @@ -33,9 +34,10 @@ def match( sound_event_annotations: Sequence[data.SoundEventAnnotation], raw_predictions: Sequence[RawPrediction], clip: data.Clip, + scores: Optional[Sequence[float]] = None, targets: Optional[TargetProtocol] = None, matcher: Optional[MatcherProtocol] = None, -) -> List[MatchEvaluation]: +) -> ClipMatches: if matcher is None: matcher = build_matcher() @@ -51,9 +53,11 @@ def match( raw_prediction.geometry for raw_prediction in raw_predictions ] - scores = [ - raw_prediction.detection_score for raw_prediction in raw_predictions - ] + if scores is None: + scores = [ + raw_prediction.detection_score + for raw_prediction in raw_predictions + ] matches = [] @@ -73,9 +77,11 @@ def match( gt_det = target_idx is not None gt_class = targets.encode_class(target) if target is not None else None + gt_geometry = ( + target_geometries[target_idx] if target_idx is not None else None + ) pred_score = float(prediction.detection_score) if prediction else 0 - pred_geometry = ( predicted_geometries[source_idx] if source_idx is not None @@ -84,7 +90,7 @@ def match( class_scores = ( { - str(class_name): float(score) + class_name: score for class_name, score in zip( targets.class_names, prediction.class_scores, @@ -100,6 +106,7 @@ def match( sound_event_annotation=target, gt_det=gt_det, gt_class=gt_class, + gt_geometry=gt_geometry, pred_score=pred_score, pred_class_scores=class_scores, pred_geometry=pred_geometry, @@ -107,7 +114,7 @@ def match( ) ) - return matches + return ClipMatches(clip=clip, matches=matches) class StartTimeMatchConfig(BaseConfig): @@ -132,12 +139,10 @@ class StartTimeMatcher(MatcherProtocol): distance_threshold=self.distance_threshold, ) - @classmethod - def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher": - return cls(distance_threshold=config.distance_threshold) - - -matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher) + @matching_strategies.register(StartTimeMatchConfig) + @staticmethod + def from_config(config: StartTimeMatchConfig): + return StartTimeMatcher(distance_threshold=config.distance_threshold) def match_start_times( @@ -264,19 +269,17 @@ class GreedyMatcher(MatcherProtocol): affinity_threshold=self.affinity_threshold, ) - @classmethod - def from_config(cls, config: GreedyMatchConfig): + @matching_strategies.register(GreedyMatchConfig) + @staticmethod + def from_config(config: GreedyMatchConfig): affinity_function = build_affinity_function(config.affinity_function) - return cls( + return GreedyMatcher( geometry=config.geometry, affinity_threshold=config.affinity_threshold, affinity_function=affinity_function, ) -matching_strategies.register(GreedyMatchConfig, GreedyMatcher) - - def greedy_match( ground_truth: Sequence[data.Geometry], predictions: Sequence[data.Geometry], @@ -313,21 +316,21 @@ def greedy_match( unassigned_gt = set(range(len(ground_truth))) if not predictions: - for target_idx in range(len(ground_truth)): - yield None, target_idx, 0 + for gt_idx in range(len(ground_truth)): + yield None, gt_idx, 0 return if not ground_truth: - for source_idx in range(len(predictions)): - yield source_idx, None, 0 + for pred_idx in range(len(predictions)): + yield pred_idx, None, 0 return indices = np.argsort(scores)[::-1] - for source_idx in indices: - source_geometry = predictions[source_idx] + for pred_idx in indices: + source_geometry = predictions[pred_idx] affinities = np.array( [ @@ -340,18 +343,18 @@ def greedy_match( affinity = affinities[closest_target] if affinities[closest_target] <= affinity_threshold: - yield source_idx, None, 0 + yield pred_idx, None, 0 continue if closest_target not in unassigned_gt: - yield source_idx, None, 0 + yield pred_idx, None, 0 continue unassigned_gt.remove(closest_target) - yield source_idx, closest_target, affinity + yield pred_idx, closest_target, affinity - for target_idx in unassigned_gt: - yield None, target_idx, 0 + for gt_idx in unassigned_gt: + yield None, gt_idx, 0 class OptimalMatchConfig(BaseConfig): @@ -386,17 +389,16 @@ class OptimalMatcher(MatcherProtocol): affinity_threshold=self.affinity_threshold, ) - @classmethod - def from_config(cls, config: OptimalMatchConfig): - return cls( + @matching_strategies.register(OptimalMatchConfig) + @staticmethod + def from_config(config: OptimalMatchConfig): + return OptimalMatcher( affinity_threshold=config.affinity_threshold, time_buffer=config.time_buffer, frequency_buffer=config.frequency_buffer, ) -matching_strategies.register(OptimalMatchConfig, OptimalMatcher) - MatchConfig = Annotated[ Union[ GreedyMatchConfig, diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py deleted file mode 100644 index efa024a..0000000 --- a/src/batdetect2/evaluate/metrics.py +++ /dev/null @@ -1,712 +0,0 @@ -from collections import defaultdict -from collections.abc import Callable, Mapping -from typing import ( - Annotated, - Any, - Dict, - List, - Literal, - Optional, - Sequence, - Union, -) - -import numpy as np -from pydantic import Field -from sklearn import metrics, preprocessing - -from batdetect2.core import BaseConfig, Registry -from batdetect2.typing import ClipEvaluation, MetricsProtocol - -__all__ = ["DetectionAP", "ClassificationAP"] - - -metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric") - - -APImplementation = Literal["sklearn", "pascal_voc"] - - -class DetectionAPConfig(BaseConfig): - name: Literal["detection_ap"] = "detection_ap" - ap_implementation: APImplementation = "pascal_voc" - - -class DetectionAP(MetricsProtocol): - def __init__( - self, - implementation: APImplementation = "pascal_voc", - ): - self.implementation = implementation - self.metric = _ap_impl_mapping[self.implementation] - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true, y_score = zip( - *[ - (match.gt_det, match.pred_score) - for clip_eval in clip_evaluations - for match in clip_eval.matches - ] - ) - score = float(self.metric(y_true, y_score)) - return {"detection_AP": score} - - @classmethod - def from_config(cls, config: DetectionAPConfig, class_names: List[str]): - return cls(implementation=config.ap_implementation) - - -metrics_registry.register(DetectionAPConfig, DetectionAP) - - -class DetectionROCAUCConfig(BaseConfig): - name: Literal["detection_roc_auc"] = "detection_roc_auc" - - -class DetectionROCAUC(MetricsProtocol): - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true, y_score = zip( - *[ - (match.gt_det, match.pred_score) - for clip_eval in clip_evaluations - for match in clip_eval.matches - ] - ) - score = float(metrics.roc_auc_score(y_true, y_score)) - return {"detection_ROC_AUC": score} - - @classmethod - def from_config( - cls, config: DetectionROCAUCConfig, class_names: List[str] - ): - return cls() - - -metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC) - - -class ClassificationAPConfig(BaseConfig): - name: Literal["classification_ap"] = "classification_ap" - ap_implementation: APImplementation = "pascal_voc" - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None - - -class ClassificationAP(MetricsProtocol): - def __init__( - self, - class_names: List[str], - implementation: APImplementation = "pascal_voc", - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - ): - self.implementation = implementation - self.metric = _ap_impl_mapping[self.implementation] - self.class_names = class_names - - self.selected = class_names - - if include is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name in include - ] - - if exclude is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name not in exclude - ] - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - for match in clip_eval.matches: - # Ignore generic unclassified targets - if match.gt_det and match.gt_class is None: - continue - - y_true.append( - match.gt_class - if match.gt_class is not None - else "__NONE__" - ) - - y_pred.append( - np.array( - [ - match.pred_class_scores.get(name, 0) - for name in self.class_names - ] - ) - ) - - y_true = preprocessing.label_binarize(y_true, classes=self.class_names) - y_pred = np.stack(y_pred) - - class_scores = {} - for class_index, class_name in enumerate(self.class_names): - y_true_class = y_true[:, class_index] - y_pred_class = y_pred[:, class_index] - class_ap = self.metric(y_true_class, y_pred_class) - class_scores[class_name] = float(class_ap) - - mean_ap = np.mean( - [value for value in class_scores.values() if value != 0] - ) - - return { - "classification_mAP": float(mean_ap), - **{ - f"classification_AP/{class_name}": class_scores[class_name] - for class_name in self.selected - }, - } - - @classmethod - def from_config( - cls, - config: ClassificationAPConfig, - class_names: List[str], - ): - return cls( - class_names, - implementation=config.ap_implementation, - include=config.include, - exclude=config.exclude, - ) - - -metrics_registry.register(ClassificationAPConfig, ClassificationAP) - - -class ClassificationROCAUCConfig(BaseConfig): - name: Literal["classification_roc_auc"] = "classification_roc_auc" - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None - - -class ClassificationROCAUC(MetricsProtocol): - def __init__( - self, - class_names: List[str], - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - ): - self.class_names = class_names - self.selected = class_names - - if include is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name in include - ] - - if exclude is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name not in exclude - ] - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - for match in clip_eval.matches: - # Ignore generic unclassified targets - if match.gt_det and match.gt_class is None: - continue - - y_true.append( - match.gt_class - if match.gt_class is not None - else "__NONE__" - ) - - y_pred.append( - np.array( - [ - match.pred_class_scores.get(name, 0) - for name in self.class_names - ] - ) - ) - - y_true = preprocessing.label_binarize(y_true, classes=self.class_names) - y_pred = np.stack(y_pred) - - class_scores = {} - for class_index, class_name in enumerate(self.class_names): - y_true_class = y_true[:, class_index] - y_pred_class = y_pred[:, class_index] - class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class) - class_scores[class_name] = float(class_roc_auc) - - mean_roc_auc = np.mean( - [value for value in class_scores.values() if value != 0] - ) - - return { - "classification_macro_average_ROC_AUC": float(mean_roc_auc), - **{ - f"classification_ROC_AUC/{class_name}": class_scores[ - class_name - ] - for class_name in self.selected - }, - } - - @classmethod - def from_config( - cls, - config: ClassificationROCAUCConfig, - class_names: List[str], - ): - return cls( - class_names, - include=config.include, - exclude=config.exclude, - ) - - -metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC) - - -class TopClassAPConfig(BaseConfig): - name: Literal["top_class_ap"] = "top_class_ap" - ap_implementation: APImplementation = "pascal_voc" - - -class TopClassAP(MetricsProtocol): - def __init__( - self, - implementation: APImplementation = "pascal_voc", - ): - self.implementation = implementation - self.metric = _ap_impl_mapping[self.implementation] - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_score = [] - - for clip_eval in clip_evaluations: - for match in clip_eval.matches: - # Ignore generic unclassified targets - if match.gt_det and match.gt_class is None: - continue - - top_class = match.pred_class - - y_true.append(top_class == match.gt_class) - y_score.append(match.pred_class_score) - - score = float(self.metric(y_true, y_score)) - return {"top_class_AP": score} - - @classmethod - def from_config(cls, config: TopClassAPConfig, class_names: List[str]): - return cls(implementation=config.ap_implementation) - - -metrics_registry.register(TopClassAPConfig, TopClassAP) - - -class ClassificationBalancedAccuracyConfig(BaseConfig): - name: Literal["classification_balanced_accuracy"] = ( - "classification_balanced_accuracy" - ) - - -class ClassificationBalancedAccuracy(MetricsProtocol): - def __init__(self, class_names: List[str]): - self.class_names = class_names - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - for match in clip_eval.matches: - top_class = match.pred_class - - # Focus on matches - if match.gt_class is None or top_class is None: - continue - - y_true.append(self.class_names.index(match.gt_class)) - y_pred.append(self.class_names.index(top_class)) - - score = float(metrics.balanced_accuracy_score(y_true, y_pred)) - return {"classification_balanced_accuracy": score} - - @classmethod - def from_config( - cls, - config: ClassificationBalancedAccuracyConfig, - class_names: List[str], - ): - return cls(class_names) - - -metrics_registry.register( - ClassificationBalancedAccuracyConfig, - ClassificationBalancedAccuracy, -) - - -class ClipDetectionAPConfig(BaseConfig): - name: Literal["clip_detection_ap"] = "clip_detection_ap" - ap_implementation: APImplementation = "pascal_voc" - - -class ClipDetectionAP(MetricsProtocol): - def __init__( - self, - implementation: APImplementation, - ): - self.implementation = implementation - self.metric = _ap_impl_mapping[self.implementation] - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_score = [] - - for clip_eval in clip_evaluations: - clip_det = [] - clip_scores = [] - - for match in clip_eval.matches: - clip_det.append(match.gt_det) - clip_scores.append(match.pred_score) - - y_true.append(any(clip_det)) - y_score.append(max(clip_scores or [0])) - - return {"clip_detection_ap": self.metric(y_true, y_score)} - - @classmethod - def from_config( - cls, - config: ClipDetectionAPConfig, - class_names: List[str], - ): - return cls(implementation=config.ap_implementation) - - -metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP) - - -class ClipDetectionROCAUCConfig(BaseConfig): - name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc" - - -class ClipDetectionROCAUC(MetricsProtocol): - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_score = [] - - for clip_eval in clip_evaluations: - clip_det = [] - clip_scores = [] - - for match in clip_eval.matches: - clip_det.append(match.gt_det) - clip_scores.append(match.pred_score) - - y_true.append(any(clip_det)) - y_score.append(max(clip_scores or [0])) - - return { - "clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score)) - } - - @classmethod - def from_config( - cls, - config: ClipDetectionROCAUCConfig, - class_names: List[str], - ): - return cls() - - -metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC) - - -class ClipMulticlassAPConfig(BaseConfig): - name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap" - ap_implementation: APImplementation = "pascal_voc" - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None - - -class ClipMulticlassAP(MetricsProtocol): - def __init__( - self, - class_names: List[str], - implementation: APImplementation, - include: Optional[Sequence[str]] = None, - exclude: Optional[Sequence[str]] = None, - ): - self.implementation = implementation - self.metric = _ap_impl_mapping[self.implementation] - self.class_names = class_names - - self.selected = class_names - - if include is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name in include - ] - - if exclude is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name not in exclude - ] - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - clip_classes = set() - clip_scores = defaultdict(list) - - for match in clip_eval.matches: - if match.gt_class is not None: - clip_classes.add(match.gt_class) - - for class_name, score in match.pred_class_scores.items(): - clip_scores[class_name].append(score) - - y_true.append(clip_classes) - y_pred.append( - np.array( - [ - # Get max score for each class - max(clip_scores.get(class_name, [0])) - for class_name in self.class_names - ] - ) - ) - - y_true = preprocessing.MultiLabelBinarizer( - classes=self.class_names - ).fit_transform(y_true) - y_pred = np.stack(y_pred) - - class_scores = {} - for class_index, class_name in enumerate(self.class_names): - y_true_class = y_true[:, class_index] - y_pred_class = y_pred[:, class_index] - class_ap = self.metric(y_true_class, y_pred_class) - class_scores[class_name] = float(class_ap) - - mean_ap = np.mean( - [value for value in class_scores.values() if value != 0] - ) - return { - "clip_multiclass_mAP": float(mean_ap), - **{ - f"clip_multiclass_AP/{class_name}": class_scores[class_name] - for class_name in self.selected - }, - } - - @classmethod - def from_config( - cls, config: ClipMulticlassAPConfig, class_names: List[str] - ): - return cls( - implementation=config.ap_implementation, - include=config.include, - exclude=config.exclude, - class_names=class_names, - ) - - -metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP) - - -class ClipMulticlassROCAUCConfig(BaseConfig): - name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc" - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None - - -class ClipMulticlassROCAUC(MetricsProtocol): - def __init__( - self, - class_names: List[str], - include: Optional[Sequence[str]] = None, - exclude: Optional[Sequence[str]] = None, - ): - self.class_names = class_names - self.selected = class_names - - if include is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name in include - ] - - if exclude is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name not in exclude - ] - - def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] - ) -> Dict[str, float]: - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - clip_classes = set() - clip_scores = defaultdict(list) - - for match in clip_eval.matches: - if match.gt_class is not None: - clip_classes.add(match.gt_class) - - for class_name, score in match.pred_class_scores.items(): - clip_scores[class_name].append(score) - - y_true.append(clip_classes) - y_pred.append( - np.array( - [ - # Get maximum score for each class - max(clip_scores.get(class_name, [0])) - for class_name in self.class_names - ] - ) - ) - - y_true = preprocessing.MultiLabelBinarizer( - classes=self.class_names - ).fit_transform(y_true) - y_pred = np.stack(y_pred) - - class_scores = {} - for class_index, class_name in enumerate(self.class_names): - y_true_class = y_true[:, class_index] - y_pred_class = y_pred[:, class_index] - class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class) - class_scores[class_name] = float(class_roc_auc) - - mean_roc_auc = np.mean( - [value for value in class_scores.values() if value != 0] - ) - return { - "clip_multiclass_macro_ROC_AUC": float(mean_roc_auc), - **{ - f"clip_multiclass_ROC_AUC/{class_name}": class_scores[ - class_name - ] - for class_name in self.selected - }, - } - - @classmethod - def from_config( - cls, - config: ClipMulticlassROCAUCConfig, - class_names: List[str], - ): - return cls( - include=config.include, - exclude=config.exclude, - class_names=class_names, - ) - - -metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC) - -MetricConfig = Annotated[ - Union[ - DetectionAPConfig, - DetectionROCAUCConfig, - ClassificationAPConfig, - ClassificationROCAUCConfig, - TopClassAPConfig, - ClassificationBalancedAccuracyConfig, - ClipDetectionAPConfig, - ClipDetectionROCAUCConfig, - ClipMulticlassAPConfig, - ClipMulticlassROCAUCConfig, - ], - Field(discriminator="name"), -] - - -def build_metric(config: MetricConfig, class_names: List[str]): - return metrics_registry.build(config, class_names) - - -def pascal_voc_average_precision(y_true, y_score) -> float: - y_true = np.array(y_true) - y_score = np.array(y_score) - - sort_ind = np.argsort(y_score)[::-1] - y_true_sorted = y_true[sort_ind] - - num_positives = y_true.sum() - false_pos_c = np.cumsum(1 - y_true_sorted) - true_pos_c = np.cumsum(y_true_sorted) - - recall = true_pos_c / num_positives - precision = true_pos_c / np.maximum( - true_pos_c + false_pos_c, - np.finfo(np.float64).eps, - ) - - precision[np.isnan(precision)] = 0 - recall[np.isnan(recall)] = 0 - - # pascal 12 way - mprec = np.hstack((0, precision, 0)) - mrec = np.hstack((0, recall, 1)) - for ii in range(mprec.shape[0] - 2, -1, -1): - mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1]) - inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1 - ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum() - - return ave_prec - - -_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = { - "sklearn": metrics.average_precision_score, - "pascal_voc": pascal_voc_average_precision, -} diff --git a/src/batdetect2/evaluate/metrics/__init__.py b/src/batdetect2/evaluate/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/batdetect2/evaluate/metrics/common.py b/src/batdetect2/evaluate/metrics/common.py new file mode 100644 index 0000000..4375477 --- /dev/null +++ b/src/batdetect2/evaluate/metrics/common.py @@ -0,0 +1,46 @@ +from typing import Optional + +import numpy as np + + +def average_precision( + y_true, + y_score, + num_positives: Optional[int] = None, +) -> float: + y_true = np.array(y_true) + y_score = np.array(y_score) + + if num_positives is None: + num_positives = y_true.sum() + + # Remove non-detections + valid_inds = y_score > 0 + y_true = y_true[valid_inds] + y_score = y_score[valid_inds] + + # Sort by score + sort_ind = np.argsort(y_score)[::-1] + y_true_sorted = y_true[sort_ind] + + false_pos_c = np.cumsum(1 - y_true_sorted) + true_pos_c = np.cumsum(y_true_sorted) + + recall = true_pos_c / num_positives + precision = true_pos_c / np.maximum( + true_pos_c + false_pos_c, + np.finfo(np.float64).eps, + ) + + precision[np.isnan(precision)] = 0 + recall[np.isnan(recall)] = 0 + + # pascal 12 way + mprec = np.hstack((0, precision, 0)) + mrec = np.hstack((0, recall, 1)) + for ii in range(mprec.shape[0] - 2, -1, -1): + mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1]) + inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1 + ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum() + + return ave_prec diff --git a/src/batdetect2/evaluate/metrics/matches.py b/src/batdetect2/evaluate/metrics/matches.py new file mode 100644 index 0000000..0c3ec12 --- /dev/null +++ b/src/batdetect2/evaluate/metrics/matches.py @@ -0,0 +1,235 @@ +from typing import Annotated, Callable, Literal, Sequence, Union + +import numpy as np +from pydantic import Field +from sklearn import metrics + +from batdetect2.core import BaseConfig, Registry +from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.typing import ( + ClipMatches, +) + +__all__ = [ + "MatchMetricConfig", + "MatchesMetric", + "build_match_metric", +] + +MatchesMetric = Callable[[Sequence[ClipMatches]], float] + + +metrics_registry: Registry[MatchesMetric, []] = Registry("match_metric") + + +class DetectionAveragePrecisionConfig(BaseConfig): + name: Literal["detection_average_precision"] = ( + "detection_average_precision" + ) + ignore_non_predictions: bool = True + + +class DetectionAveragePrecision: + def __init__(self, ignore_non_predictions: bool = True): + self.ignore_non_predictions = ignore_non_predictions + + def __call__( + self, + clip_evaluations: Sequence[ClipMatches], + ) -> float: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + num_positives += int(m.gt_det) + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and self.ignore_non_predictions: + continue + + y_true.append(m.gt_det) + y_score.append(m.pred_score) + + return average_precision(y_true, y_score, num_positives=num_positives) + + @metrics_registry.register(DetectionAveragePrecisionConfig) + @staticmethod + def from_config(config: DetectionAveragePrecisionConfig): + return DetectionAveragePrecision( + ignore_non_predictions=config.ignore_non_predictions + ) + + +class TopClassAveragePrecisionConfig(BaseConfig): + name: Literal["top_class_average_precision"] = ( + "top_class_average_precision" + ) + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class TopClassAveragePrecision: + def __init__( + self, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + ): + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipMatches], + ) -> float: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if m.is_generic and self.ignore_generic: + # Ignore ground truth sounds with unknown class + continue + + num_positives += int(m.gt_det) + + if not m.is_prediction and self.ignore_non_predictions: + # Ignore matches that don't correspond to a prediction + continue + + y_true.append(m.gt_det & (m.top_class == m.gt_class)) + y_score.append(m.top_class_score) + + return average_precision(y_true, y_score, num_positives=num_positives) + + @metrics_registry.register(TopClassAveragePrecisionConfig) + @staticmethod + def from_config(config: TopClassAveragePrecisionConfig): + return TopClassAveragePrecision( + ignore_non_predictions=config.ignore_non_predictions + ) + + +class DetectionROCAUCConfig(BaseConfig): + name: Literal["detection_roc_auc"] = "detection_roc_auc" + ignore_non_predictions: bool = True + + +class DetectionROCAUC: + def __init__( + self, + ignore_non_predictions: bool = True, + ): + self.ignore_non_predictions = ignore_non_predictions + + def __call__(self, clip_evaluations: Sequence[ClipMatches]) -> float: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if not m.is_prediction and self.ignore_non_predictions: + # Ignore matches that don't correspond to a prediction + continue + + y_true.append(m.gt_det) + y_score.append(m.pred_score) + + return float(metrics.roc_auc_score(y_true, y_score)) + + @metrics_registry.register(DetectionROCAUCConfig) + @staticmethod + def from_config(config: DetectionROCAUCConfig): + return DetectionROCAUC( + ignore_non_predictions=config.ignore_non_predictions + ) + + +class DetectionRecallConfig(BaseConfig): + name: Literal["detection_recall"] = "detection_recall" + threshold: float = 0.5 + + +class DetectionRecall: + def __init__(self, threshold: float): + self.threshold = threshold + + def __call__( + self, + clip_evaluations: Sequence[ClipMatches], + ) -> float: + num_positives = 0 + true_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if m.gt_det: + num_positives += 1 + + if m.pred_score >= self.threshold and m.gt_det: + true_positives += 1 + + if num_positives == 0: + return 1 + + return true_positives / num_positives + + @metrics_registry.register(DetectionRecallConfig) + @staticmethod + def from_config(config: DetectionRecallConfig): + return DetectionRecall(threshold=config.threshold) + + +class DetectionPrecisionConfig(BaseConfig): + name: Literal["detection_precision"] = "detection_precision" + threshold: float = 0.5 + + +class DetectionPrecision: + def __init__(self, threshold: float): + self.threshold = threshold + + def __call__( + self, + clip_evaluations: Sequence[ClipMatches], + ) -> float: + num_detections = 0 + true_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + is_detection = m.pred_score >= self.threshold + + if is_detection: + num_detections += 1 + + if is_detection and m.gt_det: + true_positives += 1 + + if num_detections == 0: + return np.nan + + return true_positives / num_detections + + @metrics_registry.register(DetectionPrecisionConfig) + @staticmethod + def from_config(config: DetectionPrecisionConfig): + return DetectionPrecision(threshold=config.threshold) + + +MatchMetricConfig = Annotated[ + Union[ + DetectionAveragePrecisionConfig, + DetectionROCAUCConfig, + DetectionRecallConfig, + DetectionPrecisionConfig, + TopClassAveragePrecisionConfig, + ], + Field(discriminator="name"), +] + + +def build_match_metric(config: MatchMetricConfig): + return metrics_registry.build(config) diff --git a/src/batdetect2/evaluate/metrics/per_class_matches.py b/src/batdetect2/evaluate/metrics/per_class_matches.py new file mode 100644 index 0000000..51e0a8a --- /dev/null +++ b/src/batdetect2/evaluate/metrics/per_class_matches.py @@ -0,0 +1,136 @@ +from typing import Annotated, Callable, Literal, Sequence, Union + +from pydantic import Field +from sklearn import metrics + +from batdetect2.core import BaseConfig, Registry +from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.typing import ( + ClipMatches, +) + +__all__ = [] + +PerClassMatchMetric = Callable[[Sequence[ClipMatches], str], float] + + +metrics_registry: Registry[PerClassMatchMetric, []] = Registry( + "match_metric" +) + + +class ClassificationAveragePrecisionConfig(BaseConfig): + name: Literal["classification_average_precision"] = ( + "classification_average_precision" + ) + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ClassificationAveragePrecision: + def __init__( + self, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + ): + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipMatches], + class_name: str, + ) -> float: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + is_class = m.gt_class == class_name + + if is_class: + num_positives += 1 + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and self.ignore_non_predictions: + continue + + # Exclude matches with ground truth sounds where the class is + # unknown + if m.is_generic and self.ignore_generic: + continue + + y_true.append(is_class) + y_score.append(m.pred_class_scores.get(class_name, 0)) + + return average_precision(y_true, y_score, num_positives=num_positives) + + @metrics_registry.register(ClassificationAveragePrecisionConfig) + @staticmethod + def from_config(config: ClassificationAveragePrecisionConfig): + return ClassificationAveragePrecision( + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ClassificationROCAUCConfig(BaseConfig): + name: Literal["classification_roc_auc"] = "classification_roc_auc" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ClassificationROCAUC: + def __init__( + self, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + ): + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipMatches], + class_name: str, + ) -> float: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + # Exclude matches with ground truth sounds where the class is + # unknown + if m.is_generic and self.ignore_generic: + continue + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and self.ignore_non_predictions: + continue + + y_true.append(m.gt_class == class_name) + y_score.append(m.pred_class_scores.get(class_name, 0)) + + return float(metrics.roc_auc_score(y_true, y_score)) + + @metrics_registry.register(ClassificationROCAUCConfig) + @staticmethod + def from_config(config: ClassificationROCAUCConfig): + return ClassificationROCAUC( + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +PerClassMatchMetricConfig = Annotated[ + Union[ + ClassificationAveragePrecisionConfig, + ClassificationROCAUCConfig, + ], + Field(discriminator="name"), +] + + +def build_per_class_matches_metric(config: PerClassMatchMetricConfig): + return metrics_registry.build(config) diff --git a/src/batdetect2/evaluate/plots.py b/src/batdetect2/evaluate/plots.py index 5ca5092..53a0420 100644 --- a/src/batdetect2/evaluate/plots.py +++ b/src/batdetect2/evaluate/plots.py @@ -17,7 +17,7 @@ from batdetect2.plotting.matches import plot_matches from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.typing import ( AudioLoader, - ClipEvaluation, + ClipMatches, MatchEvaluation, PlotterProtocol, PreprocessorProtocol, @@ -53,7 +53,7 @@ class ExampleGallery(PlotterProtocol): self.preprocessor = preprocessor or build_preprocessor() self.audio_loader = audio_loader or build_audio_loader() - def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + def __call__(self, clip_evaluations: Sequence[ClipMatches]): per_class_matches = group_matches(clip_evaluations) for class_name, matches in per_class_matches.items(): @@ -128,7 +128,7 @@ class PlotClipEvaluation(PlotterProtocol): self.audio_loader = audio_loader self.num_plots = num_plots - def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + def __call__(self, clip_evaluations: Sequence[ClipMatches]): examples = random.sample( clip_evaluations, k=min(self.num_plots, len(clip_evaluations)), @@ -171,7 +171,7 @@ class DetectionPRCurveConfig(BaseConfig): class DetectionPRCurve(PlotterProtocol): - def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + def __call__(self, clip_evaluations: Sequence[ClipMatches]): y_true, y_score = zip( *[ (match.gt_det, match.pred_score) @@ -231,7 +231,7 @@ class ClassificationPRCurves(PlotterProtocol): if class_name not in exclude ] - def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + def __call__(self, clip_evaluations: Sequence[ClipMatches]): y_true = [] y_pred = [] @@ -303,7 +303,7 @@ class DetectionROCCurveConfig(BaseConfig): class DetectionROCCurve(PlotterProtocol): - def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + def __call__(self, clip_evaluations: Sequence[ClipMatches]): y_true, y_score = zip( *[ (match.gt_det, match.pred_score) @@ -363,7 +363,7 @@ class ClassificationROCCurves(PlotterProtocol): if class_name not in exclude ] - def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + def __call__(self, clip_evaluations: Sequence[ClipMatches]): y_true = [] y_pred = [] @@ -440,7 +440,7 @@ class ConfusionMatrix(PlotterProtocol): self.background_class = background_class self.class_names = class_names - def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + def __call__(self, clip_evaluations: Sequence[ClipMatches]): y_true = [] y_pred = [] @@ -456,7 +456,7 @@ class ConfusionMatrix(PlotterProtocol): else self.background_class ) - top_class = match.pred_class + top_class = match.top_class y_pred.append( top_class if top_class is not None @@ -515,14 +515,14 @@ class ClassMatches: def group_matches( - clip_evaluations: Sequence[ClipEvaluation], + clip_evaluations: Sequence[ClipMatches], ) -> Dict[str, ClassMatches]: class_examples = defaultdict(ClassMatches) for clip_evaluation in clip_evaluations: for match in clip_evaluation.matches: gt_class = match.gt_class - pred_class = match.pred_class + pred_class = match.top_class if pred_class is None: class_examples[gt_class].false_negatives.append(match) @@ -550,7 +550,7 @@ def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5): *[ (index, match.pred_class_scores[pred_class]) for index, match in enumerate(matches) - if (pred_class := match.pred_class) is not None + if (pred_class := match.top_class) is not None ] ) diff --git a/src/batdetect2/evaluate/tables.py b/src/batdetect2/evaluate/tables.py index 9e36dbf..d623529 100644 --- a/src/batdetect2/evaluate/tables.py +++ b/src/batdetect2/evaluate/tables.py @@ -5,9 +5,9 @@ from pydantic import Field from soundevent.geometry import compute_bounds from batdetect2.core import BaseConfig, Registry -from batdetect2.typing import ClipEvaluation +from batdetect2.typing import ClipMatches -EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame] +EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame] tables_registry: Registry[EvaluationTableGenerator, []] = Registry( @@ -21,20 +21,18 @@ class FullEvaluationTableConfig(BaseConfig): class FullEvaluationTable: def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] + self, clip_evaluations: Sequence[ClipMatches] ) -> pd.DataFrame: return extract_matches_dataframe(clip_evaluations) - @classmethod - def from_config(cls, config: FullEvaluationTableConfig): - return cls() - - -tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable) + @tables_registry.register(FullEvaluationTableConfig) + @staticmethod + def from_config(config: FullEvaluationTableConfig): + return FullEvaluationTable() def extract_matches_dataframe( - clip_evaluations: Sequence[ClipEvaluation], + clip_evaluations: Sequence[ClipMatches], ) -> pd.DataFrame: data = [] @@ -78,8 +76,8 @@ def extract_matches_dataframe( ("gt", "low_freq"): gt_low_freq, ("gt", "high_freq"): gt_high_freq, ("pred", "score"): match.pred_score, - ("pred", "class"): match.pred_class, - ("pred", "class_score"): match.pred_class_score, + ("pred", "class"): match.top_class, + ("pred", "class_score"): match.top_class_score, ("pred", "start_time"): pred_start_time, ("pred", "end_time"): pred_end_time, ("pred", "low_freq"): pred_low_freq, diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index b0b33a5..400a23a 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -65,8 +65,6 @@ def plot_anchor_points( if not targets.filter(sound_event): continue - sound_event = targets.transform(sound_event) - position, _ = targets.encode_roi(sound_event) positions.append(position) diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index ae4775d..29a6cff 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -162,7 +162,7 @@ def plot_false_positive_match( plt.text( start_time, high_freq, - f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ", + f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.top_class} \nTop Class Score: {match.top_class_score:.2f} ", va="top", ha="right", color=color, @@ -312,7 +312,7 @@ def plot_true_positive_match( plt.text( start_time, high_freq, - f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", + f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ", va="top", ha="right", color=color, @@ -394,7 +394,7 @@ def plot_cross_trigger_match( plt.text( start_time, high_freq, - f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", + f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.top_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ", va="top", ha="right", color=color, diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index a6debb0..a9fd1f1 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -28,12 +28,10 @@ class CenterAudio(torch.nn.Module): def forward(self, wav: torch.Tensor) -> torch.Tensor: return center_tensor(wav) - @classmethod - def from_config(cls, config: CenterAudioConfig, samplerate: int): - return cls() - - -audio_transforms.register(CenterAudioConfig, CenterAudio) + @audio_transforms.register(CenterAudioConfig) + @staticmethod + def from_config(config: CenterAudioConfig, samplerate: int): + return CenterAudio() class ScaleAudioConfig(BaseConfig): @@ -44,12 +42,10 @@ class ScaleAudio(torch.nn.Module): def forward(self, wav: torch.Tensor) -> torch.Tensor: return peak_normalize(wav) - @classmethod - def from_config(cls, config: ScaleAudioConfig, samplerate: int): - return cls() - - -audio_transforms.register(ScaleAudioConfig, ScaleAudio) + @audio_transforms.register(ScaleAudioConfig) + @staticmethod + def from_config(config: ScaleAudioConfig, samplerate: int): + return ScaleAudio() class FixDurationConfig(BaseConfig): @@ -75,13 +71,12 @@ class FixDuration(torch.nn.Module): return torch.nn.functional.pad(wav, (0, self.length - length)) - @classmethod - def from_config(cls, config: FixDurationConfig, samplerate: int): - return cls(samplerate=samplerate, duration=config.duration) + @audio_transforms.register(FixDurationConfig) + @staticmethod + def from_config(config: FixDurationConfig, samplerate: int): + return FixDuration(samplerate=samplerate, duration=config.duration) -audio_transforms.register(FixDurationConfig, FixDuration) - AudioTransform = Annotated[ Union[ FixDurationConfig, diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 9b8fa7a..6859da9 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -285,10 +285,11 @@ class PCEN(torch.nn.Module): * torch.expm1(self.power * torch.log1p(S * smooth / self.bias)) ).to(spec.dtype) - @classmethod - def from_config(cls, config: PcenConfig, samplerate: int): + @spectrogram_transforms.register(PcenConfig) + @staticmethod + def from_config(config: PcenConfig, samplerate: int): smooth = _compute_smoothing_constant(samplerate, config.time_constant) - return cls( + return PCEN( smoothing_constant=smooth, gain=config.gain, bias=config.bias, @@ -296,9 +297,6 @@ class PCEN(torch.nn.Module): ) -spectrogram_transforms.register(PcenConfig, PCEN) - - def _compute_smoothing_constant( samplerate: int, time_constant: float, @@ -335,12 +333,10 @@ class ScaleAmplitude(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: return self.scaler(spec) - @classmethod - def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int): - return cls(scale=config.scale) - - -spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude) + @spectrogram_transforms.register(ScaleAmplitudeConfig) + @staticmethod + def from_config(config: ScaleAmplitudeConfig, samplerate: int): + return ScaleAmplitude(scale=config.scale) class SpectralMeanSubstractionConfig(BaseConfig): @@ -352,19 +348,13 @@ class SpectralMeanSubstraction(torch.nn.Module): mean = spec.mean(-1, keepdim=True) return (spec - mean).clamp(min=0) - @classmethod + @spectrogram_transforms.register(SpectralMeanSubstractionConfig) + @staticmethod def from_config( - cls, config: SpectralMeanSubstractionConfig, samplerate: int, ): - return cls() - - -spectrogram_transforms.register( - SpectralMeanSubstractionConfig, - SpectralMeanSubstraction, -) + return SpectralMeanSubstraction() class PeakNormalizeConfig(BaseConfig): @@ -375,13 +365,12 @@ class PeakNormalize(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: return peak_normalize(spec) - @classmethod - def from_config(cls, config: PeakNormalizeConfig, samplerate: int): - return cls() + @spectrogram_transforms.register(PeakNormalizeConfig) + @staticmethod + def from_config(config: PeakNormalizeConfig, samplerate: int): + return PeakNormalize() -spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize) - SpectrogramTransform = Annotated[ Union[ PcenConfig, diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index bfb4eeb..47d7c98 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -99,7 +99,7 @@ DEFAULT_DETECTION_CLASS = TargetClassConfig( DEFAULT_CLASSES = [ TargetClassConfig( name="barbar", - tags=[data.Tag(key="class", value="Barbastellus barbastellus")], + tags=[data.Tag(key="class", value="Barbastella barbastellus")], ), TargetClassConfig( name="eptser", diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 029a90a..226dfe8 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -1,11 +1,11 @@ from batdetect2.train.augmentations import ( AugmentationsConfig, - EchoAugmentationConfig, - FrequencyMaskAugmentationConfig, + AddEchoConfig, + MaskFrequencyConfig, RandomAudioSource, - TimeMaskAugmentationConfig, - VolumeAugmentationConfig, - WarpAugmentationConfig, + MaskTimeConfig, + ScaleVolumeConfig, + WarpConfig, add_echo, build_augmentations, mask_frequency, @@ -43,20 +43,20 @@ __all__ = [ "AugmentationsConfig", "ClassificationLossConfig", "DetectionLossConfig", - "EchoAugmentationConfig", - "FrequencyMaskAugmentationConfig", + "AddEchoConfig", + "MaskFrequencyConfig", "LossConfig", "LossFunction", "PLTrainerConfig", "RandomAudioSource", "SizeLossConfig", - "TimeMaskAugmentationConfig", + "MaskTimeConfig", "TrainingConfig", "TrainingDataset", "TrainingModule", "ValidationDataset", - "VolumeAugmentationConfig", - "WarpAugmentationConfig", + "ScaleVolumeConfig", + "WarpConfig", "add_echo", "build_augmentations", "build_clip_labeler", diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 172c5ed..7139899 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -12,21 +12,23 @@ from soundevent import data from soundevent.geometry import scale_geometry, shift_geometry from batdetect2.audio.clips import get_subclip_annotation +from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ from batdetect2.core.arrays import adjust_width from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.core.registries import Registry from batdetect2.typing import AudioLoader, Augmentation __all__ = [ "AugmentationConfig", "AugmentationsConfig", "DEFAULT_AUGMENTATION_CONFIG", - "EchoAugmentationConfig", + "AddEchoConfig", "AudioSource", - "FrequencyMaskAugmentationConfig", - "MixAugmentationConfig", - "TimeMaskAugmentationConfig", - "VolumeAugmentationConfig", - "WarpAugmentationConfig", + "MaskFrequencyConfig", + "MixAudioConfig", + "MaskTimeConfig", + "ScaleVolumeConfig", + "WarpConfig", "add_echo", "build_augmentations", "load_augmentation_config", @@ -37,10 +39,19 @@ __all__ = [ "warp_spectrogram", ] + AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]] +audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = ( + Registry(name="audio_augmentation") +) -class MixAugmentationConfig(BaseConfig): +spec_augmentations: Registry[Augmentation, []] = Registry( + name="spec_augmentation" +) + + +class MixAudioConfig(BaseConfig): """Configuration for MixUp augmentation (mixing two examples).""" name: Literal["mix_audio"] = "mix_audio" @@ -87,6 +98,19 @@ class MixAudio(torch.nn.Module): ) return mixed_audio, mixed_annotations + @audio_augmentations.register(MixAudioConfig) + @staticmethod + def from_config( + config: MixAudioConfig, + samplerate: int, + source: Optional[AudioSource], + ): + return MixAudio( + example_source=source, + min_weight=config.min_weight, + max_weight=config.max_weight, + ) + def mix_audio( wav1: torch.Tensor, @@ -136,7 +160,7 @@ def combine_clip_annotations( ) -class EchoAugmentationConfig(BaseConfig): +class AddEchoConfig(BaseConfig): """Configuration for adding synthetic echo/reverb.""" name: Literal["add_echo"] = "add_echo" @@ -149,14 +173,17 @@ class EchoAugmentationConfig(BaseConfig): class AddEcho(torch.nn.Module): def __init__( self, + samplerate: int = TARGET_SAMPLERATE_HZ, min_weight: float = 0.1, max_weight: float = 1.0, - max_delay: int = 2560, + max_delay: float = 0.005, ): super().__init__() + self.samplerate = samplerate self.min_weight = min_weight self.max_weight = max_weight - self.max_delay = max_delay + self.max_delay_s = max_delay + self.max_delay = int(max_delay * samplerate) def forward( self, @@ -167,6 +194,18 @@ class AddEcho(torch.nn.Module): weight = np.random.uniform(self.min_weight, self.max_weight) return add_echo(wav, delay=delay, weight=weight), clip_annotation + @audio_augmentations.register(AddEchoConfig) + @staticmethod + def from_config( + config: AddEchoConfig, samplerate: int, source: AudioSource + ): + return AddEcho( + samplerate=samplerate, + min_weight=config.min_weight, + max_weight=config.max_weight, + max_delay=config.max_delay, + ) + def add_echo( wav: torch.Tensor, @@ -183,7 +222,7 @@ def add_echo( return mix_audio(wav, audio_delay, weight) -class VolumeAugmentationConfig(BaseConfig): +class ScaleVolumeConfig(BaseConfig): """Configuration for random volume scaling of the spectrogram.""" name: Literal["scale_volume"] = "scale_volume" @@ -206,19 +245,27 @@ class ScaleVolume(torch.nn.Module): factor = np.random.uniform(self.min_scaling, self.max_scaling) return scale_volume(spec, factor=factor), clip_annotation + @spec_augmentations.register(ScaleVolumeConfig) + @staticmethod + def from_config(config: ScaleVolumeConfig): + return ScaleVolume( + min_scaling=config.min_scaling, + max_scaling=config.max_scaling, + ) + def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor: """Scale the amplitude of the spectrogram by a factor.""" return spec * factor -class WarpAugmentationConfig(BaseConfig): +class WarpConfig(BaseConfig): name: Literal["warp"] = "warp" probability: float = 0.2 delta: float = 0.04 -class WarpSpectrogram(torch.nn.Module): +class Warp(torch.nn.Module): def __init__(self, delta: float = 0.04) -> None: super().__init__() self.delta = delta @@ -234,6 +281,11 @@ class WarpSpectrogram(torch.nn.Module): warp_clip_annotation(clip_annotation, factor=factor), ) + @spec_augmentations.register(WarpConfig) + @staticmethod + def from_config(config: WarpConfig): + return Warp(delta=config.delta) + def warp_sound_event_annotation( sound_event_annotation: data.SoundEventAnnotation, @@ -294,7 +346,7 @@ def warp_spectrogram( ).squeeze(0) -class TimeMaskAugmentationConfig(BaseConfig): +class MaskTimeConfig(BaseConfig): name: Literal["mask_time"] = "mask_time" probability: float = 0.2 max_perc: float = 0.05 @@ -336,6 +388,14 @@ class MaskTime(torch.nn.Module): ] return mask_time(spec, masks), clip_annotation + @spec_augmentations.register(MaskTimeConfig) + @staticmethod + def from_config(config: MaskTimeConfig): + return MaskTime( + max_perc=config.max_perc, + max_masks=config.max_masks, + ) + def mask_time( spec: torch.Tensor, @@ -351,7 +411,7 @@ def mask_time( return spec -class FrequencyMaskAugmentationConfig(BaseConfig): +class MaskFrequencyConfig(BaseConfig): name: Literal["mask_freq"] = "mask_freq" probability: float = 0.2 max_perc: float = 0.10 @@ -394,6 +454,14 @@ class MaskFrequency(torch.nn.Module): ] return mask_frequency(spec, masks), clip_annotation + @spec_augmentations.register(MaskFrequencyConfig) + @staticmethod + def from_config(config: MaskFrequencyConfig): + return MaskFrequency( + max_perc=config.max_perc, + max_masks=config.max_masks, + ) + def mask_frequency( spec: torch.Tensor, @@ -410,8 +478,8 @@ def mask_frequency( AudioAugmentationConfig = Annotated[ Union[ - MixAugmentationConfig, - EchoAugmentationConfig, + MixAudioConfig, + AddEchoConfig, ], Field(discriminator="name"), ] @@ -419,22 +487,22 @@ AudioAugmentationConfig = Annotated[ SpectrogramAugmentationConfig = Annotated[ Union[ - VolumeAugmentationConfig, - WarpAugmentationConfig, - FrequencyMaskAugmentationConfig, - TimeMaskAugmentationConfig, + ScaleVolumeConfig, + WarpConfig, + MaskFrequencyConfig, + MaskTimeConfig, ], Field(discriminator="name"), ] AugmentationConfig = Annotated[ Union[ - MixAugmentationConfig, - EchoAugmentationConfig, - VolumeAugmentationConfig, - WarpAugmentationConfig, - FrequencyMaskAugmentationConfig, - TimeMaskAugmentationConfig, + MixAudioConfig, + AddEchoConfig, + ScaleVolumeConfig, + WarpConfig, + MaskFrequencyConfig, + MaskTimeConfig, ], Field(discriminator="name"), ] @@ -513,7 +581,7 @@ def build_augmentation_from_config( ) if config.name == "warp": - return WarpSpectrogram( + return Warp( delta=config.delta, ) @@ -538,14 +606,14 @@ def build_augmentation_from_config( DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( enabled=True, audio=[ - MixAugmentationConfig(), - EchoAugmentationConfig(), + MixAudioConfig(), + AddEchoConfig(), ], spectrogram=[ - VolumeAugmentationConfig(), - WarpAugmentationConfig(), - TimeMaskAugmentationConfig(), - FrequencyMaskAugmentationConfig(), + ScaleVolumeConfig(), + WarpConfig(), + MaskTimeConfig(), + MaskFrequencyConfig(), ], ) @@ -566,9 +634,9 @@ class AugmentationSequence(torch.nn.Module): return tensor, clip_annotation -def build_augmentation_sequence( - samplerate: int, - steps: Optional[Sequence[AugmentationConfig]] = None, +def build_audio_augmentations( + steps: Optional[Sequence[AudioAugmentationConfig]] = None, + samplerate: int = TARGET_SAMPLERATE_HZ, audio_source: Optional[AudioSource] = None, ) -> Optional[Augmentation]: if not steps: @@ -577,10 +645,8 @@ def build_augmentation_sequence( augmentations = [] for step_config in steps: - augmentation = build_augmentation_from_config( - step_config, - samplerate=samplerate, - audio_source=audio_source, + augmentation = audio_augmentations.build( + step_config, samplerate, audio_source ) if augmentation is None: diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 19b9751..3c1ed24 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -10,7 +10,6 @@ from batdetect2.postprocess import to_raw_predictions from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule from batdetect2.typing import ( - ClipEvaluation, EvaluatorProtocol, ModelOutput, RawPrediction, @@ -37,22 +36,26 @@ class ValidationMetrics(Callback): def generate_plots( self, pl_module: LightningModule, - evaluated_clips: List[ClipEvaluation], ): plotter = get_image_logger(pl_module.logger) # type: ignore if plotter is None: return - for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): + for figure_name, fig in self.evaluator.generate_plots( + self._clip_annotations, + self._predictions, + ): plotter(figure_name, fig, pl_module.global_step) def log_metrics( self, pl_module: LightningModule, - evaluated_clips: List[ClipEvaluation], ): - metrics = self.evaluator.compute_metrics(evaluated_clips) + metrics = self.evaluator.compute_metrics( + self._clip_annotations, + self._predictions, + ) pl_module.log_dict(metrics) def on_validation_epoch_end( @@ -60,13 +63,8 @@ class ValidationMetrics(Callback): trainer: Trainer, pl_module: LightningModule, ) -> None: - clip_evaluations = self.evaluator.evaluate( - self._clip_annotations, - self._predictions, - ) - - self.log_metrics(pl_module, clip_evaluations) - self.generate_plots(pl_module, clip_evaluations) + self.log_metrics(pl_module) + self.generate_plots(pl_module) return super().on_validation_epoch_end(trainer, pl_module) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 09470e1..0dfc36c 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -105,7 +105,10 @@ def train( trainer = trainer or build_trainer( config, targets=targets, - evaluator=build_evaluator(config.train.validation, targets=targets), + evaluator=build_evaluator( + config.train.validation.evaluator, + targets=targets, + ), checkpoint_dir=checkpoint_dir, log_dir=log_dir, experiment_name=experiment_name, diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py index a06f387..60697e0 100644 --- a/src/batdetect2/typing/__init__.py +++ b/src/batdetect2/typing/__init__.py @@ -1,6 +1,8 @@ from batdetect2.typing.evaluate import ( - ClipEvaluation, + AffinityFunction, + ClipMatches, EvaluatorProtocol, + MatcherProtocol, MatchEvaluation, MetricsProtocol, PlotterProtocol, @@ -36,19 +38,22 @@ from batdetect2.typing.train import ( ) __all__ = [ + "AffinityFunction", "AudioLoader", "Augmentation", "BackboneModel", "BatDetect2Prediction", - "ClipEvaluation", + "ClipMatches", "ClipLabeller", "ClipperProtocol", "DetectionModel", + "EvaluatorProtocol", "GeometryDecoder", "Heatmaps", "LossProtocol", "Losses", "MatchEvaluation", + "MatcherProtocol", "MetricsProtocol", "ModelOutput", "PlotterProtocol", @@ -63,5 +68,4 @@ __all__ = [ "SoundEventFilter", "TargetProtocol", "TrainExample", - "EvaluatorProtocol", ] diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 3c71405..8c22c52 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -31,6 +31,7 @@ class MatchEvaluation: sound_event_annotation: Optional[data.SoundEventAnnotation] gt_det: bool gt_class: Optional[str] + gt_geometry: Optional[data.Geometry] pred_score: float pred_class_scores: Dict[str, float] @@ -39,44 +40,32 @@ class MatchEvaluation: affinity: float @property - def pred_class(self) -> Optional[str]: + def top_class(self) -> Optional[str]: if not self.pred_class_scores: return None return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore @property - def pred_class_score(self) -> float: - pred_class = self.pred_class + def is_prediction(self) -> bool: + return self.pred_geometry is not None + + @property + def is_generic(self) -> bool: + return self.gt_det and self.gt_class is None + + @property + def top_class_score(self) -> float: + pred_class = self.top_class if pred_class is None: return 0 return self.pred_class_scores[pred_class] - def is_true_positive(self, threshold: float = 0) -> bool: - return ( - self.gt_det - and self.pred_score > threshold - and self.gt_class == self.pred_class - ) - - def is_false_positive(self, threshold: float = 0) -> bool: - return self.gt_det is None and self.pred_score > threshold - - def is_false_negative(self, threshold: float = 0) -> bool: - return self.gt_det and self.pred_score <= threshold - - def is_cross_trigger(self, threshold: float = 0) -> bool: - return ( - self.gt_det - and self.pred_score > threshold - and self.gt_class != self.pred_class - ) - @dataclass -class ClipEvaluation: +class ClipMatches: clip: data.Clip matches: List[MatchEvaluation] @@ -103,29 +92,36 @@ class AffinityFunction(Protocol, Generic[Geom]): class MetricsProtocol(Protocol): def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], ) -> Dict[str, float]: ... class PlotterProtocol(Protocol): def __call__( - self, clip_evaluations: Sequence[ClipEvaluation] + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], ) -> Iterable[Tuple[str, Figure]]: ... -class EvaluatorProtocol(Protocol): +EvaluationOutput = TypeVar("EvaluationOutput") + + +class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): targets: TargetProtocol def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], predictions: Sequence[Sequence[RawPrediction]], - ) -> List[ClipEvaluation]: ... + ) -> EvaluationOutput: ... def compute_metrics( - self, clip_evaluations: Sequence[ClipEvaluation] + self, eval_outputs: EvaluationOutput ) -> Dict[str, float]: ... def generate_plots( - self, clip_evaluations: Sequence[ClipEvaluation] + self, eval_outputs: EvaluationOutput ) -> Iterable[Tuple[str, Figure]]: ...