Better matching module, remove generic from classification evaluations

This commit is contained in:
mbsantiago 2025-09-14 18:16:59 +01:00
parent 8628133fd7
commit ec1c0ff020
10 changed files with 355 additions and 338 deletions

View File

@ -1,11 +1,6 @@
from batdetect2.evaluate.config import ( from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
EvaluationConfig,
load_evaluation_config,
)
from batdetect2.evaluate.match import match_predictions_and_annotations
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
"load_evaluation_config", "load_evaluation_config",
"match_predictions_and_annotations",
] ]

View File

@ -4,7 +4,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
@ -13,7 +13,7 @@ __all__ = [
class EvaluationConfig(BaseConfig): class EvaluationConfig(BaseConfig):
match: MatchConfig = Field(default_factory=MatchConfig) match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
def load_evaluation_config( def load_evaluation_config(

View File

@ -4,9 +4,8 @@ import pandas as pd
from soundevent import data from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.match import match_all_predictions from batdetect2.evaluate.match import build_matcher, match_all_predictions
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
DetectionAveragePrecision, DetectionAveragePrecision,
) )
@ -77,11 +76,13 @@ def evaluate(
clip_annotations.extend(clip_annotations) clip_annotations.extend(clip_annotations)
predictions.extend(predictions) predictions.extend(predictions)
matcher = build_matcher(config.evaluation.match)
matches = match_all_predictions( matches = match_all_predictions(
clip_annotations, clip_annotations,
predictions, predictions,
targets=targets, targets=targets,
config=config.evaluation.match, matcher=matcher,
) )
df = extract_matches_dataframe(matches) df = extract_matches_dataframe(matches)
@ -89,7 +90,6 @@ def evaluate(
metrics = [ metrics = [
DetectionAveragePrecision(), DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(class_names=targets.class_names), ClassificationMeanAveragePrecision(class_names=targets.class_names),
ClassificationAccuracy(class_names=targets.class_names),
] ]
results = { results = {

View File

@ -1,63 +1,120 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Literal, Optional, Protocol, Tuple from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.targets import build_targets
from batdetect2.typing import ( from batdetect2.typing import (
MatchEvaluation, MatchEvaluation,
TargetProtocol, TargetProtocol,
) )
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction from batdetect2.typing.postprocess import RawPrediction
MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
MatchingGeometry = Literal["bbox", "interval", "timestamp"] MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching.""" """The geometry representation to use for matching."""
matching_strategy = Registry("matching_strategy")
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time"] = "start_time"
distance_threshold: float = 0.01
@matching_strategy.register(StartTimeMatchConfig)
class StartTimeMatcher(MatcherProtocol):
def __init__(self, distance_threshold: float):
self.distance_threshold = distance_threshold
class AffinityFunction(Protocol):
def __call__( def __call__(
self, self,
geometry1: data.Geometry, ground_truth: Sequence[data.Geometry],
geometry2: data.Geometry, predictions: Sequence[data.Geometry],
time_buffer: float = 0.01, scores: Sequence[float],
freq_buffer: float = 1000, ):
) -> float: ... return match_start_times(
ground_truth,
predictions,
scores,
distance_threshold=self.distance_threshold,
)
@classmethod
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher":
return cls(distance_threshold=config.distance_threshold)
class MatchConfig(BaseConfig): def match_start_times(
"""Configuration for matching geometries. ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
distance_threshold: float = 0.01,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
if not ground_truth:
for index in range(len(predictions)):
yield index, None, 0
Attributes return
----------
strategy : MatchingStrategy, default="greedy"
The matching algorithm to use. 'greedy' prioritizes high-confidence
predictions, while 'optimal' finds the globally best set of matches.
geometry : MatchingGeometry, default="timestamp"
The geometric representation to use when computing affinity.
affinity_threshold : float, default=0.0
The minimum affinity score (e.g., IoU) required for a valid match.
time_buffer : float, default=0.005
Time tolerance in seconds used in affinity calculations.
frequency_buffer : float, default=1000
Frequency tolerance in Hertz used in affinity calculations.
"""
strategy: MatchingStrategy = "greedy" if not predictions:
geometry: MatchingGeometry = "timestamp" for index in range(len(ground_truth)):
affinity_threshold: float = 0.0 yield None, index, 0
time_buffer: float = 0.005
frequency_buffer: float = 1_000 return
ignore_start_end: float = 0.01
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
scores = np.array(scores)
sort_args = np.argsort(scores)[::-1]
distances = np.abs(gt_times[None, :] - pred_times[:, None])
closests = np.argmin(distances, axis=-1)
unmatched_gt = set(range(len(gt_times)))
for pred_index in sort_args:
# Get the closest ground truth
gt_closest_index = closests[pred_index]
if gt_closest_index not in unmatched_gt:
# Does not match if closest has been assigned
yield pred_index, None, 0
continue
# Get the actual distance
distance = distances[pred_index, gt_closest_index]
if distance > distance_threshold:
# Does not match if too far from closest
yield pred_index, None, 0
continue
# Return affinity value: linear interpolation between 0 to 1, where a
# distance at the threshold maps to 0 affinity and a zero distance maps
# to 1.
affinity = np.interp(
distance,
[0, distance_threshold],
[1, 0],
left=1,
right=0,
)
unmatched_gt.remove(gt_closest_index)
yield pred_index, gt_closest_index, affinity
for missing_index in unmatched_gt:
yield None, missing_index, 0
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox: def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
@ -142,50 +199,65 @@ def _interval_affinity(
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = { _affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
"timestamp": _timestamp_affinity, "timestamp": _timestamp_affinity,
"interval": _interval_affinity, "interval": _interval_affinity,
"bbox": compute_affinity,
} }
def match_geometries( class GreedyMatchConfig(BaseConfig):
source: List[data.Geometry], name: Literal["greedy_match"] = "greedy_match"
target: List[data.Geometry], geometry: MatchingGeometry = "timestamp"
config: MatchConfig, affinity_threshold: float = 0.0
scores: Optional[List[float]] = None, time_buffer: float = 0.005
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: frequency_buffer: float = 1_000
geometry_cast = _geometry_cast_functions[config.geometry]
affinity_function = _affinity_functions.get(
config.geometry,
compute_affinity,
)
if config.strategy == "optimal":
return optimal_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
)
if config.strategy == "greedy": @matching_strategy.register(GreedyMatchConfig)
class GreedyMatcher(MatcherProtocol):
def __init__(
self,
geometry: MatchingGeometry,
affinity_threshold: float,
time_buffer: float,
frequency_buffer: float,
):
self.geometry = geometry
self.affinity_threshold = affinity_threshold
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.affinity_function = _affinity_functions[self.geometry]
self.cast_geometry = _geometry_cast_functions[self.geometry]
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return greedy_match( return greedy_match(
source=[geometry_cast(geom) for geom in source], ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
target=[geometry_cast(geom) for geom in target], predictions=[self.cast_geometry(geom) for geom in predictions],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
scores=scores, scores=scores,
affinity_function=self.affinity_function,
affinity_threshold=self.affinity_threshold,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
) )
raise NotImplementedError( @classmethod
f"Matching strategy not implemented {config.strategy}" def from_config(cls, config: GreedyMatchConfig):
return cls(
geometry=config.geometry,
affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
) )
def greedy_match( def greedy_match(
source: List[data.Geometry], ground_truth: Sequence[data.Geometry],
target: List[data.Geometry], predictions: Sequence[data.Geometry],
scores: Optional[List[float]] = None, scores: Sequence[float],
affinity_threshold: float = 0.5, affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity, affinity_function: AffinityFunction = compute_affinity,
time_buffer: float = 0.001, time_buffer: float = 0.001,
@ -221,27 +293,24 @@ def greedy_match(
- Unmatched Source (False Positive): `(source_idx, None, 0)` - Unmatched Source (False Positive): `(source_idx, None, 0)`
- Unmatched Target (False Negative): `(None, target_idx, 0)` - Unmatched Target (False Negative): `(None, target_idx, 0)`
""" """
assigned = set() unassigned_gt = set(range(len(ground_truth)))
if not source: if not predictions:
for target_idx in range(len(target)): for target_idx in range(len(ground_truth)):
yield None, target_idx, 0 yield None, target_idx, 0
return return
if not target: if not ground_truth:
for source_idx in range(len(source)): for source_idx in range(len(predictions)):
yield source_idx, None, 0 yield source_idx, None, 0
return return
if scores is None:
indices = np.arange(len(source))
else:
indices = np.argsort(scores)[::-1] indices = np.argsort(scores)[::-1]
for source_idx in indices: for source_idx in indices:
source_geometry = source[source_idx] source_geometry = predictions[source_idx]
affinities = np.array( affinities = np.array(
[ [
@ -251,7 +320,7 @@ def greedy_match(
time_buffer=time_buffer, time_buffer=time_buffer,
freq_buffer=freq_buffer, freq_buffer=freq_buffer,
) )
for target_geometry in target for target_geometry in ground_truth
] ]
) )
@ -262,18 +331,74 @@ def greedy_match(
yield source_idx, None, 0 yield source_idx, None, 0
continue continue
if closest_target in assigned: if closest_target not in unassigned_gt:
yield source_idx, None, 0 yield source_idx, None, 0
continue continue
assigned.add(closest_target) unassigned_gt.remove(closest_target)
yield source_idx, closest_target, affinity yield source_idx, closest_target, affinity
missed_ground_truth = set(range(len(target))) - assigned for target_idx in unassigned_gt:
for target_idx in missed_ground_truth:
yield None, target_idx, 0 yield None, target_idx, 0
class OptimalMatchConfig(BaseConfig):
name: Literal["optimal_match"] = "optimal_match"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000
@matching_strategy.register(OptimalMatchConfig)
class OptimalMatcher(MatcherProtocol):
def __init__(
self,
affinity_threshold: float,
time_buffer: float,
frequency_buffer: float,
):
self.affinity_threshold = affinity_threshold
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return optimal_match(
source=predictions,
target=ground_truth,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
affinity_threshold=self.affinity_threshold,
)
@classmethod
def from_config(cls, config: OptimalMatchConfig):
return cls(
affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
)
MatchConfig = Annotated[
Union[
GreedyMatchConfig,
StartTimeMatchConfig,
OptimalMatchConfig,
],
Field(discriminator="name"),
]
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategy.build(config)
def _is_in_bounds( def _is_in_bounds(
geometry: data.Geometry, geometry: data.Geometry,
clip: data.Clip, clip: data.Clip,
@ -285,13 +410,18 @@ def _is_in_bounds(
) )
def match_sound_events_and_raw_predictions( def match_sound_events_and_predictions(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
raw_predictions: List[RawPrediction], raw_predictions: List[RawPrediction],
targets: TargetProtocol, targets: Optional[TargetProtocol] = None,
config: Optional[MatchConfig] = None, matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
config = config or MatchConfig() if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_sound_events = [ target_sound_events = [
sound_event_annotation sound_event_annotation
@ -301,7 +431,7 @@ def match_sound_events_and_raw_predictions(
and _is_in_bounds( and _is_in_bounds(
sound_event_annotation.sound_event.geometry, sound_event_annotation.sound_event.geometry,
clip=clip_annotation.clip, clip=clip_annotation.clip,
buffer=config.ignore_start_end, buffer=ignore_start_end,
) )
] ]
@ -317,7 +447,7 @@ def match_sound_events_and_raw_predictions(
if _is_in_bounds( if _is_in_bounds(
raw_prediction.geometry, raw_prediction.geometry,
clip=clip_annotation.clip, clip=clip_annotation.clip,
buffer=config.ignore_start_end, buffer=ignore_start_end,
) )
] ]
@ -331,10 +461,9 @@ def match_sound_events_and_raw_predictions(
matches = [] matches = []
for source_idx, target_idx, affinity in match_geometries( for source_idx, target_idx, affinity in matcher(
source=predicted_geometries, ground_truth=target_geometries,
target=target_geometries, predictions=predicted_geometries,
config=config,
scores=scores, scores=scores,
): ):
target = ( target = (
@ -344,7 +473,7 @@ def match_sound_events_and_raw_predictions(
raw_predictions[source_idx] if source_idx is not None else None raw_predictions[source_idx] if source_idx is not None else None
) )
gt_det = target is not None gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0 pred_score = float(prediction.detection_score) if prediction else 0
@ -383,76 +512,12 @@ def match_sound_events_and_raw_predictions(
return matches return matches
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]:
config = config or MatchConfig()
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
scores = [
sound_event.score
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for source_idx, target_idx, affinity in match_geometries(
source=predicted_geometries,
target=annotated_geometries,
config=config,
scores=scores,
):
target = (
annotated_sound_events[target_idx]
if target_idx is not None
else None
)
source = (
predicted_sound_events[source_idx]
if source_idx is not None
else None
)
matches.append(
data.Match(
source=source,
target=target,
affinity=affinity,
)
)
return matches
def match_all_predictions( def match_all_predictions(
clip_annotations: List[data.ClipAnnotation], clip_annotations: List[data.ClipAnnotation],
predictions: List[List[RawPrediction]], predictions: List[List[RawPrediction]],
targets: TargetProtocol, targets: Optional[TargetProtocol] = None,
config: Optional[MatchConfig] = None, matcher: Optional[MatcherProtocol] = None,
ignore_start_end: float = 0.01,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...") logger.info("Matching all annotations and predictions...")
return [ return [
@ -461,11 +526,12 @@ def match_all_predictions(
clip_annotations, clip_annotations,
predictions, predictions,
) )
for match in match_sound_events_and_raw_predictions( for match in match_sound_events_and_predictions(
clip_annotation, clip_annotation,
raw_predictions, raw_predictions,
targets=targets, targets=targets,
config=config, matcher=matcher,
ignore_start_end=ignore_start_end,
) )
] ]

View File

@ -24,10 +24,12 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
self.class_names = class_names self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
# NOTE: Need to exclude generic but unclassified targets
y_true = label_binarize( y_true = label_binarize(
[ [
match.gt_class if match.gt_class is not None else "__NONE__" match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches for match in matches
if not (match.gt_det and match.gt_class is None)
], ],
classes=self.class_names, classes=self.class_names,
) )
@ -38,11 +40,11 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
for name in self.class_names for name in self.class_names
} }
for match in matches for match in matches
if not (match.gt_det and match.gt_class is None)
] ]
).fillna(0) ).fillna(0)
ret = {} ret = {}
for class_index, class_name in enumerate(self.class_names): for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index] y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name] y_pred_class = y_pred[class_name]
@ -57,39 +59,3 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
) )
return ret return ret
class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
]
y_pred = pd.DataFrame(
[
{
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
y_pred = y_pred.apply(
lambda row: row.idxmax()
if row.max() >= (1 - row.sum())
else "__NONE__",
axis=1,
)
accuracy = metrics.balanced_accuracy_score(
y_true,
y_pred,
)
return {
"classification_acc": float(accuracy),
}

View File

@ -8,13 +8,10 @@ from pydantic import Field, field_validator
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.conditions import ( from batdetect2.data.conditions import build_sound_event_condition
SoundEventCondition,
build_sound_event_condition,
)
from batdetect2.targets.classes import ( from batdetect2.targets.classes import (
DEFAULT_CLASSES, DEFAULT_CLASSES,
DEFAULT_GENERIC_CLASS, DEFAULT_DETECTION_CLASS,
SoundEventDecoder, SoundEventDecoder,
SoundEventEncoder, SoundEventEncoder,
TargetClassConfig, TargetClassConfig,
@ -58,7 +55,9 @@ __all__ = [
class TargetConfig(BaseConfig): class TargetConfig(BaseConfig):
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS) detection_target: TargetClassConfig = Field(
default=DEFAULT_DETECTION_CLASS
)
classification_targets: List[TargetClassConfig] = Field( classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES default_factory=lambda: DEFAULT_CLASSES
@ -151,49 +150,36 @@ class Targets(TargetProtocol):
dimension_names: List[str] dimension_names: List[str]
detection_class_name: str detection_class_name: str
def __init__( def __init__(self, config: TargetConfig):
self, """Initialize the Targets object."""
detection_class_name: str, self.config = config
encode_fn: SoundEventEncoder,
decode_fn: SoundEventDecoder,
roi_mapper: ROITargetMapper,
class_names: list[str],
detection_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventCondition] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
):
"""Initialize the Targets object.
Note: This constructor is typically called internally by the self._filter_fn = build_sound_event_condition(
`build_targets` factory function. config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
)
Parameters self._roi_mapper = build_roi_mapper(config.roi)
----------
encode_fn : SoundEventEncoder
Configured function to encode annotations to class names.
decode_fn : SoundEventDecoder
Configured function to decode class names to tags.
roi_mapper : ROITargetMapper
Configured object for mapping geometry to/from position/size.
class_names : list[str]
Ordered list of specific target class names.
generic_class_tags : List[data.Tag]
List of tags representing the generic class.
filter_fn : SoundEventFilter, optional
Configured function to filter annotations. Defaults to None.
transform_fn : SoundEventTransformation, optional
Configured function to transform annotation tags. Defaults to None.
"""
self.detection_class_name = detection_class_name
self.class_names = class_names
self.detection_class_tags = detection_class_tags
self.dimension_names = roi_mapper.dimension_names
self._roi_mapper = roi_mapper self.dimension_names = self._roi_mapper.dimension_names
self._filter_fn = filter_fn
self._encode_fn = encode_fn self.class_names = get_class_names_from_config(
self._decode_fn = decode_fn config.classification_targets
self._roi_mapper_overrides = roi_mapper_overrides or {} )
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides: for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names: if class_name not in self.class_names:
@ -218,8 +204,6 @@ class Targets(TargetProtocol):
True if the annotation should be kept (passes the filter), True if the annotation should be kept (passes the filter),
False otherwise. If no filter was configured, always returns True. False otherwise. If no filter was configured, always returns True.
""" """
if not self._filter_fn:
return True
return self._filter_fn(sound_event) return self._filter_fn(sound_event)
def encode_class( def encode_class(
@ -331,7 +315,7 @@ class Targets(TargetProtocol):
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES, classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_GENERIC_CLASS, detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(), roi=AnchorBBoxMapperConfig(),
) )
@ -339,13 +323,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
def build_targets(config: Optional[TargetConfig] = None) -> Targets: def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig. """Build a Targets object from a loaded TargetConfig.
This factory function takes the unified `TargetConfig` and constructs all
necessary functional components (filter, transform, encoder,
decoder, ROI mapper) by calling their respective builder functions. It also
extracts metadata (class names, generic tags, dimension names) to create
and return a fully initialized `Targets` instance, ready to process
annotations.
Parameters Parameters
---------- ----------
config : TargetConfig config : TargetConfig
@ -370,31 +347,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
filter_fn = build_sound_event_condition(config.detection_target.match_if) return Targets(config=config)
encode_fn = build_sound_event_encoder(config.classification_targets)
decode_fn = build_sound_event_decoder(config.classification_targets)
roi_mapper = build_roi_mapper(config.roi)
class_names = get_class_names_from_config(config.classification_targets)
generic_class_tags = config.detection_target.assign_tags
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
return Targets(
filter_fn=filter_fn,
encode_fn=encode_fn,
decode_fn=decode_fn,
class_names=class_names,
roi_mapper=roi_mapper,
detection_class_name=config.detection_target.name,
detection_class_tags=generic_class_tags,
roi_mapper_overrides=roi_overrides,
)
def load_targets( def load_targets(

View File

@ -15,6 +15,7 @@ from batdetect2.data.conditions import (
build_sound_event_condition, build_sound_event_condition,
) )
from batdetect2.targets.rois import ROIMapperConfig from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import call_type, generic_class
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
__all__ = [ __all__ = [
@ -69,24 +70,27 @@ class TargetClassConfig(BaseConfig):
return self._match_if return self._match_if
DEFAULT_GENERIC_CLASS = TargetClassConfig( DEFAULT_DETECTION_CLASS = TargetClassConfig(
name="bat", name="bat",
match_if=AllOfConfig( match_if=AllOfConfig(
conditions=[ conditions=[
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")), HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
NotConfig( NotConfig(
condition=HasAnyTagConfig( condition=HasAnyTagConfig(
tags=[ tags=[
data.Tag(key="event", value="Feeding"), data.Tag(term=call_type, value="Feeding"),
data.Tag(key="event", value="Unknown"), data.Tag(term=call_type, value="Social"),
data.Tag(key="event", value="Not Bat"), data.Tag(term=call_type, value="Unknown"),
data.Tag(term=generic_class, value="Unknown"),
data.Tag(term=generic_class, value="Not Bat"),
data.Tag(term=call_type, value="Not Bat"),
] ]
) )
), ),
] ]
), ),
assign_tags=[ assign_tags=[
data.Tag(key="call_type", value="Echolocation"), data.Tag(term=call_type, value="Echolocation"),
data.Tag(key="order", value="Chiroptera"), data.Tag(key="order", value="Chiroptera"),
], ],
) )
@ -94,73 +98,73 @@ DEFAULT_GENERIC_CLASS = TargetClassConfig(
DEFAULT_CLASSES = [ DEFAULT_CLASSES = [
TargetClassConfig( TargetClassConfig(
name="myomys", name="barbar",
tags=[data.Tag(key="class", value="Myotis mystacinus")], tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
),
TargetClassConfig(
name="myoalc",
tags=[data.Tag(key="class", value="Myotis alcathoe")],
), ),
TargetClassConfig( TargetClassConfig(
name="eptser", name="eptser",
tags=[data.Tag(key="class", value="Eptesicus serotinus")], tags=[data.Tag(key="class", value="Eptesicus serotinus")],
), ),
TargetClassConfig( TargetClassConfig(
name="pipnat", name="myoalc",
tags=[data.Tag(key="class", value="Pipistrellus nathusii")], tags=[data.Tag(key="class", value="Myotis alcathoe")],
),
TargetClassConfig(
name="barbar",
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
),
TargetClassConfig(
name="myonat",
tags=[data.Tag(key="class", value="Myotis nattereri")],
),
TargetClassConfig(
name="myodau",
tags=[data.Tag(key="class", value="Myotis daubentonii")],
),
TargetClassConfig(
name="myobra",
tags=[data.Tag(key="class", value="Myotis brandtii")],
),
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
), ),
TargetClassConfig( TargetClassConfig(
name="myobec", name="myobec",
tags=[data.Tag(key="class", value="Myotis bechsteinii")], tags=[data.Tag(key="class", value="Myotis bechsteinii")],
), ),
TargetClassConfig( TargetClassConfig(
name="pippyg", name="myobra",
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")], tags=[data.Tag(key="class", value="Myotis brandtii")],
), ),
TargetClassConfig( TargetClassConfig(
name="rhihip", name="myodau",
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")], tags=[data.Tag(key="class", value="Myotis daubentonii")],
),
TargetClassConfig(
name="myomys",
tags=[data.Tag(key="class", value="Myotis mystacinus")],
),
TargetClassConfig(
name="myonat",
tags=[data.Tag(key="class", value="Myotis nattereri")],
), ),
TargetClassConfig( TargetClassConfig(
name="nyclei", name="nyclei",
tags=[data.Tag(key="class", value="Nyctalus leisleri")], tags=[data.Tag(key="class", value="Nyctalus leisleri")],
), ),
TargetClassConfig( TargetClassConfig(
name="rhifer", name="nycnoc",
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")], tags=[data.Tag(key="class", value="Nyctalus noctula")],
),
TargetClassConfig(
name="pipnat",
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
),
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
),
TargetClassConfig(
name="pippyg",
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
), ),
TargetClassConfig( TargetClassConfig(
name="pleaur", name="pleaur",
tags=[data.Tag(key="class", value="Plecotus auritus")], tags=[data.Tag(key="class", value="Plecotus auritus")],
), ),
TargetClassConfig(
name="nycnoc",
tags=[data.Tag(key="class", value="Nyctalus noctula")],
),
TargetClassConfig( TargetClassConfig(
name="pleaus", name="pleaus",
tags=[data.Tag(key="class", value="Plecotus austriacus")], tags=[data.Tag(key="class", value="Plecotus austriacus")],
), ),
TargetClassConfig(
name="rhifer",
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
),
TargetClassConfig(
name="rhihip",
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
),
] ]

View File

@ -7,6 +7,7 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.match import ( from batdetect2.evaluate.match import (
MatchConfig, MatchConfig,
build_matcher,
match_all_predictions, match_all_predictions,
) )
from batdetect2.plotting.clips import PreprocessorProtocol from batdetect2.plotting.clips import PreprocessorProtocol
@ -42,6 +43,8 @@ class ValidationMetrics(Callback):
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.plot = plot self.plot = plot
self.matcher = build_matcher(config=match_config)
self._clip_annotations: List[data.ClipAnnotation] = [] self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[List[RawPrediction]] = [] self._predictions: List[List[RawPrediction]] = []
@ -93,7 +96,7 @@ class ValidationMetrics(Callback):
self._clip_annotations, self._clip_annotations,
self._predictions, self._predictions,
targets=pl_module.model.targets, targets=pl_module.model.targets,
config=self.match_config, matcher=self.matcher,
) )
self.log_metrics(pl_module, matches) self.log_metrics(pl_module, matches)

View File

@ -11,7 +11,6 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
DetectionAveragePrecision, DetectionAveragePrecision,
) )
@ -175,7 +174,6 @@ def build_trainer_callbacks(
ClassificationMeanAveragePrecision( ClassificationMeanAveragePrecision(
class_names=targets.class_names class_names=targets.class_names
), ),
ClassificationAccuracy(class_names=targets.class_names),
], ],
preprocessor=preprocessor, preprocessor=preprocessor,
match_config=config.match, match_config=config.match,

View File

@ -1,5 +1,15 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol from typing import (
Dict,
Generic,
Iterable,
List,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
)
from soundevent import data from soundevent import data
@ -40,5 +50,27 @@ class MatchEvaluation:
return self.pred_class_scores[pred_class] return self.pred_class_scores[pred_class]
class MatcherProtocol(Protocol):
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ...
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
class AffinityFunction(Protocol, Generic[Geom]):
def __call__(
self,
geometry1: Geom,
geometry2: Geom,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ...
class MetricsProtocol(Protocol): class MetricsProtocol(Protocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ... def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...