Better matching module, remove generic from classification evaluations

This commit is contained in:
mbsantiago 2025-09-14 18:16:59 +01:00
parent 8628133fd7
commit ec1c0ff020
10 changed files with 355 additions and 338 deletions

View File

@ -1,11 +1,6 @@
from batdetect2.evaluate.config import (
EvaluationConfig,
load_evaluation_config,
)
from batdetect2.evaluate.match import match_predictions_and_annotations
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
"match_predictions_and_annotations",
]

View File

@ -4,7 +4,7 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
__all__ = [
"EvaluationConfig",
@ -13,7 +13,7 @@ __all__ = [
class EvaluationConfig(BaseConfig):
match: MatchConfig = Field(default_factory=MatchConfig)
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
def load_evaluation_config(

View File

@ -4,9 +4,8 @@ import pandas as pd
from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.match import match_all_predictions
from batdetect2.evaluate.match import build_matcher, match_all_predictions
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
@ -77,11 +76,13 @@ def evaluate(
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matcher = build_matcher(config.evaluation.match)
matches = match_all_predictions(
clip_annotations,
predictions,
targets=targets,
config=config.evaluation.match,
matcher=matcher,
)
df = extract_matches_dataframe(matches)
@ -89,7 +90,6 @@ def evaluate(
metrics = [
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(class_names=targets.class_names),
ClassificationAccuracy(class_names=targets.class_names),
]
results = {

View File

@ -1,63 +1,120 @@
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Protocol, Tuple
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
from loguru import logger
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.targets import build_targets
from batdetect2.typing import (
MatchEvaluation,
TargetProtocol,
)
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction
MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
matching_strategy = Registry("matching_strategy")
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time"] = "start_time"
distance_threshold: float = 0.01
@matching_strategy.register(StartTimeMatchConfig)
class StartTimeMatcher(MatcherProtocol):
def __init__(self, distance_threshold: float):
self.distance_threshold = distance_threshold
class AffinityFunction(Protocol):
def __call__(
self,
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ...
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return match_start_times(
ground_truth,
predictions,
scores,
distance_threshold=self.distance_threshold,
)
@classmethod
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher":
return cls(distance_threshold=config.distance_threshold)
class MatchConfig(BaseConfig):
"""Configuration for matching geometries.
def match_start_times(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
distance_threshold: float = 0.01,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
if not ground_truth:
for index in range(len(predictions)):
yield index, None, 0
Attributes
----------
strategy : MatchingStrategy, default="greedy"
The matching algorithm to use. 'greedy' prioritizes high-confidence
predictions, while 'optimal' finds the globally best set of matches.
geometry : MatchingGeometry, default="timestamp"
The geometric representation to use when computing affinity.
affinity_threshold : float, default=0.0
The minimum affinity score (e.g., IoU) required for a valid match.
time_buffer : float, default=0.005
Time tolerance in seconds used in affinity calculations.
frequency_buffer : float, default=1000
Frequency tolerance in Hertz used in affinity calculations.
"""
return
strategy: MatchingStrategy = "greedy"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000
ignore_start_end: float = 0.01
if not predictions:
for index in range(len(ground_truth)):
yield None, index, 0
return
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
scores = np.array(scores)
sort_args = np.argsort(scores)[::-1]
distances = np.abs(gt_times[None, :] - pred_times[:, None])
closests = np.argmin(distances, axis=-1)
unmatched_gt = set(range(len(gt_times)))
for pred_index in sort_args:
# Get the closest ground truth
gt_closest_index = closests[pred_index]
if gt_closest_index not in unmatched_gt:
# Does not match if closest has been assigned
yield pred_index, None, 0
continue
# Get the actual distance
distance = distances[pred_index, gt_closest_index]
if distance > distance_threshold:
# Does not match if too far from closest
yield pred_index, None, 0
continue
# Return affinity value: linear interpolation between 0 to 1, where a
# distance at the threshold maps to 0 affinity and a zero distance maps
# to 1.
affinity = np.interp(
distance,
[0, distance_threshold],
[1, 0],
left=1,
right=0,
)
unmatched_gt.remove(gt_closest_index)
yield pred_index, gt_closest_index, affinity
for missing_index in unmatched_gt:
yield None, missing_index, 0
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
@ -142,50 +199,65 @@ def _interval_affinity(
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
"timestamp": _timestamp_affinity,
"interval": _interval_affinity,
"bbox": compute_affinity,
}
def match_geometries(
source: List[data.Geometry],
target: List[data.Geometry],
config: MatchConfig,
scores: Optional[List[float]] = None,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
geometry_cast = _geometry_cast_functions[config.geometry]
affinity_function = _affinity_functions.get(
config.geometry,
compute_affinity,
)
class GreedyMatchConfig(BaseConfig):
name: Literal["greedy_match"] = "greedy_match"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000
if config.strategy == "optimal":
return optimal_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
)
if config.strategy == "greedy":
@matching_strategy.register(GreedyMatchConfig)
class GreedyMatcher(MatcherProtocol):
def __init__(
self,
geometry: MatchingGeometry,
affinity_threshold: float,
time_buffer: float,
frequency_buffer: float,
):
self.geometry = geometry
self.affinity_threshold = affinity_threshold
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.affinity_function = _affinity_functions[self.geometry]
self.cast_geometry = _geometry_cast_functions[self.geometry]
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return greedy_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
predictions=[self.cast_geometry(geom) for geom in predictions],
scores=scores,
affinity_function=self.affinity_function,
affinity_threshold=self.affinity_threshold,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
raise NotImplementedError(
f"Matching strategy not implemented {config.strategy}"
)
@classmethod
def from_config(cls, config: GreedyMatchConfig):
return cls(
geometry=config.geometry,
affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
)
def greedy_match(
source: List[data.Geometry],
target: List[data.Geometry],
scores: Optional[List[float]] = None,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
time_buffer: float = 0.001,
@ -221,27 +293,24 @@ def greedy_match(
- Unmatched Source (False Positive): `(source_idx, None, 0)`
- Unmatched Target (False Negative): `(None, target_idx, 0)`
"""
assigned = set()
unassigned_gt = set(range(len(ground_truth)))
if not source:
for target_idx in range(len(target)):
if not predictions:
for target_idx in range(len(ground_truth)):
yield None, target_idx, 0
return
if not target:
for source_idx in range(len(source)):
if not ground_truth:
for source_idx in range(len(predictions)):
yield source_idx, None, 0
return
if scores is None:
indices = np.arange(len(source))
else:
indices = np.argsort(scores)[::-1]
indices = np.argsort(scores)[::-1]
for source_idx in indices:
source_geometry = source[source_idx]
source_geometry = predictions[source_idx]
affinities = np.array(
[
@ -251,7 +320,7 @@ def greedy_match(
time_buffer=time_buffer,
freq_buffer=freq_buffer,
)
for target_geometry in target
for target_geometry in ground_truth
]
)
@ -262,18 +331,74 @@ def greedy_match(
yield source_idx, None, 0
continue
if closest_target in assigned:
if closest_target not in unassigned_gt:
yield source_idx, None, 0
continue
assigned.add(closest_target)
unassigned_gt.remove(closest_target)
yield source_idx, closest_target, affinity
missed_ground_truth = set(range(len(target))) - assigned
for target_idx in missed_ground_truth:
for target_idx in unassigned_gt:
yield None, target_idx, 0
class OptimalMatchConfig(BaseConfig):
name: Literal["optimal_match"] = "optimal_match"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000
@matching_strategy.register(OptimalMatchConfig)
class OptimalMatcher(MatcherProtocol):
def __init__(
self,
affinity_threshold: float,
time_buffer: float,
frequency_buffer: float,
):
self.affinity_threshold = affinity_threshold
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return optimal_match(
source=predictions,
target=ground_truth,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
affinity_threshold=self.affinity_threshold,
)
@classmethod
def from_config(cls, config: OptimalMatchConfig):
return cls(
affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
)
MatchConfig = Annotated[
Union[
GreedyMatchConfig,
StartTimeMatchConfig,
OptimalMatchConfig,
],
Field(discriminator="name"),
]
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategy.build(config)
def _is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
@ -285,13 +410,18 @@ def _is_in_bounds(
)
def match_sound_events_and_raw_predictions(
def match_sound_events_and_predictions(
clip_annotation: data.ClipAnnotation,
raw_predictions: List[RawPrediction],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]:
config = config or MatchConfig()
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_sound_events = [
sound_event_annotation
@ -301,7 +431,7 @@ def match_sound_events_and_raw_predictions(
and _is_in_bounds(
sound_event_annotation.sound_event.geometry,
clip=clip_annotation.clip,
buffer=config.ignore_start_end,
buffer=ignore_start_end,
)
]
@ -317,7 +447,7 @@ def match_sound_events_and_raw_predictions(
if _is_in_bounds(
raw_prediction.geometry,
clip=clip_annotation.clip,
buffer=config.ignore_start_end,
buffer=ignore_start_end,
)
]
@ -331,10 +461,9 @@ def match_sound_events_and_raw_predictions(
matches = []
for source_idx, target_idx, affinity in match_geometries(
source=predicted_geometries,
target=target_geometries,
config=config,
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
@ -344,7 +473,7 @@ def match_sound_events_and_raw_predictions(
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target is not None
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
@ -383,76 +512,12 @@ def match_sound_events_and_raw_predictions(
return matches
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]:
config = config or MatchConfig()
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
scores = [
sound_event.score
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for source_idx, target_idx, affinity in match_geometries(
source=predicted_geometries,
target=annotated_geometries,
config=config,
scores=scores,
):
target = (
annotated_sound_events[target_idx]
if target_idx is not None
else None
)
source = (
predicted_sound_events[source_idx]
if source_idx is not None
else None
)
matches.append(
data.Match(
source=source,
target=target,
affinity=affinity,
)
)
return matches
def match_all_predictions(
clip_annotations: List[data.ClipAnnotation],
predictions: List[List[RawPrediction]],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
@ -461,11 +526,12 @@ def match_all_predictions(
clip_annotations,
predictions,
)
for match in match_sound_events_and_raw_predictions(
for match in match_sound_events_and_predictions(
clip_annotation,
raw_predictions,
targets=targets,
config=config,
matcher=matcher,
ignore_start_end=ignore_start_end,
)
]

View File

@ -24,10 +24,12 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
# NOTE: Need to exclude generic but unclassified targets
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
if not (match.gt_det and match.gt_class is None)
],
classes=self.class_names,
)
@ -38,11 +40,11 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
for name in self.class_names
}
for match in matches
if not (match.gt_det and match.gt_class is None)
]
).fillna(0)
ret = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name]
@ -57,39 +59,3 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
)
return ret
class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
]
y_pred = pd.DataFrame(
[
{
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
y_pred = y_pred.apply(
lambda row: row.idxmax()
if row.max() >= (1 - row.sum())
else "__NONE__",
axis=1,
)
accuracy = metrics.balanced_accuracy_score(
y_true,
y_pred,
)
return {
"classification_acc": float(accuracy),
}

View File

@ -8,13 +8,10 @@ from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.conditions import (
SoundEventCondition,
build_sound_event_condition,
)
from batdetect2.data.conditions import build_sound_event_condition
from batdetect2.targets.classes import (
DEFAULT_CLASSES,
DEFAULT_GENERIC_CLASS,
DEFAULT_DETECTION_CLASS,
SoundEventDecoder,
SoundEventEncoder,
TargetClassConfig,
@ -58,7 +55,9 @@ __all__ = [
class TargetConfig(BaseConfig):
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS)
detection_target: TargetClassConfig = Field(
default=DEFAULT_DETECTION_CLASS
)
classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES
@ -151,49 +150,36 @@ class Targets(TargetProtocol):
dimension_names: List[str]
detection_class_name: str
def __init__(
self,
detection_class_name: str,
encode_fn: SoundEventEncoder,
decode_fn: SoundEventDecoder,
roi_mapper: ROITargetMapper,
class_names: list[str],
detection_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventCondition] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
):
"""Initialize the Targets object.
def __init__(self, config: TargetConfig):
"""Initialize the Targets object."""
self.config = config
Note: This constructor is typically called internally by the
`build_targets` factory function.
self._filter_fn = build_sound_event_condition(
config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
)
Parameters
----------
encode_fn : SoundEventEncoder
Configured function to encode annotations to class names.
decode_fn : SoundEventDecoder
Configured function to decode class names to tags.
roi_mapper : ROITargetMapper
Configured object for mapping geometry to/from position/size.
class_names : list[str]
Ordered list of specific target class names.
generic_class_tags : List[data.Tag]
List of tags representing the generic class.
filter_fn : SoundEventFilter, optional
Configured function to filter annotations. Defaults to None.
transform_fn : SoundEventTransformation, optional
Configured function to transform annotation tags. Defaults to None.
"""
self.detection_class_name = detection_class_name
self.class_names = class_names
self.detection_class_tags = detection_class_tags
self.dimension_names = roi_mapper.dimension_names
self._roi_mapper = build_roi_mapper(config.roi)
self._roi_mapper = roi_mapper
self._filter_fn = filter_fn
self._encode_fn = encode_fn
self._decode_fn = decode_fn
self._roi_mapper_overrides = roi_mapper_overrides or {}
self.dimension_names = self._roi_mapper.dimension_names
self.class_names = get_class_names_from_config(
config.classification_targets
)
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
@ -218,8 +204,6 @@ class Targets(TargetProtocol):
True if the annotation should be kept (passes the filter),
False otherwise. If no filter was configured, always returns True.
"""
if not self._filter_fn:
return True
return self._filter_fn(sound_event)
def encode_class(
@ -331,7 +315,7 @@ class Targets(TargetProtocol):
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_GENERIC_CLASS,
detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(),
)
@ -339,13 +323,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
This factory function takes the unified `TargetConfig` and constructs all
necessary functional components (filter, transform, encoder,
decoder, ROI mapper) by calling their respective builder functions. It also
extracts metadata (class names, generic tags, dimension names) to create
and return a fully initialized `Targets` instance, ready to process
annotations.
Parameters
----------
config : TargetConfig
@ -370,31 +347,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
lambda: config.to_yaml_string(),
)
filter_fn = build_sound_event_condition(config.detection_target.match_if)
encode_fn = build_sound_event_encoder(config.classification_targets)
decode_fn = build_sound_event_decoder(config.classification_targets)
roi_mapper = build_roi_mapper(config.roi)
class_names = get_class_names_from_config(config.classification_targets)
generic_class_tags = config.detection_target.assign_tags
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
return Targets(
filter_fn=filter_fn,
encode_fn=encode_fn,
decode_fn=decode_fn,
class_names=class_names,
roi_mapper=roi_mapper,
detection_class_name=config.detection_target.name,
detection_class_tags=generic_class_tags,
roi_mapper_overrides=roi_overrides,
)
return Targets(config=config)
def load_targets(

View File

@ -15,6 +15,7 @@ from batdetect2.data.conditions import (
build_sound_event_condition,
)
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import call_type, generic_class
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
__all__ = [
@ -69,24 +70,27 @@ class TargetClassConfig(BaseConfig):
return self._match_if
DEFAULT_GENERIC_CLASS = TargetClassConfig(
DEFAULT_DETECTION_CLASS = TargetClassConfig(
name="bat",
match_if=AllOfConfig(
conditions=[
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")),
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
NotConfig(
condition=HasAnyTagConfig(
tags=[
data.Tag(key="event", value="Feeding"),
data.Tag(key="event", value="Unknown"),
data.Tag(key="event", value="Not Bat"),
data.Tag(term=call_type, value="Feeding"),
data.Tag(term=call_type, value="Social"),
data.Tag(term=call_type, value="Unknown"),
data.Tag(term=generic_class, value="Unknown"),
data.Tag(term=generic_class, value="Not Bat"),
data.Tag(term=call_type, value="Not Bat"),
]
)
),
]
),
assign_tags=[
data.Tag(key="call_type", value="Echolocation"),
data.Tag(term=call_type, value="Echolocation"),
data.Tag(key="order", value="Chiroptera"),
],
)
@ -94,73 +98,73 @@ DEFAULT_GENERIC_CLASS = TargetClassConfig(
DEFAULT_CLASSES = [
TargetClassConfig(
name="myomys",
tags=[data.Tag(key="class", value="Myotis mystacinus")],
),
TargetClassConfig(
name="myoalc",
tags=[data.Tag(key="class", value="Myotis alcathoe")],
name="barbar",
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
),
TargetClassConfig(
name="eptser",
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
),
TargetClassConfig(
name="pipnat",
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
),
TargetClassConfig(
name="barbar",
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
),
TargetClassConfig(
name="myonat",
tags=[data.Tag(key="class", value="Myotis nattereri")],
),
TargetClassConfig(
name="myodau",
tags=[data.Tag(key="class", value="Myotis daubentonii")],
),
TargetClassConfig(
name="myobra",
tags=[data.Tag(key="class", value="Myotis brandtii")],
),
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
name="myoalc",
tags=[data.Tag(key="class", value="Myotis alcathoe")],
),
TargetClassConfig(
name="myobec",
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
),
TargetClassConfig(
name="pippyg",
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
name="myobra",
tags=[data.Tag(key="class", value="Myotis brandtii")],
),
TargetClassConfig(
name="rhihip",
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
name="myodau",
tags=[data.Tag(key="class", value="Myotis daubentonii")],
),
TargetClassConfig(
name="myomys",
tags=[data.Tag(key="class", value="Myotis mystacinus")],
),
TargetClassConfig(
name="myonat",
tags=[data.Tag(key="class", value="Myotis nattereri")],
),
TargetClassConfig(
name="nyclei",
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
),
TargetClassConfig(
name="rhifer",
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
name="nycnoc",
tags=[data.Tag(key="class", value="Nyctalus noctula")],
),
TargetClassConfig(
name="pipnat",
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
),
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
),
TargetClassConfig(
name="pippyg",
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
),
TargetClassConfig(
name="pleaur",
tags=[data.Tag(key="class", value="Plecotus auritus")],
),
TargetClassConfig(
name="nycnoc",
tags=[data.Tag(key="class", value="Nyctalus noctula")],
),
TargetClassConfig(
name="pleaus",
tags=[data.Tag(key="class", value="Plecotus austriacus")],
),
TargetClassConfig(
name="rhifer",
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
),
TargetClassConfig(
name="rhihip",
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
),
]

View File

@ -7,6 +7,7 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.match import (
MatchConfig,
build_matcher,
match_all_predictions,
)
from batdetect2.plotting.clips import PreprocessorProtocol
@ -42,6 +43,8 @@ class ValidationMetrics(Callback):
self.preprocessor = preprocessor
self.plot = plot
self.matcher = build_matcher(config=match_config)
self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[List[RawPrediction]] = []
@ -93,7 +96,7 @@ class ValidationMetrics(Callback):
self._clip_annotations,
self._predictions,
targets=pl_module.model.targets,
config=self.match_config,
matcher=self.matcher,
)
self.log_metrics(pl_module, matches)

View File

@ -11,7 +11,6 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
@ -175,7 +174,6 @@ def build_trainer_callbacks(
ClassificationMeanAveragePrecision(
class_names=targets.class_names
),
ClassificationAccuracy(class_names=targets.class_names),
],
preprocessor=preprocessor,
match_config=config.match,

View File

@ -1,5 +1,15 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from typing import (
Dict,
Generic,
Iterable,
List,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
)
from soundevent import data
@ -40,5 +50,27 @@ class MatchEvaluation:
return self.pred_class_scores[pred_class]
class MatcherProtocol(Protocol):
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ...
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
class AffinityFunction(Protocol, Generic[Geom]):
def __call__(
self,
geometry1: Geom,
geometry2: Geom,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ...
class MetricsProtocol(Protocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...