Better evaluation organisation

This commit is contained in:
mbsantiago 2025-09-25 17:48:29 +01:00
parent 4cd983a2c2
commit d6ddc4514c
39 changed files with 1704 additions and 1253 deletions

View File

@ -140,13 +140,14 @@ train:
validation: validation:
metrics: metrics:
- name: detection_ap - name: detection_ap
- name: detection_roc_auc
- name: classification_ap - name: classification_ap
- name: classification_roc_auc plots:
- name: top_class_ap - name: example_gallery
- name: classification_balanced_accuracy - name: example_clip
- name: clip_ap - name: detection_pr_curve
- name: clip_roc_auc - name: classification_pr_curves
- name: detection_roc_curve
- name: classification_roc_curves
evaluation: evaluation:
match_strategy: match_strategy:
@ -155,6 +156,14 @@ evaluation:
metrics: metrics:
- name: classification_ap - name: classification_ap
- name: detection_ap - name: detection_ap
- name: detection_roc_auc
- name: classification_roc_auc
- name: top_class_ap
- name: classification_balanced_accuracy
- name: clip_multiclass_ap
- name: clip_multiclass_roc_auc
- name: clip_detection_ap
- name: clip_detection_roc_auc
plots: plots:
- name: example_gallery - name: example_gallery
- name: example_clip - name: example_clip

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
from typing import Optional, Sequence from typing import List, Optional, Sequence
import torch
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
@ -8,6 +9,7 @@ from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.decoding import to_raw_predictions
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.targets import build_targets from batdetect2.targets.targets import build_targets
from batdetect2.train import train from batdetect2.train import train
@ -19,6 +21,7 @@ from batdetect2.typing import (
PreprocessorProtocol, PreprocessorProtocol,
TargetProtocol, TargetProtocol,
) )
from batdetect2.typing.postprocess import RawPrediction
class BatDetect2API: class BatDetect2API:
@ -92,6 +95,18 @@ class BatDetect2API:
run_name=run_name, run_name=run_name,
) )
def process_spectrogram(
self,
spec: torch.Tensor,
start_times: Optional[Sequence[float]] = None,
) -> List[List[RawPrediction]]:
outputs = self.model.detector(spec)
clip_detections = self.postprocessor(outputs, start_times=start_times)
return [
to_raw_predictions(clip_dets.numpy(), self.targets)
for clip_dets in clip_detections
]
@classmethod @classmethod
def from_config(cls, config: BatDetect2Config): def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)
@ -109,7 +124,7 @@ class BatDetect2API:
) )
evaluator = build_evaluator( evaluator = build_evaluator(
config=config.evaluation, config=config.evaluation.evaluator,
targets=targets, targets=targets,
) )
@ -164,7 +179,7 @@ class BatDetect2API:
) )
evaluator = build_evaluator( evaluator = build_evaluator(
config=config.evaluation, config=config.evaluation.evaluator,
targets=targets, targets=targets,
) )

View File

@ -56,18 +56,16 @@ class RandomClip:
min_sound_event_overlap=self.min_sound_event_overlap, min_sound_event_overlap=self.min_sound_event_overlap,
) )
@classmethod @clipper_registry.register(RandomClipConfig)
def from_config(cls, config: RandomClipConfig): @staticmethod
return cls( def from_config(config: RandomClipConfig):
return RandomClip(
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap, min_sound_event_overlap=config.min_sound_event_overlap,
) )
clipper_registry.register(RandomClipConfig, RandomClip)
def get_subclip_annotation( def get_subclip_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
random: bool = True, random: bool = True,
@ -184,13 +182,12 @@ class PaddedClip:
) )
return clip_annotation.model_copy(update=dict(clip=clip)) return clip_annotation.model_copy(update=dict(clip=clip))
@classmethod @clipper_registry.register(PaddedClipConfig)
def from_config(cls, config: PaddedClipConfig): @staticmethod
return cls(chunk_size=config.chunk_size) def from_config(config: PaddedClipConfig):
return PaddedClip(chunk_size=config.chunk_size)
clipper_registry.register(PaddedClipConfig, PaddedClip)
ClipConfig = Annotated[ ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
] ]

View File

@ -53,6 +53,7 @@ class BaseConfig(BaseModel):
""" """
return yaml.dump( return yaml.dump(
self.model_dump( self.model_dump(
mode="json",
exclude_none=exclude_none, exclude_none=exclude_none,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,

View File

@ -1,16 +1,16 @@
import sys import sys
from typing import Generic, Protocol, Type, TypeVar from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import assert_type
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import ParamSpec from typing import Concatenate, ParamSpec
else: else:
from typing_extensions import ParamSpec from typing_extensions import Concatenate, ParamSpec
__all__ = [ __all__ = [
"Registry", "Registry",
"SimpleRegistry",
] ]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
@ -18,19 +18,26 @@ T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type") P_Type = ParamSpec("P_Type")
class LogicProtocol(Generic[T_Config, T_Type, P_Type], Protocol): T = TypeVar("T")
"""A generic protocol for the logic classes."""
@classmethod
def from_config(
cls,
config: T_Config,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type: ...
T_Proto = TypeVar("T_Proto", bound=LogicProtocol) class SimpleRegistry(Generic[T]):
def __init__(self, name: str):
self._name = name
self._registry = {}
def register(self, name: str):
def decorator(obj: T) -> T:
self._registry[name] = obj
return obj
return decorator
def get(self, name: str) -> T:
return self._registry[name]
def has(self, name: str) -> bool:
return name in self._registry
class Registry(Generic[T_Type, P_Type]): class Registry(Generic[T_Type, P_Type]):
@ -38,13 +45,15 @@ class Registry(Generic[T_Type, P_Type]):
def __init__(self, name: str): def __init__(self, name: str):
self._name = name self._name = name
self._registry = {} self._registry: Dict[
str, Callable[Concatenate[..., P_Type], T_Type]
] = {}
self._config_types: Dict[str, Type[BaseModel]] = {}
def register( def register(
self, self,
config_cls: Type[T_Config], config_cls: Type[T_Config],
logic_cls: LogicProtocol[T_Config, T_Type, P_Type], ):
) -> None:
fields = config_cls.model_fields fields = config_cls.model_fields
if "name" not in fields: if "name" not in fields:
@ -52,10 +61,21 @@ class Registry(Generic[T_Type, P_Type]):
name = fields["name"].default name = fields["name"].default
self._config_types[name] = config_cls
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.") raise ValueError("'name' field must be a string literal.")
self._registry[name] = logic_cls def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
):
self._registry[name] = func
return func
return decorator
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
return tuple(self._config_types.values())
def build( def build(
self, self,
@ -75,4 +95,4 @@ class Registry(Generic[T_Type, P_Type]):
f"No {self._name} with name '{name}' is registered." f"No {self._name} with name '{name}' is registered."
) )
return self._registry[name].from_config(config, *args, **kwargs) return self._registry[name](config, *args, **kwargs)

View File

@ -10,7 +10,7 @@ from batdetect2.core.registries import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
condition_registry: Registry[SoundEventCondition, []] = Registry("condition") conditions: Registry[SoundEventCondition, []] = Registry("condition")
class HasTagConfig(BaseConfig): class HasTagConfig(BaseConfig):
@ -27,12 +27,10 @@ class HasTag:
) -> bool: ) -> bool:
return self.tag in sound_event_annotation.tags return self.tag in sound_event_annotation.tags
@classmethod @conditions.register(HasTagConfig)
def from_config(cls, config: HasTagConfig): @staticmethod
return cls(tag=config.tag) def from_config(config: HasTagConfig):
return HasTag(tag=config.tag)
condition_registry.register(HasTagConfig, HasTag)
class HasAllTagsConfig(BaseConfig): class HasAllTagsConfig(BaseConfig):
@ -52,12 +50,10 @@ class HasAllTags:
) -> bool: ) -> bool:
return self.tags.issubset(sound_event_annotation.tags) return self.tags.issubset(sound_event_annotation.tags)
@classmethod @conditions.register(HasAllTagsConfig)
def from_config(cls, config: HasAllTagsConfig): @staticmethod
return cls(tags=config.tags) def from_config(config: HasAllTagsConfig):
return HasAllTags(tags=config.tags)
condition_registry.register(HasAllTagsConfig, HasAllTags)
class HasAnyTagConfig(BaseConfig): class HasAnyTagConfig(BaseConfig):
@ -77,13 +73,12 @@ class HasAnyTag:
) -> bool: ) -> bool:
return bool(self.tags.intersection(sound_event_annotation.tags)) return bool(self.tags.intersection(sound_event_annotation.tags))
@classmethod @conditions.register(HasAnyTagConfig)
def from_config(cls, config: HasAnyTagConfig): @staticmethod
return cls(tags=config.tags) def from_config(config: HasAnyTagConfig):
return HasAnyTag(tags=config.tags)
condition_registry.register(HasAnyTagConfig, HasAnyTag)
Operator = Literal["gt", "gte", "lt", "lte", "eq"] Operator = Literal["gt", "gte", "lt", "lte", "eq"]
@ -134,12 +129,10 @@ class Duration:
return self._comparator(duration) return self._comparator(duration)
@classmethod @conditions.register(DurationConfig)
def from_config(cls, config: DurationConfig): @staticmethod
return cls(operator=config.operator, seconds=config.seconds) def from_config(config: DurationConfig):
return Duration(operator=config.operator, seconds=config.seconds)
condition_registry.register(DurationConfig, Duration)
class FrequencyConfig(BaseConfig): class FrequencyConfig(BaseConfig):
@ -181,18 +174,16 @@ class Frequency:
return self._comparator(high_freq) return self._comparator(high_freq)
@classmethod @conditions.register(FrequencyConfig)
def from_config(cls, config: FrequencyConfig): @staticmethod
return cls( def from_config(config: FrequencyConfig):
return Frequency(
operator=config.operator, operator=config.operator,
boundary=config.boundary, boundary=config.boundary,
hertz=config.hertz, hertz=config.hertz,
) )
condition_registry.register(FrequencyConfig, Frequency)
class AllOfConfig(BaseConfig): class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of" name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"] conditions: Sequence["SoundEventConditionConfig"]
@ -207,15 +198,13 @@ class AllOf:
) -> bool: ) -> bool:
return all(c(sound_event_annotation) for c in self.conditions) return all(c(sound_event_annotation) for c in self.conditions)
@classmethod @conditions.register(AllOfConfig)
def from_config(cls, config: AllOfConfig): @staticmethod
def from_config(config: AllOfConfig):
conditions = [ conditions = [
build_sound_event_condition(cond) for cond in config.conditions build_sound_event_condition(cond) for cond in config.conditions
] ]
return cls(conditions) return AllOf(conditions)
condition_registry.register(AllOfConfig, AllOf)
class AnyOfConfig(BaseConfig): class AnyOfConfig(BaseConfig):
@ -232,15 +221,13 @@ class AnyOf:
) -> bool: ) -> bool:
return any(c(sound_event_annotation) for c in self.conditions) return any(c(sound_event_annotation) for c in self.conditions)
@classmethod @conditions.register(AnyOfConfig)
def from_config(cls, config: AnyOfConfig): @staticmethod
def from_config(config: AnyOfConfig):
conditions = [ conditions = [
build_sound_event_condition(cond) for cond in config.conditions build_sound_event_condition(cond) for cond in config.conditions
] ]
return cls(conditions) return AnyOf(conditions)
condition_registry.register(AnyOfConfig, AnyOf)
class NotConfig(BaseConfig): class NotConfig(BaseConfig):
@ -257,14 +244,13 @@ class Not:
) -> bool: ) -> bool:
return not self.condition(sound_event_annotation) return not self.condition(sound_event_annotation)
@classmethod @conditions.register(NotConfig)
def from_config(cls, config: NotConfig): @staticmethod
def from_config(config: NotConfig):
condition = build_sound_event_condition(config.condition) condition = build_sound_event_condition(config.condition)
return cls(condition) return Not(condition)
condition_registry.register(NotConfig, Not)
SoundEventConditionConfig = Annotated[ SoundEventConditionConfig = Annotated[
Union[ Union[
HasTagConfig, HasTagConfig,
@ -283,7 +269,7 @@ SoundEventConditionConfig = Annotated[
def build_sound_event_condition( def build_sound_event_condition(
config: SoundEventConditionConfig, config: SoundEventConditionConfig,
) -> SoundEventCondition: ) -> SoundEventCondition:
return condition_registry.build(config) return conditions.build(config)
def filter_clip_annotation( def filter_clip_annotation(

View File

@ -17,7 +17,7 @@ SoundEventTransform = Callable[
data.SoundEventAnnotation, data.SoundEventAnnotation,
] ]
transform_registry: Registry[SoundEventTransform, []] = Registry("transform") transforms: Registry[SoundEventTransform, []] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig): class SetFrequencyBoundConfig(BaseConfig):
@ -63,12 +63,10 @@ class SetFrequencyBound:
update=dict(sound_event=sound_event) update=dict(sound_event=sound_event)
) )
@classmethod @transforms.register(SetFrequencyBoundConfig)
def from_config(cls, config: SetFrequencyBoundConfig): @staticmethod
return cls(hertz=config.hertz, boundary=config.boundary) def from_config(config: SetFrequencyBoundConfig):
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
transform_registry.register(SetFrequencyBoundConfig, SetFrequencyBound)
class ApplyIfConfig(BaseConfig): class ApplyIfConfig(BaseConfig):
@ -95,14 +93,12 @@ class ApplyIf:
return self.transform(sound_event_annotation) return self.transform(sound_event_annotation)
@classmethod @transforms.register(ApplyIfConfig)
def from_config(cls, config: ApplyIfConfig): @staticmethod
def from_config(config: ApplyIfConfig):
transform = build_sound_event_transform(config.transform) transform = build_sound_event_transform(config.transform)
condition = build_sound_event_condition(config.condition) condition = build_sound_event_condition(config.condition)
return cls(condition=condition, transform=transform) return ApplyIf(condition=condition, transform=transform)
transform_registry.register(ApplyIfConfig, ApplyIf)
class ReplaceTagConfig(BaseConfig): class ReplaceTagConfig(BaseConfig):
@ -134,12 +130,12 @@ class ReplaceTag:
return sound_event_annotation.model_copy(update=dict(tags=tags)) return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod @transforms.register(ReplaceTagConfig)
def from_config(cls, config: ReplaceTagConfig): @staticmethod
return cls(original=config.original, replacement=config.replacement) def from_config(config: ReplaceTagConfig):
return ReplaceTag(
original=config.original, replacement=config.replacement
transform_registry.register(ReplaceTagConfig, ReplaceTag) )
class MapTagValueConfig(BaseConfig): class MapTagValueConfig(BaseConfig):
@ -189,18 +185,16 @@ class MapTagValue:
return sound_event_annotation.model_copy(update=dict(tags=tags)) return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod @transforms.register(MapTagValueConfig)
def from_config(cls, config: MapTagValueConfig): @staticmethod
return cls( def from_config(config: MapTagValueConfig):
return MapTagValue(
tag_key=config.tag_key, tag_key=config.tag_key,
value_mapping=config.value_mapping, value_mapping=config.value_mapping,
target_key=config.target_key, target_key=config.target_key,
) )
transform_registry.register(MapTagValueConfig, MapTagValue)
class ApplyAllConfig(BaseConfig): class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all" name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list) steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
@ -219,14 +213,13 @@ class ApplyAll:
return sound_event_annotation return sound_event_annotation
@classmethod @transforms.register(ApplyAllConfig)
def from_config(cls, config: ApplyAllConfig): @staticmethod
def from_config(config: ApplyAllConfig):
steps = [build_sound_event_transform(step) for step in config.steps] steps = [build_sound_event_transform(step) for step in config.steps]
return cls(steps) return ApplyAll(steps)
transform_registry.register(ApplyAllConfig, ApplyAll)
SoundEventTransformConfig = Annotated[ SoundEventTransformConfig = Annotated[
Union[ Union[
SetFrequencyBoundConfig, SetFrequencyBoundConfig,
@ -242,7 +235,7 @@ SoundEventTransformConfig = Annotated[
def build_sound_event_transform( def build_sound_event_transform(
config: SoundEventTransformConfig, config: SoundEventTransformConfig,
) -> SoundEventTransform: ) -> SoundEventTransform:
return transform_registry.build(config) return transforms.build(config)
def transform_clip_annotation( def transform_clip_annotation(

View File

@ -1,11 +1,11 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate from batdetect2.evaluate.evaluate import evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import MultipleEvaluator, build_evaluator
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
"load_evaluation_config", "load_evaluation_config",
"evaluate", "evaluate",
"Evaluator", "MultipleEvaluator",
"build_evaluator", "build_evaluator",
] ]

View File

@ -27,12 +27,10 @@ class TimeAffinity(AffinityFunction):
geometry1, geometry2, time_buffer=self.time_buffer geometry1, geometry2, time_buffer=self.time_buffer
) )
@classmethod @affinity_functions.register(TimeAffinityConfig)
def from_config(cls, config: TimeAffinityConfig): @staticmethod
return cls(time_buffer=config.time_buffer) def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer)
affinity_functions.register(TimeAffinityConfig, TimeAffinity)
def compute_timestamp_affinity( def compute_timestamp_affinity(
@ -73,12 +71,10 @@ class IntervalIOU(AffinityFunction):
time_buffer=self.time_buffer, time_buffer=self.time_buffer,
) )
@classmethod @affinity_functions.register(IntervalIOUConfig)
def from_config(cls, config: IntervalIOUConfig): @staticmethod
return cls(time_buffer=config.time_buffer) def from_config(config: IntervalIOUConfig):
return IntervalIOU(time_buffer=config.time_buffer)
affinity_functions.register(IntervalIOUConfig, IntervalIOU)
def compute_interval_iou( def compute_interval_iou(
@ -127,13 +123,12 @@ class GeometricIOU(AffinityFunction):
time_buffer=self.time_buffer, time_buffer=self.time_buffer,
) )
@classmethod @affinity_functions.register(GeometricIOUConfig)
def from_config(cls, config: GeometricIOUConfig): @staticmethod
return cls(time_buffer=config.time_buffer) def from_config(config: GeometricIOUConfig):
return GeometricIOU(time_buffer=config.time_buffer)
affinity_functions.register(GeometricIOUConfig, GeometricIOU)
AffinityConfig = Annotated[ AffinityConfig = Annotated[
Union[ Union[
TimeAffinityConfig, TimeAffinityConfig,

View File

@ -1,16 +1,13 @@
from typing import List, Optional from typing import Optional
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig from batdetect2.evaluate.evaluator import (
from batdetect2.evaluate.metrics import ( EvaluatorConfig,
ClassificationAPConfig, MultipleEvaluatorConfig,
DetectionAPConfig,
MetricConfig,
) )
from batdetect2.evaluate.plots import PlotConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [ __all__ = [
@ -20,15 +17,7 @@ __all__ = [
class EvaluationConfig(BaseConfig): class EvaluationConfig(BaseConfig):
ignore_start_end: float = 0.01 evaluator: EvaluatorConfig = Field(default_factory=MultipleEvaluatorConfig)
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
metrics: List[MetricConfig] = Field(
default_factory=lambda: [
DetectionAPConfig(),
ClassificationAPConfig(),
]
)
plots: List[PlotConfig] = Field(default_factory=list)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)

View File

@ -55,7 +55,10 @@ def evaluate(
num_workers=num_workers, num_workers=num_workers,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets) evaluator = build_evaluator(
config=config.evaluation.evaluator,
targets=targets,
)
logger = build_logger( logger = build_logger(
config.evaluation.logger, config.evaluation.logger,

View File

@ -1,173 +0,0 @@
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from matplotlib.figure import Figure
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.match import build_matcher, match
from batdetect2.evaluate.metrics import build_metric
from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets
from batdetect2.typing.evaluate import (
ClipEvaluation,
EvaluatorProtocol,
MatcherProtocol,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"Evaluator",
"build_evaluator",
]
class Evaluator:
def __init__(
self,
config: EvaluationConfig,
targets: TargetProtocol,
matcher: MatcherProtocol,
metrics: List[MetricsProtocol],
plots: List[PlotterProtocol],
):
self.config = config
self.targets = targets
self.matcher = matcher
self.metrics = metrics
self.plots = plots
def match(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEvaluation:
clip = clip_annotation.clip
ground_truth = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
predictions = [
prediction
for prediction in predictions
if self.filter_predictions(prediction, clip)
]
return ClipEvaluation(
clip=clip_annotation.clip,
matches=match(
ground_truth,
predictions,
clip=clip,
targets=self.targets,
matcher=self.matcher,
),
)
def filter_sound_event_annotations(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.config.ignore_start_end,
)
def filter_predictions(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.config.ignore_start_end,
)
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]:
if len(clip_annotations) != len(predictions):
raise ValueError(
"Number of annotated clips and sets of predictions do not match"
)
return [
self.match(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def compute_metrics(
self,
clip_evaluations: Sequence[ClipEvaluation],
) -> Dict[str, float]:
results = {}
for metric in self.metrics:
results.update(metric(clip_evaluations))
return results
def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]:
for plotter in self.plots:
for name, fig in plotter(clip_evaluations):
yield name, fig
def build_evaluator(
config: Optional[EvaluationConfig] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None,
) -> EvaluatorProtocol:
config = config or EvaluationConfig()
targets = targets or build_targets()
matcher = matcher or build_matcher(config.match_strategy)
if metrics is None:
metrics = [
build_metric(config, targets.class_names)
for config in config.metrics
]
if plots is None:
plots = [
build_plotter(config, targets.class_names)
for config in config.plots
]
return Evaluator(
config=config,
targets=targets,
matcher=matcher,
metrics=metrics,
plots=plots,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -0,0 +1,114 @@
from typing import (
Annotated,
Any,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.evaluate.evaluator.base import evaluators
from batdetect2.evaluate.evaluator.clip import ClipMetricsConfig
from batdetect2.evaluate.evaluator.per_class import ClassificationMetricsConfig
from batdetect2.evaluate.evaluator.single import GlobalEvaluatorConfig
from batdetect2.targets import build_targets
from batdetect2.typing import (
EvaluatorProtocol,
RawPrediction,
TargetProtocol,
)
__all__ = [
"EvaluatorConfig",
"build_evaluator",
]
EvaluatorConfig = Annotated[
Union[
ClassificationMetricsConfig,
GlobalEvaluatorConfig,
ClipMetricsConfig,
"MultipleEvaluatorConfig",
],
Field(discriminator="name"),
]
class MultipleEvaluatorConfig(BaseConfig):
name: Literal["multiple_evaluations"] = "multiple_evaluations"
evaluations: List[EvaluatorConfig] = Field(
default_factory=lambda: [
ClassificationMetricsConfig(),
GlobalEvaluatorConfig(),
]
)
class MultipleEvaluator:
def __init__(
self,
targets: TargetProtocol,
evaluators: Sequence[EvaluatorProtocol],
):
self.targets = targets
self.evaluators = evaluators
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[Any]:
return [
evaluator.evaluate(
clip_annotations,
predictions,
)
for evaluator in self.evaluators
]
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
results = {}
for evaluator, outputs in zip(self.evaluators, eval_outputs):
results.update(evaluator.compute_metrics(outputs))
return results
def generate_plots(
self,
eval_outputs: List[Any],
) -> Iterable[Tuple[str, Figure]]:
for evaluator, outputs in zip(self.evaluators, eval_outputs):
for name, fig in evaluator.generate_plots(outputs):
yield name, fig
@evaluators.register(MultipleEvaluatorConfig)
@staticmethod
def from_config(config: MultipleEvaluatorConfig, targets: TargetProtocol):
return MultipleEvaluator(
evaluators=[
build_evaluator(conf, targets=targets)
for conf in config.evaluations
],
targets=targets,
)
def build_evaluator(
config: Optional[EvaluatorConfig] = None,
targets: Optional[TargetProtocol] = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
config = config or MultipleEvaluatorConfig()
return evaluators.build(config, targets)

View File

@ -0,0 +1,107 @@
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.match import (
MatchConfig,
StartTimeMatchConfig,
build_matcher,
)
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BaseEvaluatorConfig",
"BaseEvaluator",
]
evaluators: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry("metric")
class BaseEvaluatorConfig(BaseConfig):
prefix: str
ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseEvaluator(EvaluatorProtocol):
targets: TargetProtocol
matcher: MatcherProtocol
ignore_start_end: float
prefix: str
def __init__(
self,
matcher: MatcherProtocol,
targets: TargetProtocol,
prefix: str,
ignore_start_end: float = 0.01,
):
self.matcher = matcher
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end
def filter_sound_event_annotations(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.ignore_start_end,
)
def filter_predictions(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.ignore_start_end,
)
@classmethod
def build(
cls,
config: BaseEvaluatorConfig,
targets: TargetProtocol,
**kwargs,
):
matcher = build_matcher(config.matching_strategy)
return cls(
matcher=matcher,
targets=targets,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
**kwargs,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -0,0 +1,163 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Dict, List, Literal, Sequence, Set
from pydantic import Field, field_validator
from sklearn import metrics
from soundevent import data
from batdetect2.evaluate.evaluator.base import (
BaseEvaluator,
BaseEvaluatorConfig,
evaluators,
)
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
@dataclass
class ClipInfo:
gt_det: bool
gt_classes: Set[str]
pred_score: float
pred_class_scores: Dict[str, float]
ClipMetric = Callable[[Sequence[ClipInfo]], float]
def clip_detection_average_precision(
clip_evaluations: Sequence[ClipInfo],
) -> float:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.pred_score)
return average_precision(y_true=y_true, y_score=y_score)
def clip_detection_roc_auc(
clip_evaluations: Sequence[ClipInfo],
) -> float:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.pred_score)
return float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
clip_metrics = {
"average_precision": clip_detection_average_precision,
"roc_auc": clip_detection_roc_auc,
}
class ClipMetricsConfig(BaseEvaluatorConfig):
name: Literal["clip"] = "clip"
prefix: str = "clip"
metrics: List[str] = Field(
default_factory=lambda: [
"average_precision",
"roc_auc",
]
)
@field_validator("metrics", mode="after")
@classmethod
def validate_metrics(cls, v: List[str]) -> List[str]:
for metric_name in v:
if metric_name not in clip_metrics:
raise ValueError(f"Unknown metric {metric_name}")
return v
class ClipEvaluator(BaseEvaluator):
def __init__(self, *args, metrics: Dict[str, ClipMetric], **kwargs):
super().__init__(*args, **kwargs)
self.metrics = metrics
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipInfo]:
return [
self.match_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def compute_metrics(
self,
eval_outputs: List[ClipInfo],
) -> Dict[str, float]:
scores = {
name: metric(eval_outputs) for name, metric in self.metrics.items()
}
return {
f"{self.prefix}/{name}": score for name, score in scores.items()
}
def match_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipInfo:
clip = clip_annotation.clip
gt_det = False
gt_classes = set()
for sound_event in clip_annotation.sound_events:
if self.filter_sound_event_annotations(sound_event, clip):
continue
gt_det = True
class_name = self.targets.encode_class(sound_event)
if class_name is None:
continue
gt_classes.add(class_name)
pred_score = 0
pred_class_scores: defaultdict[str, float] = defaultdict(lambda: 0)
for pred in predictions:
if self.filter_predictions(pred, clip):
continue
pred_score = max(pred_score, pred.detection_score)
for class_name, class_score in zip(
self.targets.class_names,
pred.class_scores,
):
pred_class_scores[class_name] = max(
pred_class_scores[class_name],
class_score,
)
return ClipInfo(
gt_det=gt_det,
gt_classes=gt_classes,
pred_score=pred_score,
pred_class_scores=pred_class_scores,
)
@evaluators.register(ClipMetricsConfig)
@staticmethod
def from_config(
config: ClipMetricsConfig,
targets: TargetProtocol,
):
metrics = {name: clip_metrics.get(name) for name in config.metrics}
return ClipEvaluator.build(
config=config,
metrics=metrics,
targets=targets,
)

View File

@ -0,0 +1,219 @@
from collections import defaultdict
from typing import (
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
)
import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.evaluator.base import (
BaseEvaluator,
BaseEvaluatorConfig,
evaluators,
)
from batdetect2.evaluate.match import match
from batdetect2.evaluate.metrics.per_class_matches import (
ClassificationAveragePrecisionConfig,
PerClassMatchMetric,
PerClassMatchMetricConfig,
build_per_class_matches_metric,
)
from batdetect2.typing import (
ClipMatches,
RawPrediction,
TargetProtocol,
)
ScoreFn = Callable[[RawPrediction, int], float]
def score_by_class_score(pred: RawPrediction, class_index: int) -> float:
return float(pred.class_scores[class_index])
def score_by_adjusted_class_score(
pred: RawPrediction,
class_index: int,
) -> float:
return float(pred.class_scores[class_index]) * pred.detection_score
ScoreFunctionOption = Literal["class_score", "adjusted_class_score"]
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
"class_score": score_by_class_score,
"adjusted_class_score": score_by_adjusted_class_score,
}
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
return score_functions[name]
class ClassificationMetricsConfig(BaseEvaluatorConfig):
name: Literal["classification"] = "classification"
prefix: str = "classification"
include_generics: bool = True
score_by: ScoreFunctionOption = "class_score"
metrics: List[PerClassMatchMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
)
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class PerClassEvaluator(BaseEvaluator):
def __init__(
self,
*args,
metrics: Dict[str, PerClassMatchMetric],
score_fn: ScoreFn,
include_generics: bool = True,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.score_fn = score_fn
self.metrics = metrics
self.include_generics = include_generics
self.include = include
self.exclude = exclude
self.selected = self.targets.class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> Dict[str, List[ClipMatches]]:
ret = defaultdict(list)
for clip_annotation, preds in zip(clip_annotations, predictions):
matches = self.match_clip(clip_annotation, preds)
for class_name, clip_eval in matches.items():
ret[class_name].append(clip_eval)
return ret
def compute_metrics(
self,
eval_outputs: Dict[str, List[ClipMatches]],
) -> Dict[str, float]:
results = {}
for metric_name, metric in self.metrics.items():
class_scores = {
class_name: metric(eval_outputs[class_name], class_name)
for class_name in self.targets.class_names
}
mean = float(
np.mean([v for v in class_scores.values() if v != np.nan])
)
results[f"{self.prefix}/mean_{metric_name}"] = mean
for class_name, value in class_scores.items():
if class_name not in self.selected:
continue
results[f"{self.prefix}/{metric_name}/{class_name}"] = value
return results
def match_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> Dict[str, ClipMatches]:
clip = clip_annotation.clip
preds = [
pred for pred in predictions if self.filter_predictions(pred, clip)
]
all_gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
ret = {}
for class_name in self.targets.class_names:
class_idx = self.targets.class_names.index(class_name)
# Only match to targets of the given class
gts = [
sound_event
for sound_event in all_gts
if self.is_class(sound_event, class_name)
]
scores = [self.score_fn(pred, class_idx) for pred in preds]
ret[class_name] = match(
gts,
preds,
clip=clip,
scores=scores,
targets=self.targets,
matcher=self.matcher,
)
return ret
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@evaluators.register(ClassificationMetricsConfig)
@staticmethod
def from_config(
config: ClassificationMetricsConfig,
targets: TargetProtocol,
):
metrics = {
metric.name: build_per_class_matches_metric(metric)
for metric in config.metrics
}
return PerClassEvaluator.build(
config=config,
targets=targets,
metrics=metrics,
score_fn=get_score_fn(config.score_by),
include_generics=config.include_generics,
include=config.include,
exclude=config.exclude,
)

View File

@ -0,0 +1,126 @@
from typing import Callable, Dict, List, Literal, Mapping, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.evaluator.base import (
BaseEvaluator,
BaseEvaluatorConfig,
evaluators,
)
from batdetect2.evaluate.match import match
from batdetect2.evaluate.metrics.matches import (
DetectionAveragePrecisionConfig,
MatchesMetric,
MatchMetricConfig,
build_match_metric,
)
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
ScoreFn = Callable[[RawPrediction], float]
def score_by_detection_score(pred: RawPrediction) -> float:
return pred.detection_score
def score_by_top_class_score(pred: RawPrediction) -> float:
return pred.class_scores.max()
ScoreFunctionOption = Literal["detection_score", "top_class_score"]
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
"detection_score": score_by_detection_score,
"top_class_score": score_by_top_class_score,
}
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
return score_functions[name]
class GlobalEvaluatorConfig(BaseEvaluatorConfig):
name: Literal["detection"] = "detection"
prefix: str = "detection"
score_by: ScoreFunctionOption = "detection_score"
metrics: List[MatchMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
)
class GlobalEvaluator(BaseEvaluator):
def __init__(
self,
*args,
score_fn: ScoreFn,
metrics: Dict[str, MatchesMetric],
**kwargs,
):
super().__init__(*args, **kwargs)
self.metrics = metrics
self.score_fn = score_fn
def compute_metrics(
self,
eval_outputs: List[ClipMatches],
) -> Dict[str, float]:
scores = {
name: metric(eval_outputs) for name, metric in self.metrics.items()
}
return {
f"{self.prefix}/{name}": score for name, score in scores.items()
}
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipMatches]:
return [
self.match_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def match_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipMatches:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
preds = [
pred for pred in predictions if self.filter_predictions(pred, clip)
]
scores = [self.score_fn(pred) for pred in preds]
return match(
gts,
preds,
scores=scores,
clip=clip,
targets=self.targets,
matcher=self.matcher,
)
@evaluators.register(GlobalEvaluatorConfig)
@staticmethod
def from_config(
config: GlobalEvaluatorConfig,
targets: TargetProtocol,
):
metrics = {
metric.name: build_match_metric(metric)
for metric in config.metrics
}
score_fn = get_score_fn(config.score_by)
return GlobalEvaluator.build(
config=config,
score_fn=score_fn,
metrics=metrics,
targets=targets,
)

View File

@ -0,0 +1,133 @@
from typing import Dict, List, Literal, Sequence
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.evaluate.match import match
from batdetect2.evaluate.metrics.base import (
BaseMetric,
BaseMetricConfig,
metrics_registry,
)
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.evaluate.metrics.detection import DetectionMetric
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
__all__ = [
"TopClassEvaluator",
"TopClassEvaluatorConfig",
]
def top_class_average_precision(
clip_evaluations: Sequence[ClipMatches],
) -> float:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_generic = m.gt_det and (m.gt_class is None)
# Ignore ground truth sounds with unknown class
if is_generic:
continue
num_positives += int(m.gt_det)
# Ignore matches that don't correspond to a prediction
if m.pred_geometry is None:
continue
y_true.append(m.gt_det & (m.top_class == m.gt_class))
y_score.append(m.top_class_score)
return average_precision(y_true, y_score, num_positives=num_positives)
top_class_metrics = {
"average_precision": top_class_average_precision,
}
class TopClassEvaluatorConfig(BaseMetricConfig):
name: Literal["top_class"] = "top_class"
prefix: str = "top_class"
metrics: List[str] = Field(default_factory=lambda: ["average_precision"])
@field_validator("metrics", mode="after")
@classmethod
def validate_metrics(cls, v: List[str]) -> List[str]:
for metric_name in v:
if metric_name not in top_class_metrics:
raise ValueError(f"Unknown metric {metric_name}")
return v
class TopClassEvaluator(BaseMetric):
def __init__(self, *args, metrics: Dict[str, DetectionMetric], **kwargs):
super().__init__(*args, **kwargs)
self.metrics = metrics
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> Dict[str, float]:
clip_evaluations = [
self.match_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
scores = {
name: metric(clip_evaluations)
for name, metric in self.metrics.items()
}
return {
f"{self.prefix}/{name}": score for name, score in scores.items()
}
def match_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipMatches:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.filter_sound_event_annotations(sound_event, clip)
]
preds = [
pred for pred in predictions if self.filter_predictions(pred, clip)
]
# Use score of top class for matching
scores = [pred.class_scores.max() for pred in preds]
return match(
gts,
preds,
scores=scores,
clip=clip,
targets=self.targets,
matcher=self.matcher,
)
@classmethod
def from_config(
cls,
config: TopClassEvaluatorConfig,
targets: TargetProtocol,
):
metrics = {
name: top_class_metrics.get(name) for name in config.metrics
}
return super().build(
config=config,
metrics=metrics,
targets=targets,
)
metrics_registry.register(TopClassEvaluatorConfig, TopClassEvaluator)

View File

@ -8,7 +8,7 @@ from batdetect2.evaluate.tables import FullEvaluationTable
from batdetect2.logging import get_image_logger, get_table_logger from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol from batdetect2.typing import ClipMatches, EvaluatorProtocol
class EvaluationModule(LightningModule): class EvaluationModule(LightningModule):
@ -56,7 +56,7 @@ class EvaluationModule(LightningModule):
self.plot_examples(self.clip_evaluations) self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations) self.log_table(self.clip_evaluations)
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]): def log_table(self, evaluated_clips: Sequence[ClipMatches]):
table_logger = get_table_logger(self.logger) # type: ignore table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None: if table_logger is None:
@ -65,7 +65,7 @@ class EvaluationModule(LightningModule):
df = FullEvaluationTable()(evaluated_clips) df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0) table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]): def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
plotter = get_image_logger(self.logger) # type: ignore plotter = get_image_logger(self.logger) # type: ignore
if plotter is None: if plotter is None:
@ -74,7 +74,7 @@ class EvaluationModule(LightningModule):
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step) plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]): def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
metrics = self.evaluator.compute_metrics(evaluated_clips) metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics) self.log_dict(metrics)

View File

@ -8,8 +8,7 @@ from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig from batdetect2.core import BaseConfig, Registry
from batdetect2.core.registries import Registry
from batdetect2.evaluate.affinity import ( from batdetect2.evaluate.affinity import (
AffinityConfig, AffinityConfig,
GeometricIOUConfig, GeometricIOUConfig,
@ -17,11 +16,13 @@ from batdetect2.evaluate.affinity import (
) )
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing import ( from batdetect2.typing import (
AffinityFunction,
MatcherProtocol,
MatchEvaluation, MatchEvaluation,
RawPrediction,
TargetProtocol, TargetProtocol,
) )
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol from batdetect2.typing.evaluate import ClipMatches
from batdetect2.typing.postprocess import RawPrediction
MatchingGeometry = Literal["bbox", "interval", "timestamp"] MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching.""" """The geometry representation to use for matching."""
@ -33,9 +34,10 @@ def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation], sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction], raw_predictions: Sequence[RawPrediction],
clip: data.Clip, clip: data.Clip,
scores: Optional[Sequence[float]] = None,
targets: Optional[TargetProtocol] = None, targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None, matcher: Optional[MatcherProtocol] = None,
) -> List[MatchEvaluation]: ) -> ClipMatches:
if matcher is None: if matcher is None:
matcher = build_matcher() matcher = build_matcher()
@ -51,9 +53,11 @@ def match(
raw_prediction.geometry for raw_prediction in raw_predictions raw_prediction.geometry for raw_prediction in raw_predictions
] ]
scores = [ if scores is None:
raw_prediction.detection_score for raw_prediction in raw_predictions scores = [
] raw_prediction.detection_score
for raw_prediction in raw_predictions
]
matches = [] matches = []
@ -73,9 +77,11 @@ def match(
gt_det = target_idx is not None gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0 pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = ( pred_geometry = (
predicted_geometries[source_idx] predicted_geometries[source_idx]
if source_idx is not None if source_idx is not None
@ -84,7 +90,7 @@ def match(
class_scores = ( class_scores = (
{ {
str(class_name): float(score) class_name: score
for class_name, score in zip( for class_name, score in zip(
targets.class_names, targets.class_names,
prediction.class_scores, prediction.class_scores,
@ -100,6 +106,7 @@ def match(
sound_event_annotation=target, sound_event_annotation=target,
gt_det=gt_det, gt_det=gt_det,
gt_class=gt_class, gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score, pred_score=pred_score,
pred_class_scores=class_scores, pred_class_scores=class_scores,
pred_geometry=pred_geometry, pred_geometry=pred_geometry,
@ -107,7 +114,7 @@ def match(
) )
) )
return matches return ClipMatches(clip=clip, matches=matches)
class StartTimeMatchConfig(BaseConfig): class StartTimeMatchConfig(BaseConfig):
@ -132,12 +139,10 @@ class StartTimeMatcher(MatcherProtocol):
distance_threshold=self.distance_threshold, distance_threshold=self.distance_threshold,
) )
@classmethod @matching_strategies.register(StartTimeMatchConfig)
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher": @staticmethod
return cls(distance_threshold=config.distance_threshold) def from_config(config: StartTimeMatchConfig):
return StartTimeMatcher(distance_threshold=config.distance_threshold)
matching_strategies.register(StartTimeMatchConfig, StartTimeMatcher)
def match_start_times( def match_start_times(
@ -264,19 +269,17 @@ class GreedyMatcher(MatcherProtocol):
affinity_threshold=self.affinity_threshold, affinity_threshold=self.affinity_threshold,
) )
@classmethod @matching_strategies.register(GreedyMatchConfig)
def from_config(cls, config: GreedyMatchConfig): @staticmethod
def from_config(config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function) affinity_function = build_affinity_function(config.affinity_function)
return cls( return GreedyMatcher(
geometry=config.geometry, geometry=config.geometry,
affinity_threshold=config.affinity_threshold, affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function, affinity_function=affinity_function,
) )
matching_strategies.register(GreedyMatchConfig, GreedyMatcher)
def greedy_match( def greedy_match(
ground_truth: Sequence[data.Geometry], ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry], predictions: Sequence[data.Geometry],
@ -313,21 +316,21 @@ def greedy_match(
unassigned_gt = set(range(len(ground_truth))) unassigned_gt = set(range(len(ground_truth)))
if not predictions: if not predictions:
for target_idx in range(len(ground_truth)): for gt_idx in range(len(ground_truth)):
yield None, target_idx, 0 yield None, gt_idx, 0
return return
if not ground_truth: if not ground_truth:
for source_idx in range(len(predictions)): for pred_idx in range(len(predictions)):
yield source_idx, None, 0 yield pred_idx, None, 0
return return
indices = np.argsort(scores)[::-1] indices = np.argsort(scores)[::-1]
for source_idx in indices: for pred_idx in indices:
source_geometry = predictions[source_idx] source_geometry = predictions[pred_idx]
affinities = np.array( affinities = np.array(
[ [
@ -340,18 +343,18 @@ def greedy_match(
affinity = affinities[closest_target] affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold: if affinities[closest_target] <= affinity_threshold:
yield source_idx, None, 0 yield pred_idx, None, 0
continue continue
if closest_target not in unassigned_gt: if closest_target not in unassigned_gt:
yield source_idx, None, 0 yield pred_idx, None, 0
continue continue
unassigned_gt.remove(closest_target) unassigned_gt.remove(closest_target)
yield source_idx, closest_target, affinity yield pred_idx, closest_target, affinity
for target_idx in unassigned_gt: for gt_idx in unassigned_gt:
yield None, target_idx, 0 yield None, gt_idx, 0
class OptimalMatchConfig(BaseConfig): class OptimalMatchConfig(BaseConfig):
@ -386,17 +389,16 @@ class OptimalMatcher(MatcherProtocol):
affinity_threshold=self.affinity_threshold, affinity_threshold=self.affinity_threshold,
) )
@classmethod @matching_strategies.register(OptimalMatchConfig)
def from_config(cls, config: OptimalMatchConfig): @staticmethod
return cls( def from_config(config: OptimalMatchConfig):
return OptimalMatcher(
affinity_threshold=config.affinity_threshold, affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer, time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer, frequency_buffer=config.frequency_buffer,
) )
matching_strategies.register(OptimalMatchConfig, OptimalMatcher)
MatchConfig = Annotated[ MatchConfig = Annotated[
Union[ Union[
GreedyMatchConfig, GreedyMatchConfig,

View File

@ -1,712 +0,0 @@
from collections import defaultdict
from collections.abc import Callable, Mapping
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics, preprocessing
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation, MetricsProtocol
__all__ = ["DetectionAP", "ClassificationAP"]
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
APImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class DetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(self.metric(y_true, y_score))
return {"detection_AP": score}
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(DetectionAPConfig, DetectionAP)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["detection_roc_auc"] = "detection_roc_auc"
class DetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(metrics.roc_auc_score(y_true, y_score))
return {"detection_ROC_AUC": score}
@classmethod
def from_config(
cls, config: DetectionROCAUCConfig, class_names: List[str]
):
return cls()
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation = "pascal_voc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_mAP": float(mean_ap),
**{
f"classification_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationAPConfig,
class_names: List[str],
):
return cls(
class_names,
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
class ClassificationROCAUCConfig(BaseConfig):
name: Literal["classification_roc_auc"] = "classification_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
**{
f"classification_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationROCAUCConfig,
class_names: List[str],
):
return cls(
class_names,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
class TopClassAPConfig(BaseConfig):
name: Literal["top_class_ap"] = "top_class_ap"
ap_implementation: APImplementation = "pascal_voc"
class TopClassAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
top_class = match.pred_class
y_true.append(top_class == match.gt_class)
y_score.append(match.pred_class_score)
score = float(self.metric(y_true, y_score))
return {"top_class_AP": score}
@classmethod
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(TopClassAPConfig, TopClassAP)
class ClassificationBalancedAccuracyConfig(BaseConfig):
name: Literal["classification_balanced_accuracy"] = (
"classification_balanced_accuracy"
)
class ClassificationBalancedAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
top_class = match.pred_class
# Focus on matches
if match.gt_class is None or top_class is None:
continue
y_true.append(self.class_names.index(match.gt_class))
y_pred.append(self.class_names.index(top_class))
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
return {"classification_balanced_accuracy": score}
@classmethod
def from_config(
cls,
config: ClassificationBalancedAccuracyConfig,
class_names: List[str],
):
return cls(class_names)
metrics_registry.register(
ClassificationBalancedAccuracyConfig,
ClassificationBalancedAccuracy,
)
class ClipDetectionAPConfig(BaseConfig):
name: Literal["clip_detection_ap"] = "clip_detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class ClipDetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {"clip_detection_ap": self.metric(y_true, y_score)}
@classmethod
def from_config(
cls,
config: ClipDetectionAPConfig,
class_names: List[str],
):
return cls(implementation=config.ap_implementation)
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
class ClipDetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
}
@classmethod
def from_config(
cls,
config: ClipDetectionROCAUCConfig,
class_names: List[str],
):
return cls()
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
class ClipMulticlassAPConfig(BaseConfig):
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation,
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get max score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_mAP": float(mean_ap),
**{
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls, config: ClipMulticlassAPConfig, class_names: List[str]
):
return cls(
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
class ClipMulticlassROCAUCConfig(BaseConfig):
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get maximum score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
**{
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClipMulticlassROCAUCConfig,
class_names: List[str],
):
return cls(
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
MetricConfig = Annotated[
Union[
DetectionAPConfig,
DetectionROCAUCConfig,
ClassificationAPConfig,
ClassificationROCAUCConfig,
TopClassAPConfig,
ClassificationBalancedAccuracyConfig,
ClipDetectionAPConfig,
ClipDetectionROCAUCConfig,
ClipMulticlassAPConfig,
ClipMulticlassROCAUCConfig,
],
Field(discriminator="name"),
]
def build_metric(config: MetricConfig, class_names: List[str]):
return metrics_registry.build(config, class_names)
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}

View File

@ -0,0 +1,46 @@
from typing import Optional
import numpy as np
def average_precision(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
if num_positives is None:
num_positives = y_true.sum()
# Remove non-detections
valid_inds = y_score > 0
y_true = y_true[valid_inds]
y_score = y_score[valid_inds]
# Sort by score
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec

View File

@ -0,0 +1,235 @@
from typing import Annotated, Callable, Literal, Sequence, Union
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import (
ClipMatches,
)
__all__ = [
"MatchMetricConfig",
"MatchesMetric",
"build_match_metric",
]
MatchesMetric = Callable[[Sequence[ClipMatches]], float]
metrics_registry: Registry[MatchesMetric, []] = Registry("match_metric")
class DetectionAveragePrecisionConfig(BaseConfig):
name: Literal["detection_average_precision"] = (
"detection_average_precision"
)
ignore_non_predictions: bool = True
class DetectionAveragePrecision:
def __init__(self, ignore_non_predictions: bool = True):
self.ignore_non_predictions = ignore_non_predictions
def __call__(
self,
clip_evaluations: Sequence[ClipMatches],
) -> float:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
num_positives += int(m.gt_det)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.gt_det)
y_score.append(m.pred_score)
return average_precision(y_true, y_score, num_positives=num_positives)
@metrics_registry.register(DetectionAveragePrecisionConfig)
@staticmethod
def from_config(config: DetectionAveragePrecisionConfig):
return DetectionAveragePrecision(
ignore_non_predictions=config.ignore_non_predictions
)
class TopClassAveragePrecisionConfig(BaseConfig):
name: Literal["top_class_average_precision"] = (
"top_class_average_precision"
)
ignore_non_predictions: bool = True
ignore_generic: bool = True
class TopClassAveragePrecision:
def __init__(
self,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipMatches],
) -> float:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore ground truth sounds with unknown class
continue
num_positives += int(m.gt_det)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.gt_det & (m.top_class == m.gt_class))
y_score.append(m.top_class_score)
return average_precision(y_true, y_score, num_positives=num_positives)
@metrics_registry.register(TopClassAveragePrecisionConfig)
@staticmethod
def from_config(config: TopClassAveragePrecisionConfig):
return TopClassAveragePrecision(
ignore_non_predictions=config.ignore_non_predictions
)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["detection_roc_auc"] = "detection_roc_auc"
ignore_non_predictions: bool = True
class DetectionROCAUC:
def __init__(
self,
ignore_non_predictions: bool = True,
):
self.ignore_non_predictions = ignore_non_predictions
def __call__(self, clip_evaluations: Sequence[ClipMatches]) -> float:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.gt_det)
y_score.append(m.pred_score)
return float(metrics.roc_auc_score(y_true, y_score))
@metrics_registry.register(DetectionROCAUCConfig)
@staticmethod
def from_config(config: DetectionROCAUCConfig):
return DetectionROCAUC(
ignore_non_predictions=config.ignore_non_predictions
)
class DetectionRecallConfig(BaseConfig):
name: Literal["detection_recall"] = "detection_recall"
threshold: float = 0.5
class DetectionRecall:
def __init__(self, threshold: float):
self.threshold = threshold
def __call__(
self,
clip_evaluations: Sequence[ClipMatches],
) -> float:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.gt_det:
num_positives += 1
if m.pred_score >= self.threshold and m.gt_det:
true_positives += 1
if num_positives == 0:
return 1
return true_positives / num_positives
@metrics_registry.register(DetectionRecallConfig)
@staticmethod
def from_config(config: DetectionRecallConfig):
return DetectionRecall(threshold=config.threshold)
class DetectionPrecisionConfig(BaseConfig):
name: Literal["detection_precision"] = "detection_precision"
threshold: float = 0.5
class DetectionPrecision:
def __init__(self, threshold: float):
self.threshold = threshold
def __call__(
self,
clip_evaluations: Sequence[ClipMatches],
) -> float:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_detection = m.pred_score >= self.threshold
if is_detection:
num_detections += 1
if is_detection and m.gt_det:
true_positives += 1
if num_detections == 0:
return np.nan
return true_positives / num_detections
@metrics_registry.register(DetectionPrecisionConfig)
@staticmethod
def from_config(config: DetectionPrecisionConfig):
return DetectionPrecision(threshold=config.threshold)
MatchMetricConfig = Annotated[
Union[
DetectionAveragePrecisionConfig,
DetectionROCAUCConfig,
DetectionRecallConfig,
DetectionPrecisionConfig,
TopClassAveragePrecisionConfig,
],
Field(discriminator="name"),
]
def build_match_metric(config: MatchMetricConfig):
return metrics_registry.build(config)

View File

@ -0,0 +1,136 @@
from typing import Annotated, Callable, Literal, Sequence, Union
from pydantic import Field
from sklearn import metrics
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import (
ClipMatches,
)
__all__ = []
PerClassMatchMetric = Callable[[Sequence[ClipMatches], str], float]
metrics_registry: Registry[PerClassMatchMetric, []] = Registry(
"match_metric"
)
class ClassificationAveragePrecisionConfig(BaseConfig):
name: Literal["classification_average_precision"] = (
"classification_average_precision"
)
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ClassificationAveragePrecision:
def __init__(
self,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipMatches],
class_name: str,
) -> float:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_class = m.gt_class == class_name
if is_class:
num_positives += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
# Exclude matches with ground truth sounds where the class is
# unknown
if m.is_generic and self.ignore_generic:
continue
y_true.append(is_class)
y_score.append(m.pred_class_scores.get(class_name, 0))
return average_precision(y_true, y_score, num_positives=num_positives)
@metrics_registry.register(ClassificationAveragePrecisionConfig)
@staticmethod
def from_config(config: ClassificationAveragePrecisionConfig):
return ClassificationAveragePrecision(
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ClassificationROCAUCConfig(BaseConfig):
name: Literal["classification_roc_auc"] = "classification_roc_auc"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ClassificationROCAUC:
def __init__(
self,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipMatches],
class_name: str,
) -> float:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
# Exclude matches with ground truth sounds where the class is
# unknown
if m.is_generic and self.ignore_generic:
continue
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.gt_class == class_name)
y_score.append(m.pred_class_scores.get(class_name, 0))
return float(metrics.roc_auc_score(y_true, y_score))
@metrics_registry.register(ClassificationROCAUCConfig)
@staticmethod
def from_config(config: ClassificationROCAUCConfig):
return ClassificationROCAUC(
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
PerClassMatchMetricConfig = Annotated[
Union[
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
Field(discriminator="name"),
]
def build_per_class_matches_metric(config: PerClassMatchMetricConfig):
return metrics_registry.build(config)

View File

@ -17,7 +17,7 @@ from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
ClipEvaluation, ClipMatches,
MatchEvaluation, MatchEvaluation,
PlotterProtocol, PlotterProtocol,
PreprocessorProtocol, PreprocessorProtocol,
@ -53,7 +53,7 @@ class ExampleGallery(PlotterProtocol):
self.preprocessor = preprocessor or build_preprocessor() self.preprocessor = preprocessor or build_preprocessor()
self.audio_loader = audio_loader or build_audio_loader() self.audio_loader = audio_loader or build_audio_loader()
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): def __call__(self, clip_evaluations: Sequence[ClipMatches]):
per_class_matches = group_matches(clip_evaluations) per_class_matches = group_matches(clip_evaluations)
for class_name, matches in per_class_matches.items(): for class_name, matches in per_class_matches.items():
@ -128,7 +128,7 @@ class PlotClipEvaluation(PlotterProtocol):
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.num_plots = num_plots self.num_plots = num_plots
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): def __call__(self, clip_evaluations: Sequence[ClipMatches]):
examples = random.sample( examples = random.sample(
clip_evaluations, clip_evaluations,
k=min(self.num_plots, len(clip_evaluations)), k=min(self.num_plots, len(clip_evaluations)),
@ -171,7 +171,7 @@ class DetectionPRCurveConfig(BaseConfig):
class DetectionPRCurve(PlotterProtocol): class DetectionPRCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true, y_score = zip( y_true, y_score = zip(
*[ *[
(match.gt_det, match.pred_score) (match.gt_det, match.pred_score)
@ -231,7 +231,7 @@ class ClassificationPRCurves(PlotterProtocol):
if class_name not in exclude if class_name not in exclude
] ]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true = [] y_true = []
y_pred = [] y_pred = []
@ -303,7 +303,7 @@ class DetectionROCCurveConfig(BaseConfig):
class DetectionROCCurve(PlotterProtocol): class DetectionROCCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true, y_score = zip( y_true, y_score = zip(
*[ *[
(match.gt_det, match.pred_score) (match.gt_det, match.pred_score)
@ -363,7 +363,7 @@ class ClassificationROCCurves(PlotterProtocol):
if class_name not in exclude if class_name not in exclude
] ]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true = [] y_true = []
y_pred = [] y_pred = []
@ -440,7 +440,7 @@ class ConfusionMatrix(PlotterProtocol):
self.background_class = background_class self.background_class = background_class
self.class_names = class_names self.class_names = class_names
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true = [] y_true = []
y_pred = [] y_pred = []
@ -456,7 +456,7 @@ class ConfusionMatrix(PlotterProtocol):
else self.background_class else self.background_class
) )
top_class = match.pred_class top_class = match.top_class
y_pred.append( y_pred.append(
top_class top_class
if top_class is not None if top_class is not None
@ -515,14 +515,14 @@ class ClassMatches:
def group_matches( def group_matches(
clip_evaluations: Sequence[ClipEvaluation], clip_evaluations: Sequence[ClipMatches],
) -> Dict[str, ClassMatches]: ) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches) class_examples = defaultdict(ClassMatches)
for clip_evaluation in clip_evaluations: for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches: for match in clip_evaluation.matches:
gt_class = match.gt_class gt_class = match.gt_class
pred_class = match.pred_class pred_class = match.top_class
if pred_class is None: if pred_class is None:
class_examples[gt_class].false_negatives.append(match) class_examples[gt_class].false_negatives.append(match)
@ -550,7 +550,7 @@ def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
*[ *[
(index, match.pred_class_scores[pred_class]) (index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches) for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None if (pred_class := match.top_class) is not None
] ]
) )

View File

@ -5,9 +5,9 @@ from pydantic import Field
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation from batdetect2.typing import ClipMatches
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame] EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
tables_registry: Registry[EvaluationTableGenerator, []] = Registry( tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
@ -21,20 +21,18 @@ class FullEvaluationTableConfig(BaseConfig):
class FullEvaluationTable: class FullEvaluationTable:
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEvaluation] self, clip_evaluations: Sequence[ClipMatches]
) -> pd.DataFrame: ) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations) return extract_matches_dataframe(clip_evaluations)
@classmethod @tables_registry.register(FullEvaluationTableConfig)
def from_config(cls, config: FullEvaluationTableConfig): @staticmethod
return cls() def from_config(config: FullEvaluationTableConfig):
return FullEvaluationTable()
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
def extract_matches_dataframe( def extract_matches_dataframe(
clip_evaluations: Sequence[ClipEvaluation], clip_evaluations: Sequence[ClipMatches],
) -> pd.DataFrame: ) -> pd.DataFrame:
data = [] data = []
@ -78,8 +76,8 @@ def extract_matches_dataframe(
("gt", "low_freq"): gt_low_freq, ("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq, ("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score, ("pred", "score"): match.pred_score,
("pred", "class"): match.pred_class, ("pred", "class"): match.top_class,
("pred", "class_score"): match.pred_class_score, ("pred", "class_score"): match.top_class_score,
("pred", "start_time"): pred_start_time, ("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time, ("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq, ("pred", "low_freq"): pred_low_freq,

View File

@ -65,8 +65,6 @@ def plot_anchor_points(
if not targets.filter(sound_event): if not targets.filter(sound_event):
continue continue
sound_event = targets.transform(sound_event)
position, _ = targets.encode_roi(sound_event) position, _ = targets.encode_roi(sound_event)
positions.append(position) positions.append(position)

View File

@ -162,7 +162,7 @@ def plot_false_positive_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ", f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.top_class} \nTop Class Score: {match.top_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
@ -312,7 +312,7 @@ def plot_true_positive_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,
@ -394,7 +394,7 @@ def plot_cross_trigger_match(
plt.text( plt.text(
start_time, start_time,
high_freq, high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.top_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
va="top", va="top",
ha="right", ha="right",
color=color, color=color,

View File

@ -28,12 +28,10 @@ class CenterAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
return center_tensor(wav) return center_tensor(wav)
@classmethod @audio_transforms.register(CenterAudioConfig)
def from_config(cls, config: CenterAudioConfig, samplerate: int): @staticmethod
return cls() def from_config(config: CenterAudioConfig, samplerate: int):
return CenterAudio()
audio_transforms.register(CenterAudioConfig, CenterAudio)
class ScaleAudioConfig(BaseConfig): class ScaleAudioConfig(BaseConfig):
@ -44,12 +42,10 @@ class ScaleAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
return peak_normalize(wav) return peak_normalize(wav)
@classmethod @audio_transforms.register(ScaleAudioConfig)
def from_config(cls, config: ScaleAudioConfig, samplerate: int): @staticmethod
return cls() def from_config(config: ScaleAudioConfig, samplerate: int):
return ScaleAudio()
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
class FixDurationConfig(BaseConfig): class FixDurationConfig(BaseConfig):
@ -75,13 +71,12 @@ class FixDuration(torch.nn.Module):
return torch.nn.functional.pad(wav, (0, self.length - length)) return torch.nn.functional.pad(wav, (0, self.length - length))
@classmethod @audio_transforms.register(FixDurationConfig)
def from_config(cls, config: FixDurationConfig, samplerate: int): @staticmethod
return cls(samplerate=samplerate, duration=config.duration) def from_config(config: FixDurationConfig, samplerate: int):
return FixDuration(samplerate=samplerate, duration=config.duration)
audio_transforms.register(FixDurationConfig, FixDuration)
AudioTransform = Annotated[ AudioTransform = Annotated[
Union[ Union[
FixDurationConfig, FixDurationConfig,

View File

@ -285,10 +285,11 @@ class PCEN(torch.nn.Module):
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias)) * torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
).to(spec.dtype) ).to(spec.dtype)
@classmethod @spectrogram_transforms.register(PcenConfig)
def from_config(cls, config: PcenConfig, samplerate: int): @staticmethod
def from_config(config: PcenConfig, samplerate: int):
smooth = _compute_smoothing_constant(samplerate, config.time_constant) smooth = _compute_smoothing_constant(samplerate, config.time_constant)
return cls( return PCEN(
smoothing_constant=smooth, smoothing_constant=smooth,
gain=config.gain, gain=config.gain,
bias=config.bias, bias=config.bias,
@ -296,9 +297,6 @@ class PCEN(torch.nn.Module):
) )
spectrogram_transforms.register(PcenConfig, PCEN)
def _compute_smoothing_constant( def _compute_smoothing_constant(
samplerate: int, samplerate: int,
time_constant: float, time_constant: float,
@ -335,12 +333,10 @@ class ScaleAmplitude(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
return self.scaler(spec) return self.scaler(spec)
@classmethod @spectrogram_transforms.register(ScaleAmplitudeConfig)
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int): @staticmethod
return cls(scale=config.scale) def from_config(config: ScaleAmplitudeConfig, samplerate: int):
return ScaleAmplitude(scale=config.scale)
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
class SpectralMeanSubstractionConfig(BaseConfig): class SpectralMeanSubstractionConfig(BaseConfig):
@ -352,19 +348,13 @@ class SpectralMeanSubstraction(torch.nn.Module):
mean = spec.mean(-1, keepdim=True) mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0) return (spec - mean).clamp(min=0)
@classmethod @spectrogram_transforms.register(SpectralMeanSubstractionConfig)
@staticmethod
def from_config( def from_config(
cls,
config: SpectralMeanSubstractionConfig, config: SpectralMeanSubstractionConfig,
samplerate: int, samplerate: int,
): ):
return cls() return SpectralMeanSubstraction()
spectrogram_transforms.register(
SpectralMeanSubstractionConfig,
SpectralMeanSubstraction,
)
class PeakNormalizeConfig(BaseConfig): class PeakNormalizeConfig(BaseConfig):
@ -375,13 +365,12 @@ class PeakNormalize(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
return peak_normalize(spec) return peak_normalize(spec)
@classmethod @spectrogram_transforms.register(PeakNormalizeConfig)
def from_config(cls, config: PeakNormalizeConfig, samplerate: int): @staticmethod
return cls() def from_config(config: PeakNormalizeConfig, samplerate: int):
return PeakNormalize()
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
SpectrogramTransform = Annotated[ SpectrogramTransform = Annotated[
Union[ Union[
PcenConfig, PcenConfig,

View File

@ -99,7 +99,7 @@ DEFAULT_DETECTION_CLASS = TargetClassConfig(
DEFAULT_CLASSES = [ DEFAULT_CLASSES = [
TargetClassConfig( TargetClassConfig(
name="barbar", name="barbar",
tags=[data.Tag(key="class", value="Barbastellus barbastellus")], tags=[data.Tag(key="class", value="Barbastella barbastellus")],
), ),
TargetClassConfig( TargetClassConfig(
name="eptser", name="eptser",

View File

@ -1,11 +1,11 @@
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
AugmentationsConfig, AugmentationsConfig,
EchoAugmentationConfig, AddEchoConfig,
FrequencyMaskAugmentationConfig, MaskFrequencyConfig,
RandomAudioSource, RandomAudioSource,
TimeMaskAugmentationConfig, MaskTimeConfig,
VolumeAugmentationConfig, ScaleVolumeConfig,
WarpAugmentationConfig, WarpConfig,
add_echo, add_echo,
build_augmentations, build_augmentations,
mask_frequency, mask_frequency,
@ -43,20 +43,20 @@ __all__ = [
"AugmentationsConfig", "AugmentationsConfig",
"ClassificationLossConfig", "ClassificationLossConfig",
"DetectionLossConfig", "DetectionLossConfig",
"EchoAugmentationConfig", "AddEchoConfig",
"FrequencyMaskAugmentationConfig", "MaskFrequencyConfig",
"LossConfig", "LossConfig",
"LossFunction", "LossFunction",
"PLTrainerConfig", "PLTrainerConfig",
"RandomAudioSource", "RandomAudioSource",
"SizeLossConfig", "SizeLossConfig",
"TimeMaskAugmentationConfig", "MaskTimeConfig",
"TrainingConfig", "TrainingConfig",
"TrainingDataset", "TrainingDataset",
"TrainingModule", "TrainingModule",
"ValidationDataset", "ValidationDataset",
"VolumeAugmentationConfig", "ScaleVolumeConfig",
"WarpAugmentationConfig", "WarpConfig",
"add_echo", "add_echo",
"build_augmentations", "build_augmentations",
"build_clip_labeler", "build_clip_labeler",

View File

@ -12,21 +12,23 @@ from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.audio.clips import get_subclip_annotation from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.arrays import adjust_width from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
from batdetect2.typing import AudioLoader, Augmentation from batdetect2.typing import AudioLoader, Augmentation
__all__ = [ __all__ = [
"AugmentationConfig", "AugmentationConfig",
"AugmentationsConfig", "AugmentationsConfig",
"DEFAULT_AUGMENTATION_CONFIG", "DEFAULT_AUGMENTATION_CONFIG",
"EchoAugmentationConfig", "AddEchoConfig",
"AudioSource", "AudioSource",
"FrequencyMaskAugmentationConfig", "MaskFrequencyConfig",
"MixAugmentationConfig", "MixAudioConfig",
"TimeMaskAugmentationConfig", "MaskTimeConfig",
"VolumeAugmentationConfig", "ScaleVolumeConfig",
"WarpAugmentationConfig", "WarpConfig",
"add_echo", "add_echo",
"build_augmentations", "build_augmentations",
"load_augmentation_config", "load_augmentation_config",
@ -37,10 +39,19 @@ __all__ = [
"warp_spectrogram", "warp_spectrogram",
] ]
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]] AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
audio_augmentations: Registry[Augmentation, [int, Optional[AudioSource]]] = (
Registry(name="audio_augmentation")
)
class MixAugmentationConfig(BaseConfig): spec_augmentations: Registry[Augmentation, []] = Registry(
name="spec_augmentation"
)
class MixAudioConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples).""" """Configuration for MixUp augmentation (mixing two examples)."""
name: Literal["mix_audio"] = "mix_audio" name: Literal["mix_audio"] = "mix_audio"
@ -87,6 +98,19 @@ class MixAudio(torch.nn.Module):
) )
return mixed_audio, mixed_annotations return mixed_audio, mixed_annotations
@audio_augmentations.register(MixAudioConfig)
@staticmethod
def from_config(
config: MixAudioConfig,
samplerate: int,
source: Optional[AudioSource],
):
return MixAudio(
example_source=source,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
def mix_audio( def mix_audio(
wav1: torch.Tensor, wav1: torch.Tensor,
@ -136,7 +160,7 @@ def combine_clip_annotations(
) )
class EchoAugmentationConfig(BaseConfig): class AddEchoConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb.""" """Configuration for adding synthetic echo/reverb."""
name: Literal["add_echo"] = "add_echo" name: Literal["add_echo"] = "add_echo"
@ -149,14 +173,17 @@ class EchoAugmentationConfig(BaseConfig):
class AddEcho(torch.nn.Module): class AddEcho(torch.nn.Module):
def __init__( def __init__(
self, self,
samplerate: int = TARGET_SAMPLERATE_HZ,
min_weight: float = 0.1, min_weight: float = 0.1,
max_weight: float = 1.0, max_weight: float = 1.0,
max_delay: int = 2560, max_delay: float = 0.005,
): ):
super().__init__() super().__init__()
self.samplerate = samplerate
self.min_weight = min_weight self.min_weight = min_weight
self.max_weight = max_weight self.max_weight = max_weight
self.max_delay = max_delay self.max_delay_s = max_delay
self.max_delay = int(max_delay * samplerate)
def forward( def forward(
self, self,
@ -167,6 +194,18 @@ class AddEcho(torch.nn.Module):
weight = np.random.uniform(self.min_weight, self.max_weight) weight = np.random.uniform(self.min_weight, self.max_weight)
return add_echo(wav, delay=delay, weight=weight), clip_annotation return add_echo(wav, delay=delay, weight=weight), clip_annotation
@audio_augmentations.register(AddEchoConfig)
@staticmethod
def from_config(
config: AddEchoConfig, samplerate: int, source: AudioSource
):
return AddEcho(
samplerate=samplerate,
min_weight=config.min_weight,
max_weight=config.max_weight,
max_delay=config.max_delay,
)
def add_echo( def add_echo(
wav: torch.Tensor, wav: torch.Tensor,
@ -183,7 +222,7 @@ def add_echo(
return mix_audio(wav, audio_delay, weight) return mix_audio(wav, audio_delay, weight)
class VolumeAugmentationConfig(BaseConfig): class ScaleVolumeConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram.""" """Configuration for random volume scaling of the spectrogram."""
name: Literal["scale_volume"] = "scale_volume" name: Literal["scale_volume"] = "scale_volume"
@ -206,19 +245,27 @@ class ScaleVolume(torch.nn.Module):
factor = np.random.uniform(self.min_scaling, self.max_scaling) factor = np.random.uniform(self.min_scaling, self.max_scaling)
return scale_volume(spec, factor=factor), clip_annotation return scale_volume(spec, factor=factor), clip_annotation
@spec_augmentations.register(ScaleVolumeConfig)
@staticmethod
def from_config(config: ScaleVolumeConfig):
return ScaleVolume(
min_scaling=config.min_scaling,
max_scaling=config.max_scaling,
)
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor: def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
"""Scale the amplitude of the spectrogram by a factor.""" """Scale the amplitude of the spectrogram by a factor."""
return spec * factor return spec * factor
class WarpAugmentationConfig(BaseConfig): class WarpConfig(BaseConfig):
name: Literal["warp"] = "warp" name: Literal["warp"] = "warp"
probability: float = 0.2 probability: float = 0.2
delta: float = 0.04 delta: float = 0.04
class WarpSpectrogram(torch.nn.Module): class Warp(torch.nn.Module):
def __init__(self, delta: float = 0.04) -> None: def __init__(self, delta: float = 0.04) -> None:
super().__init__() super().__init__()
self.delta = delta self.delta = delta
@ -234,6 +281,11 @@ class WarpSpectrogram(torch.nn.Module):
warp_clip_annotation(clip_annotation, factor=factor), warp_clip_annotation(clip_annotation, factor=factor),
) )
@spec_augmentations.register(WarpConfig)
@staticmethod
def from_config(config: WarpConfig):
return Warp(delta=config.delta)
def warp_sound_event_annotation( def warp_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation, sound_event_annotation: data.SoundEventAnnotation,
@ -294,7 +346,7 @@ def warp_spectrogram(
).squeeze(0) ).squeeze(0)
class TimeMaskAugmentationConfig(BaseConfig): class MaskTimeConfig(BaseConfig):
name: Literal["mask_time"] = "mask_time" name: Literal["mask_time"] = "mask_time"
probability: float = 0.2 probability: float = 0.2
max_perc: float = 0.05 max_perc: float = 0.05
@ -336,6 +388,14 @@ class MaskTime(torch.nn.Module):
] ]
return mask_time(spec, masks), clip_annotation return mask_time(spec, masks), clip_annotation
@spec_augmentations.register(MaskTimeConfig)
@staticmethod
def from_config(config: MaskTimeConfig):
return MaskTime(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
def mask_time( def mask_time(
spec: torch.Tensor, spec: torch.Tensor,
@ -351,7 +411,7 @@ def mask_time(
return spec return spec
class FrequencyMaskAugmentationConfig(BaseConfig): class MaskFrequencyConfig(BaseConfig):
name: Literal["mask_freq"] = "mask_freq" name: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2 probability: float = 0.2
max_perc: float = 0.10 max_perc: float = 0.10
@ -394,6 +454,14 @@ class MaskFrequency(torch.nn.Module):
] ]
return mask_frequency(spec, masks), clip_annotation return mask_frequency(spec, masks), clip_annotation
@spec_augmentations.register(MaskFrequencyConfig)
@staticmethod
def from_config(config: MaskFrequencyConfig):
return MaskFrequency(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
def mask_frequency( def mask_frequency(
spec: torch.Tensor, spec: torch.Tensor,
@ -410,8 +478,8 @@ def mask_frequency(
AudioAugmentationConfig = Annotated[ AudioAugmentationConfig = Annotated[
Union[ Union[
MixAugmentationConfig, MixAudioConfig,
EchoAugmentationConfig, AddEchoConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -419,22 +487,22 @@ AudioAugmentationConfig = Annotated[
SpectrogramAugmentationConfig = Annotated[ SpectrogramAugmentationConfig = Annotated[
Union[ Union[
VolumeAugmentationConfig, ScaleVolumeConfig,
WarpAugmentationConfig, WarpConfig,
FrequencyMaskAugmentationConfig, MaskFrequencyConfig,
TimeMaskAugmentationConfig, MaskTimeConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[
Union[ Union[
MixAugmentationConfig, MixAudioConfig,
EchoAugmentationConfig, AddEchoConfig,
VolumeAugmentationConfig, ScaleVolumeConfig,
WarpAugmentationConfig, WarpConfig,
FrequencyMaskAugmentationConfig, MaskFrequencyConfig,
TimeMaskAugmentationConfig, MaskTimeConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -513,7 +581,7 @@ def build_augmentation_from_config(
) )
if config.name == "warp": if config.name == "warp":
return WarpSpectrogram( return Warp(
delta=config.delta, delta=config.delta,
) )
@ -538,14 +606,14 @@ def build_augmentation_from_config(
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
enabled=True, enabled=True,
audio=[ audio=[
MixAugmentationConfig(), MixAudioConfig(),
EchoAugmentationConfig(), AddEchoConfig(),
], ],
spectrogram=[ spectrogram=[
VolumeAugmentationConfig(), ScaleVolumeConfig(),
WarpAugmentationConfig(), WarpConfig(),
TimeMaskAugmentationConfig(), MaskTimeConfig(),
FrequencyMaskAugmentationConfig(), MaskFrequencyConfig(),
], ],
) )
@ -566,9 +634,9 @@ class AugmentationSequence(torch.nn.Module):
return tensor, clip_annotation return tensor, clip_annotation
def build_augmentation_sequence( def build_audio_augmentations(
samplerate: int, steps: Optional[Sequence[AudioAugmentationConfig]] = None,
steps: Optional[Sequence[AugmentationConfig]] = None, samplerate: int = TARGET_SAMPLERATE_HZ,
audio_source: Optional[AudioSource] = None, audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]: ) -> Optional[Augmentation]:
if not steps: if not steps:
@ -577,10 +645,8 @@ def build_augmentation_sequence(
augmentations = [] augmentations = []
for step_config in steps: for step_config in steps:
augmentation = build_augmentation_from_config( augmentation = audio_augmentations.build(
step_config, step_config, samplerate, audio_source
samplerate=samplerate,
audio_source=audio_source,
) )
if augmentation is None: if augmentation is None:

View File

@ -10,7 +10,6 @@ from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import ( from batdetect2.typing import (
ClipEvaluation,
EvaluatorProtocol, EvaluatorProtocol,
ModelOutput, ModelOutput,
RawPrediction, RawPrediction,
@ -37,22 +36,26 @@ class ValidationMetrics(Callback):
def generate_plots( def generate_plots(
self, self,
pl_module: LightningModule, pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
): ):
plotter = get_image_logger(pl_module.logger) # type: ignore plotter = get_image_logger(pl_module.logger) # type: ignore
if plotter is None: if plotter is None:
return return
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): for figure_name, fig in self.evaluator.generate_plots(
self._clip_annotations,
self._predictions,
):
plotter(figure_name, fig, pl_module.global_step) plotter(figure_name, fig, pl_module.global_step)
def log_metrics( def log_metrics(
self, self,
pl_module: LightningModule, pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
): ):
metrics = self.evaluator.compute_metrics(evaluated_clips) metrics = self.evaluator.compute_metrics(
self._clip_annotations,
self._predictions,
)
pl_module.log_dict(metrics) pl_module.log_dict(metrics)
def on_validation_epoch_end( def on_validation_epoch_end(
@ -60,13 +63,8 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
clip_evaluations = self.evaluator.evaluate( self.log_metrics(pl_module)
self._clip_annotations, self.generate_plots(pl_module)
self._predictions,
)
self.log_metrics(pl_module, clip_evaluations)
self.generate_plots(pl_module, clip_evaluations)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)

View File

@ -105,7 +105,10 @@ def train(
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
config, config,
targets=targets, targets=targets,
evaluator=build_evaluator(config.train.validation, targets=targets), evaluator=build_evaluator(
config.train.validation.evaluator,
targets=targets,
),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,

View File

@ -1,6 +1,8 @@
from batdetect2.typing.evaluate import ( from batdetect2.typing.evaluate import (
ClipEvaluation, AffinityFunction,
ClipMatches,
EvaluatorProtocol, EvaluatorProtocol,
MatcherProtocol,
MatchEvaluation, MatchEvaluation,
MetricsProtocol, MetricsProtocol,
PlotterProtocol, PlotterProtocol,
@ -36,19 +38,22 @@ from batdetect2.typing.train import (
) )
__all__ = [ __all__ = [
"AffinityFunction",
"AudioLoader", "AudioLoader",
"Augmentation", "Augmentation",
"BackboneModel", "BackboneModel",
"BatDetect2Prediction", "BatDetect2Prediction",
"ClipEvaluation", "ClipMatches",
"ClipLabeller", "ClipLabeller",
"ClipperProtocol", "ClipperProtocol",
"DetectionModel", "DetectionModel",
"EvaluatorProtocol",
"GeometryDecoder", "GeometryDecoder",
"Heatmaps", "Heatmaps",
"LossProtocol", "LossProtocol",
"Losses", "Losses",
"MatchEvaluation", "MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol", "MetricsProtocol",
"ModelOutput", "ModelOutput",
"PlotterProtocol", "PlotterProtocol",
@ -63,5 +68,4 @@ __all__ = [
"SoundEventFilter", "SoundEventFilter",
"TargetProtocol", "TargetProtocol",
"TrainExample", "TrainExample",
"EvaluatorProtocol",
] ]

View File

@ -31,6 +31,7 @@ class MatchEvaluation:
sound_event_annotation: Optional[data.SoundEventAnnotation] sound_event_annotation: Optional[data.SoundEventAnnotation]
gt_det: bool gt_det: bool
gt_class: Optional[str] gt_class: Optional[str]
gt_geometry: Optional[data.Geometry]
pred_score: float pred_score: float
pred_class_scores: Dict[str, float] pred_class_scores: Dict[str, float]
@ -39,44 +40,32 @@ class MatchEvaluation:
affinity: float affinity: float
@property @property
def pred_class(self) -> Optional[str]: def top_class(self) -> Optional[str]:
if not self.pred_class_scores: if not self.pred_class_scores:
return None return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property @property
def pred_class_score(self) -> float: def is_prediction(self) -> bool:
pred_class = self.pred_class return self.pred_geometry is not None
@property
def is_generic(self) -> bool:
return self.gt_det and self.gt_class is None
@property
def top_class_score(self) -> float:
pred_class = self.top_class
if pred_class is None: if pred_class is None:
return 0 return 0
return self.pred_class_scores[pred_class] return self.pred_class_scores[pred_class]
def is_true_positive(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class == self.pred_class
)
def is_false_positive(self, threshold: float = 0) -> bool:
return self.gt_det is None and self.pred_score > threshold
def is_false_negative(self, threshold: float = 0) -> bool:
return self.gt_det and self.pred_score <= threshold
def is_cross_trigger(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class != self.pred_class
)
@dataclass @dataclass
class ClipEvaluation: class ClipMatches:
clip: data.Clip clip: data.Clip
matches: List[MatchEvaluation] matches: List[MatchEvaluation]
@ -103,29 +92,36 @@ class AffinityFunction(Protocol, Generic[Geom]):
class MetricsProtocol(Protocol): class MetricsProtocol(Protocol):
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEvaluation] self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> Dict[str, float]: ... ) -> Dict[str, float]: ...
class PlotterProtocol(Protocol): class PlotterProtocol(Protocol):
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEvaluation] self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[Tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol): EvaluationOutput = TypeVar("EvaluationOutput")
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol targets: TargetProtocol
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]], predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]: ... ) -> EvaluationOutput: ...
def compute_metrics( def compute_metrics(
self, clip_evaluations: Sequence[ClipEvaluation] self, eval_outputs: EvaluationOutput
) -> Dict[str, float]: ... ) -> Dict[str, float]: ...
def generate_plots( def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation] self, eval_outputs: EvaluationOutput
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[Tuple[str, Figure]]: ...