mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Better matching module, remove generic from classification evaluations
This commit is contained in:
parent
8628133fd7
commit
ec1c0ff020
@ -1,11 +1,6 @@
|
|||||||
from batdetect2.evaluate.config import (
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
EvaluationConfig,
|
|
||||||
load_evaluation_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.match import match_predictions_and_annotations
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
"match_predictions_and_annotations",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.match import MatchConfig
|
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
@ -13,7 +13,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
match: MatchConfig = Field(default_factory=MatchConfig)
|
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
|
|||||||
@ -4,9 +4,8 @@ import pandas as pd
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||||
from batdetect2.evaluate.match import match_all_predictions
|
from batdetect2.evaluate.match import build_matcher, match_all_predictions
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.evaluate.metrics import (
|
||||||
ClassificationAccuracy,
|
|
||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
@ -77,11 +76,13 @@ def evaluate(
|
|||||||
clip_annotations.extend(clip_annotations)
|
clip_annotations.extend(clip_annotations)
|
||||||
predictions.extend(predictions)
|
predictions.extend(predictions)
|
||||||
|
|
||||||
|
matcher = build_matcher(config.evaluation.match)
|
||||||
|
|
||||||
matches = match_all_predictions(
|
matches = match_all_predictions(
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
predictions,
|
predictions,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config.evaluation.match,
|
matcher=matcher,
|
||||||
)
|
)
|
||||||
|
|
||||||
df = extract_matches_dataframe(matches)
|
df = extract_matches_dataframe(matches)
|
||||||
@ -89,7 +90,6 @@ def evaluate(
|
|||||||
metrics = [
|
metrics = [
|
||||||
DetectionAveragePrecision(),
|
DetectionAveragePrecision(),
|
||||||
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
results = {
|
results = {
|
||||||
|
|||||||
@ -1,63 +1,120 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Literal, Optional, Protocol, Tuple
|
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.data._core import Registry
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
|
||||||
MatchingStrategy = Literal["greedy", "optimal"]
|
|
||||||
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
|
||||||
|
|
||||||
|
|
||||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||||
"""The geometry representation to use for matching."""
|
"""The geometry representation to use for matching."""
|
||||||
|
|
||||||
|
matching_strategy = Registry("matching_strategy")
|
||||||
|
|
||||||
|
|
||||||
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
|
name: Literal["start_time"] = "start_time"
|
||||||
|
distance_threshold: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
@matching_strategy.register(StartTimeMatchConfig)
|
||||||
|
class StartTimeMatcher(MatcherProtocol):
|
||||||
|
def __init__(self, distance_threshold: float):
|
||||||
|
self.distance_threshold = distance_threshold
|
||||||
|
|
||||||
class AffinityFunction(Protocol):
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
geometry1: data.Geometry,
|
ground_truth: Sequence[data.Geometry],
|
||||||
geometry2: data.Geometry,
|
predictions: Sequence[data.Geometry],
|
||||||
time_buffer: float = 0.01,
|
scores: Sequence[float],
|
||||||
freq_buffer: float = 1000,
|
):
|
||||||
) -> float: ...
|
return match_start_times(
|
||||||
|
ground_truth,
|
||||||
|
predictions,
|
||||||
|
scores,
|
||||||
|
distance_threshold=self.distance_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher":
|
||||||
|
return cls(distance_threshold=config.distance_threshold)
|
||||||
|
|
||||||
|
|
||||||
class MatchConfig(BaseConfig):
|
def match_start_times(
|
||||||
"""Configuration for matching geometries.
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
distance_threshold: float = 0.01,
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
if not ground_truth:
|
||||||
|
for index in range(len(predictions)):
|
||||||
|
yield index, None, 0
|
||||||
|
|
||||||
Attributes
|
return
|
||||||
----------
|
|
||||||
strategy : MatchingStrategy, default="greedy"
|
|
||||||
The matching algorithm to use. 'greedy' prioritizes high-confidence
|
|
||||||
predictions, while 'optimal' finds the globally best set of matches.
|
|
||||||
geometry : MatchingGeometry, default="timestamp"
|
|
||||||
The geometric representation to use when computing affinity.
|
|
||||||
affinity_threshold : float, default=0.0
|
|
||||||
The minimum affinity score (e.g., IoU) required for a valid match.
|
|
||||||
time_buffer : float, default=0.005
|
|
||||||
Time tolerance in seconds used in affinity calculations.
|
|
||||||
frequency_buffer : float, default=1000
|
|
||||||
Frequency tolerance in Hertz used in affinity calculations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
strategy: MatchingStrategy = "greedy"
|
if not predictions:
|
||||||
geometry: MatchingGeometry = "timestamp"
|
for index in range(len(ground_truth)):
|
||||||
affinity_threshold: float = 0.0
|
yield None, index, 0
|
||||||
time_buffer: float = 0.005
|
|
||||||
frequency_buffer: float = 1_000
|
return
|
||||||
ignore_start_end: float = 0.01
|
|
||||||
|
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
|
||||||
|
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
|
||||||
|
scores = np.array(scores)
|
||||||
|
|
||||||
|
sort_args = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
|
distances = np.abs(gt_times[None, :] - pred_times[:, None])
|
||||||
|
closests = np.argmin(distances, axis=-1)
|
||||||
|
|
||||||
|
unmatched_gt = set(range(len(gt_times)))
|
||||||
|
|
||||||
|
for pred_index in sort_args:
|
||||||
|
# Get the closest ground truth
|
||||||
|
gt_closest_index = closests[pred_index]
|
||||||
|
|
||||||
|
if gt_closest_index not in unmatched_gt:
|
||||||
|
# Does not match if closest has been assigned
|
||||||
|
yield pred_index, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the actual distance
|
||||||
|
distance = distances[pred_index, gt_closest_index]
|
||||||
|
|
||||||
|
if distance > distance_threshold:
|
||||||
|
# Does not match if too far from closest
|
||||||
|
yield pred_index, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Return affinity value: linear interpolation between 0 to 1, where a
|
||||||
|
# distance at the threshold maps to 0 affinity and a zero distance maps
|
||||||
|
# to 1.
|
||||||
|
affinity = np.interp(
|
||||||
|
distance,
|
||||||
|
[0, distance_threshold],
|
||||||
|
[1, 0],
|
||||||
|
left=1,
|
||||||
|
right=0,
|
||||||
|
)
|
||||||
|
unmatched_gt.remove(gt_closest_index)
|
||||||
|
yield pred_index, gt_closest_index, affinity
|
||||||
|
|
||||||
|
for missing_index in unmatched_gt:
|
||||||
|
yield None, missing_index, 0
|
||||||
|
|
||||||
|
|
||||||
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||||
@ -142,50 +199,65 @@ def _interval_affinity(
|
|||||||
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
|
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
|
||||||
"timestamp": _timestamp_affinity,
|
"timestamp": _timestamp_affinity,
|
||||||
"interval": _interval_affinity,
|
"interval": _interval_affinity,
|
||||||
|
"bbox": compute_affinity,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def match_geometries(
|
class GreedyMatchConfig(BaseConfig):
|
||||||
source: List[data.Geometry],
|
name: Literal["greedy_match"] = "greedy_match"
|
||||||
target: List[data.Geometry],
|
geometry: MatchingGeometry = "timestamp"
|
||||||
config: MatchConfig,
|
affinity_threshold: float = 0.0
|
||||||
scores: Optional[List[float]] = None,
|
time_buffer: float = 0.005
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
frequency_buffer: float = 1_000
|
||||||
geometry_cast = _geometry_cast_functions[config.geometry]
|
|
||||||
affinity_function = _affinity_functions.get(
|
|
||||||
config.geometry,
|
|
||||||
compute_affinity,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.strategy == "optimal":
|
|
||||||
return optimal_match(
|
|
||||||
source=[geometry_cast(geom) for geom in source],
|
|
||||||
target=[geometry_cast(geom) for geom in target],
|
|
||||||
time_buffer=config.time_buffer,
|
|
||||||
freq_buffer=config.frequency_buffer,
|
|
||||||
affinity_threshold=config.affinity_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.strategy == "greedy":
|
@matching_strategy.register(GreedyMatchConfig)
|
||||||
|
class GreedyMatcher(MatcherProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
geometry: MatchingGeometry,
|
||||||
|
affinity_threshold: float,
|
||||||
|
time_buffer: float,
|
||||||
|
frequency_buffer: float,
|
||||||
|
):
|
||||||
|
self.geometry = geometry
|
||||||
|
self.affinity_threshold = affinity_threshold
|
||||||
|
self.time_buffer = time_buffer
|
||||||
|
self.frequency_buffer = frequency_buffer
|
||||||
|
|
||||||
|
self.affinity_function = _affinity_functions[self.geometry]
|
||||||
|
self.cast_geometry = _geometry_cast_functions[self.geometry]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
):
|
||||||
return greedy_match(
|
return greedy_match(
|
||||||
source=[geometry_cast(geom) for geom in source],
|
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
|
||||||
target=[geometry_cast(geom) for geom in target],
|
predictions=[self.cast_geometry(geom) for geom in predictions],
|
||||||
time_buffer=config.time_buffer,
|
|
||||||
freq_buffer=config.frequency_buffer,
|
|
||||||
affinity_threshold=config.affinity_threshold,
|
|
||||||
affinity_function=affinity_function,
|
|
||||||
scores=scores,
|
scores=scores,
|
||||||
|
affinity_function=self.affinity_function,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.frequency_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise NotImplementedError(
|
@classmethod
|
||||||
f"Matching strategy not implemented {config.strategy}"
|
def from_config(cls, config: GreedyMatchConfig):
|
||||||
)
|
return cls(
|
||||||
|
geometry=config.geometry,
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
frequency_buffer=config.frequency_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def greedy_match(
|
def greedy_match(
|
||||||
source: List[data.Geometry],
|
ground_truth: Sequence[data.Geometry],
|
||||||
target: List[data.Geometry],
|
predictions: Sequence[data.Geometry],
|
||||||
scores: Optional[List[float]] = None,
|
scores: Sequence[float],
|
||||||
affinity_threshold: float = 0.5,
|
affinity_threshold: float = 0.5,
|
||||||
affinity_function: AffinityFunction = compute_affinity,
|
affinity_function: AffinityFunction = compute_affinity,
|
||||||
time_buffer: float = 0.001,
|
time_buffer: float = 0.001,
|
||||||
@ -221,27 +293,24 @@ def greedy_match(
|
|||||||
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||||
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||||
"""
|
"""
|
||||||
assigned = set()
|
unassigned_gt = set(range(len(ground_truth)))
|
||||||
|
|
||||||
if not source:
|
if not predictions:
|
||||||
for target_idx in range(len(target)):
|
for target_idx in range(len(ground_truth)):
|
||||||
yield None, target_idx, 0
|
yield None, target_idx, 0
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if not target:
|
if not ground_truth:
|
||||||
for source_idx in range(len(source)):
|
for source_idx in range(len(predictions)):
|
||||||
yield source_idx, None, 0
|
yield source_idx, None, 0
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if scores is None:
|
indices = np.argsort(scores)[::-1]
|
||||||
indices = np.arange(len(source))
|
|
||||||
else:
|
|
||||||
indices = np.argsort(scores)[::-1]
|
|
||||||
|
|
||||||
for source_idx in indices:
|
for source_idx in indices:
|
||||||
source_geometry = source[source_idx]
|
source_geometry = predictions[source_idx]
|
||||||
|
|
||||||
affinities = np.array(
|
affinities = np.array(
|
||||||
[
|
[
|
||||||
@ -251,7 +320,7 @@ def greedy_match(
|
|||||||
time_buffer=time_buffer,
|
time_buffer=time_buffer,
|
||||||
freq_buffer=freq_buffer,
|
freq_buffer=freq_buffer,
|
||||||
)
|
)
|
||||||
for target_geometry in target
|
for target_geometry in ground_truth
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -262,18 +331,74 @@ def greedy_match(
|
|||||||
yield source_idx, None, 0
|
yield source_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if closest_target in assigned:
|
if closest_target not in unassigned_gt:
|
||||||
yield source_idx, None, 0
|
yield source_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assigned.add(closest_target)
|
unassigned_gt.remove(closest_target)
|
||||||
yield source_idx, closest_target, affinity
|
yield source_idx, closest_target, affinity
|
||||||
|
|
||||||
missed_ground_truth = set(range(len(target))) - assigned
|
for target_idx in unassigned_gt:
|
||||||
for target_idx in missed_ground_truth:
|
|
||||||
yield None, target_idx, 0
|
yield None, target_idx, 0
|
||||||
|
|
||||||
|
|
||||||
|
class OptimalMatchConfig(BaseConfig):
|
||||||
|
name: Literal["optimal_match"] = "optimal_match"
|
||||||
|
affinity_threshold: float = 0.0
|
||||||
|
time_buffer: float = 0.005
|
||||||
|
frequency_buffer: float = 1_000
|
||||||
|
|
||||||
|
|
||||||
|
@matching_strategy.register(OptimalMatchConfig)
|
||||||
|
class OptimalMatcher(MatcherProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
affinity_threshold: float,
|
||||||
|
time_buffer: float,
|
||||||
|
frequency_buffer: float,
|
||||||
|
):
|
||||||
|
self.affinity_threshold = affinity_threshold
|
||||||
|
self.time_buffer = time_buffer
|
||||||
|
self.frequency_buffer = frequency_buffer
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
):
|
||||||
|
return optimal_match(
|
||||||
|
source=predictions,
|
||||||
|
target=ground_truth,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.frequency_buffer,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: OptimalMatchConfig):
|
||||||
|
return cls(
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
frequency_buffer=config.frequency_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MatchConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
GreedyMatchConfig,
|
||||||
|
StartTimeMatchConfig,
|
||||||
|
OptimalMatchConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
||||||
|
config = config or StartTimeMatchConfig()
|
||||||
|
return matching_strategy.build(config)
|
||||||
|
|
||||||
|
|
||||||
def _is_in_bounds(
|
def _is_in_bounds(
|
||||||
geometry: data.Geometry,
|
geometry: data.Geometry,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
@ -285,13 +410,18 @@ def _is_in_bounds(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
def match_sound_events_and_predictions(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
raw_predictions: List[RawPrediction],
|
raw_predictions: List[RawPrediction],
|
||||||
targets: TargetProtocol,
|
targets: Optional[TargetProtocol] = None,
|
||||||
config: Optional[MatchConfig] = None,
|
matcher: Optional[MatcherProtocol] = None,
|
||||||
|
ignore_start_end: float = 0.01,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
config = config or MatchConfig()
|
if matcher is None:
|
||||||
|
matcher = build_matcher()
|
||||||
|
|
||||||
|
if targets is None:
|
||||||
|
targets = build_targets()
|
||||||
|
|
||||||
target_sound_events = [
|
target_sound_events = [
|
||||||
sound_event_annotation
|
sound_event_annotation
|
||||||
@ -301,7 +431,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
and _is_in_bounds(
|
and _is_in_bounds(
|
||||||
sound_event_annotation.sound_event.geometry,
|
sound_event_annotation.sound_event.geometry,
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
buffer=config.ignore_start_end,
|
buffer=ignore_start_end,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -317,7 +447,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
if _is_in_bounds(
|
if _is_in_bounds(
|
||||||
raw_prediction.geometry,
|
raw_prediction.geometry,
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
buffer=config.ignore_start_end,
|
buffer=ignore_start_end,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -331,10 +461,9 @@ def match_sound_events_and_raw_predictions(
|
|||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
for source_idx, target_idx, affinity in match_geometries(
|
for source_idx, target_idx, affinity in matcher(
|
||||||
source=predicted_geometries,
|
ground_truth=target_geometries,
|
||||||
target=target_geometries,
|
predictions=predicted_geometries,
|
||||||
config=config,
|
|
||||||
scores=scores,
|
scores=scores,
|
||||||
):
|
):
|
||||||
target = (
|
target = (
|
||||||
@ -344,7 +473,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
raw_predictions[source_idx] if source_idx is not None else None
|
raw_predictions[source_idx] if source_idx is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
gt_det = target is not None
|
gt_det = target_idx is not None
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
pred_score = float(prediction.detection_score) if prediction else 0
|
||||||
@ -383,76 +512,12 @@ def match_sound_events_and_raw_predictions(
|
|||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
def match_predictions_and_annotations(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
clip_prediction: data.ClipPrediction,
|
|
||||||
config: Optional[MatchConfig] = None,
|
|
||||||
) -> List[data.Match]:
|
|
||||||
config = config or MatchConfig()
|
|
||||||
|
|
||||||
annotated_sound_events = [
|
|
||||||
sound_event_annotation
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
|
||||||
if sound_event_annotation.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_sound_events = [
|
|
||||||
sound_event_prediction
|
|
||||||
for sound_event_prediction in clip_prediction.sound_events
|
|
||||||
if sound_event_prediction.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
annotated_geometries: List[data.Geometry] = [
|
|
||||||
sound_event.sound_event.geometry
|
|
||||||
for sound_event in annotated_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_geometries: List[data.Geometry] = [
|
|
||||||
sound_event.sound_event.geometry
|
|
||||||
for sound_event in predicted_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
scores = [
|
|
||||||
sound_event.score
|
|
||||||
for sound_event in predicted_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for source_idx, target_idx, affinity in match_geometries(
|
|
||||||
source=predicted_geometries,
|
|
||||||
target=annotated_geometries,
|
|
||||||
config=config,
|
|
||||||
scores=scores,
|
|
||||||
):
|
|
||||||
target = (
|
|
||||||
annotated_sound_events[target_idx]
|
|
||||||
if target_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
source = (
|
|
||||||
predicted_sound_events[source_idx]
|
|
||||||
if source_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
matches.append(
|
|
||||||
data.Match(
|
|
||||||
source=source,
|
|
||||||
target=target,
|
|
||||||
affinity=affinity,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
|
|
||||||
def match_all_predictions(
|
def match_all_predictions(
|
||||||
clip_annotations: List[data.ClipAnnotation],
|
clip_annotations: List[data.ClipAnnotation],
|
||||||
predictions: List[List[RawPrediction]],
|
predictions: List[List[RawPrediction]],
|
||||||
targets: TargetProtocol,
|
targets: Optional[TargetProtocol] = None,
|
||||||
config: Optional[MatchConfig] = None,
|
matcher: Optional[MatcherProtocol] = None,
|
||||||
|
ignore_start_end: float = 0.01,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
logger.info("Matching all annotations and predictions...")
|
logger.info("Matching all annotations and predictions...")
|
||||||
return [
|
return [
|
||||||
@ -461,11 +526,12 @@ def match_all_predictions(
|
|||||||
clip_annotations,
|
clip_annotations,
|
||||||
predictions,
|
predictions,
|
||||||
)
|
)
|
||||||
for match in match_sound_events_and_raw_predictions(
|
for match in match_sound_events_and_predictions(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
raw_predictions,
|
raw_predictions,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config,
|
matcher=matcher,
|
||||||
|
ignore_start_end=ignore_start_end,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -24,10 +24,12 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
|
|
||||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
|
# NOTE: Need to exclude generic but unclassified targets
|
||||||
y_true = label_binarize(
|
y_true = label_binarize(
|
||||||
[
|
[
|
||||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
for match in matches
|
for match in matches
|
||||||
|
if not (match.gt_det and match.gt_class is None)
|
||||||
],
|
],
|
||||||
classes=self.class_names,
|
classes=self.class_names,
|
||||||
)
|
)
|
||||||
@ -38,11 +40,11 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
for name in self.class_names
|
for name in self.class_names
|
||||||
}
|
}
|
||||||
for match in matches
|
for match in matches
|
||||||
|
if not (match.gt_det and match.gt_class is None)
|
||||||
]
|
]
|
||||||
).fillna(0)
|
).fillna(0)
|
||||||
|
|
||||||
ret = {}
|
ret = {}
|
||||||
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
y_true_class = y_true[:, class_index]
|
y_true_class = y_true[:, class_index]
|
||||||
y_pred_class = y_pred[class_name]
|
y_pred_class = y_pred[class_name]
|
||||||
@ -57,39 +59,3 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAccuracy(MetricsProtocol):
|
|
||||||
def __init__(self, class_names: List[str]):
|
|
||||||
self.class_names = class_names
|
|
||||||
|
|
||||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
|
||||||
y_true = [
|
|
||||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
|
||||||
for match in matches
|
|
||||||
]
|
|
||||||
|
|
||||||
y_pred = pd.DataFrame(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
name: match.pred_class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
}
|
|
||||||
for match in matches
|
|
||||||
]
|
|
||||||
).fillna(0)
|
|
||||||
y_pred = y_pred.apply(
|
|
||||||
lambda row: row.idxmax()
|
|
||||||
if row.max() >= (1 - row.sum())
|
|
||||||
else "__NONE__",
|
|
||||||
axis=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
accuracy = metrics.balanced_accuracy_score(
|
|
||||||
y_true,
|
|
||||||
y_pred,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"classification_acc": float(accuracy),
|
|
||||||
}
|
|
||||||
|
|||||||
@ -8,13 +8,10 @@ from pydantic import Field, field_validator
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import build_sound_event_condition
|
||||||
SoundEventCondition,
|
|
||||||
build_sound_event_condition,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
DEFAULT_CLASSES,
|
DEFAULT_CLASSES,
|
||||||
DEFAULT_GENERIC_CLASS,
|
DEFAULT_DETECTION_CLASS,
|
||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
TargetClassConfig,
|
TargetClassConfig,
|
||||||
@ -58,7 +55,9 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class TargetConfig(BaseConfig):
|
class TargetConfig(BaseConfig):
|
||||||
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS)
|
detection_target: TargetClassConfig = Field(
|
||||||
|
default=DEFAULT_DETECTION_CLASS
|
||||||
|
)
|
||||||
|
|
||||||
classification_targets: List[TargetClassConfig] = Field(
|
classification_targets: List[TargetClassConfig] = Field(
|
||||||
default_factory=lambda: DEFAULT_CLASSES
|
default_factory=lambda: DEFAULT_CLASSES
|
||||||
@ -151,49 +150,36 @@ class Targets(TargetProtocol):
|
|||||||
dimension_names: List[str]
|
dimension_names: List[str]
|
||||||
detection_class_name: str
|
detection_class_name: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: TargetConfig):
|
||||||
self,
|
"""Initialize the Targets object."""
|
||||||
detection_class_name: str,
|
self.config = config
|
||||||
encode_fn: SoundEventEncoder,
|
|
||||||
decode_fn: SoundEventDecoder,
|
|
||||||
roi_mapper: ROITargetMapper,
|
|
||||||
class_names: list[str],
|
|
||||||
detection_class_tags: List[data.Tag],
|
|
||||||
filter_fn: Optional[SoundEventCondition] = None,
|
|
||||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
|
||||||
):
|
|
||||||
"""Initialize the Targets object.
|
|
||||||
|
|
||||||
Note: This constructor is typically called internally by the
|
self._filter_fn = build_sound_event_condition(
|
||||||
`build_targets` factory function.
|
config.detection_target.match_if
|
||||||
|
)
|
||||||
|
self._encode_fn = build_sound_event_encoder(
|
||||||
|
config.classification_targets
|
||||||
|
)
|
||||||
|
self._decode_fn = build_sound_event_decoder(
|
||||||
|
config.classification_targets
|
||||||
|
)
|
||||||
|
|
||||||
Parameters
|
self._roi_mapper = build_roi_mapper(config.roi)
|
||||||
----------
|
|
||||||
encode_fn : SoundEventEncoder
|
|
||||||
Configured function to encode annotations to class names.
|
|
||||||
decode_fn : SoundEventDecoder
|
|
||||||
Configured function to decode class names to tags.
|
|
||||||
roi_mapper : ROITargetMapper
|
|
||||||
Configured object for mapping geometry to/from position/size.
|
|
||||||
class_names : list[str]
|
|
||||||
Ordered list of specific target class names.
|
|
||||||
generic_class_tags : List[data.Tag]
|
|
||||||
List of tags representing the generic class.
|
|
||||||
filter_fn : SoundEventFilter, optional
|
|
||||||
Configured function to filter annotations. Defaults to None.
|
|
||||||
transform_fn : SoundEventTransformation, optional
|
|
||||||
Configured function to transform annotation tags. Defaults to None.
|
|
||||||
"""
|
|
||||||
self.detection_class_name = detection_class_name
|
|
||||||
self.class_names = class_names
|
|
||||||
self.detection_class_tags = detection_class_tags
|
|
||||||
self.dimension_names = roi_mapper.dimension_names
|
|
||||||
|
|
||||||
self._roi_mapper = roi_mapper
|
self.dimension_names = self._roi_mapper.dimension_names
|
||||||
self._filter_fn = filter_fn
|
|
||||||
self._encode_fn = encode_fn
|
self.class_names = get_class_names_from_config(
|
||||||
self._decode_fn = decode_fn
|
config.classification_targets
|
||||||
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
)
|
||||||
|
|
||||||
|
self.detection_class_name = config.detection_target.name
|
||||||
|
self.detection_class_tags = config.detection_target.assign_tags
|
||||||
|
|
||||||
|
self._roi_mapper_overrides = {
|
||||||
|
class_config.name: build_roi_mapper(class_config.roi)
|
||||||
|
for class_config in config.classification_targets
|
||||||
|
if class_config.roi is not None
|
||||||
|
}
|
||||||
|
|
||||||
for class_name in self._roi_mapper_overrides:
|
for class_name in self._roi_mapper_overrides:
|
||||||
if class_name not in self.class_names:
|
if class_name not in self.class_names:
|
||||||
@ -218,8 +204,6 @@ class Targets(TargetProtocol):
|
|||||||
True if the annotation should be kept (passes the filter),
|
True if the annotation should be kept (passes the filter),
|
||||||
False otherwise. If no filter was configured, always returns True.
|
False otherwise. If no filter was configured, always returns True.
|
||||||
"""
|
"""
|
||||||
if not self._filter_fn:
|
|
||||||
return True
|
|
||||||
return self._filter_fn(sound_event)
|
return self._filter_fn(sound_event)
|
||||||
|
|
||||||
def encode_class(
|
def encode_class(
|
||||||
@ -331,7 +315,7 @@ class Targets(TargetProtocol):
|
|||||||
|
|
||||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||||
classification_targets=DEFAULT_CLASSES,
|
classification_targets=DEFAULT_CLASSES,
|
||||||
detection_target=DEFAULT_GENERIC_CLASS,
|
detection_target=DEFAULT_DETECTION_CLASS,
|
||||||
roi=AnchorBBoxMapperConfig(),
|
roi=AnchorBBoxMapperConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -339,13 +323,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
This factory function takes the unified `TargetConfig` and constructs all
|
|
||||||
necessary functional components (filter, transform, encoder,
|
|
||||||
decoder, ROI mapper) by calling their respective builder functions. It also
|
|
||||||
extracts metadata (class names, generic tags, dimension names) to create
|
|
||||||
and return a fully initialized `Targets` instance, ready to process
|
|
||||||
annotations.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
config : TargetConfig
|
config : TargetConfig
|
||||||
@ -370,31 +347,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
filter_fn = build_sound_event_condition(config.detection_target.match_if)
|
return Targets(config=config)
|
||||||
encode_fn = build_sound_event_encoder(config.classification_targets)
|
|
||||||
decode_fn = build_sound_event_decoder(config.classification_targets)
|
|
||||||
|
|
||||||
roi_mapper = build_roi_mapper(config.roi)
|
|
||||||
class_names = get_class_names_from_config(config.classification_targets)
|
|
||||||
|
|
||||||
generic_class_tags = config.detection_target.assign_tags
|
|
||||||
|
|
||||||
roi_overrides = {
|
|
||||||
class_config.name: build_roi_mapper(class_config.roi)
|
|
||||||
for class_config in config.classification_targets
|
|
||||||
if class_config.roi is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
return Targets(
|
|
||||||
filter_fn=filter_fn,
|
|
||||||
encode_fn=encode_fn,
|
|
||||||
decode_fn=decode_fn,
|
|
||||||
class_names=class_names,
|
|
||||||
roi_mapper=roi_mapper,
|
|
||||||
detection_class_name=config.detection_target.name,
|
|
||||||
detection_class_tags=generic_class_tags,
|
|
||||||
roi_mapper_overrides=roi_overrides,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_targets(
|
def load_targets(
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from batdetect2.data.conditions import (
|
|||||||
build_sound_event_condition,
|
build_sound_event_condition,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import ROIMapperConfig
|
from batdetect2.targets.rois import ROIMapperConfig
|
||||||
|
from batdetect2.targets.terms import call_type, generic_class
|
||||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -69,24 +70,27 @@ class TargetClassConfig(BaseConfig):
|
|||||||
return self._match_if
|
return self._match_if
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_GENERIC_CLASS = TargetClassConfig(
|
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||||
name="bat",
|
name="bat",
|
||||||
match_if=AllOfConfig(
|
match_if=AllOfConfig(
|
||||||
conditions=[
|
conditions=[
|
||||||
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")),
|
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
||||||
NotConfig(
|
NotConfig(
|
||||||
condition=HasAnyTagConfig(
|
condition=HasAnyTagConfig(
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="event", value="Feeding"),
|
data.Tag(term=call_type, value="Feeding"),
|
||||||
data.Tag(key="event", value="Unknown"),
|
data.Tag(term=call_type, value="Social"),
|
||||||
data.Tag(key="event", value="Not Bat"),
|
data.Tag(term=call_type, value="Unknown"),
|
||||||
|
data.Tag(term=generic_class, value="Unknown"),
|
||||||
|
data.Tag(term=generic_class, value="Not Bat"),
|
||||||
|
data.Tag(term=call_type, value="Not Bat"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
assign_tags=[
|
assign_tags=[
|
||||||
data.Tag(key="call_type", value="Echolocation"),
|
data.Tag(term=call_type, value="Echolocation"),
|
||||||
data.Tag(key="order", value="Chiroptera"),
|
data.Tag(key="order", value="Chiroptera"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -94,73 +98,73 @@ DEFAULT_GENERIC_CLASS = TargetClassConfig(
|
|||||||
|
|
||||||
DEFAULT_CLASSES = [
|
DEFAULT_CLASSES = [
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="myomys",
|
name="barbar",
|
||||||
tags=[data.Tag(key="class", value="Myotis mystacinus")],
|
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myoalc",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis alcathoe")],
|
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="eptser",
|
name="eptser",
|
||||||
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
|
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="pipnat",
|
name="myoalc",
|
||||||
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
|
tags=[data.Tag(key="class", value="Myotis alcathoe")],
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="barbar",
|
|
||||||
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myonat",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis nattereri")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myodau",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis daubentonii")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myobra",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis brandtii")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="pippip",
|
|
||||||
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
|
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="myobec",
|
name="myobec",
|
||||||
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
|
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="pippyg",
|
name="myobra",
|
||||||
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
|
tags=[data.Tag(key="class", value="Myotis brandtii")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="rhihip",
|
name="myodau",
|
||||||
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
tags=[data.Tag(key="class", value="Myotis daubentonii")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myomys",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis mystacinus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myonat",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis nattereri")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="nyclei",
|
name="nyclei",
|
||||||
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
|
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="rhifer",
|
name="nycnoc",
|
||||||
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
tags=[data.Tag(key="class", value="Nyctalus noctula")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pipnat",
|
||||||
|
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pippip",
|
||||||
|
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pippyg",
|
||||||
|
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="pleaur",
|
name="pleaur",
|
||||||
tags=[data.Tag(key="class", value="Plecotus auritus")],
|
tags=[data.Tag(key="class", value="Plecotus auritus")],
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
|
||||||
name="nycnoc",
|
|
||||||
tags=[data.Tag(key="class", value="Nyctalus noctula")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="pleaus",
|
name="pleaus",
|
||||||
tags=[data.Tag(key="class", value="Plecotus austriacus")],
|
tags=[data.Tag(key="class", value="Plecotus austriacus")],
|
||||||
),
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="rhifer",
|
||||||
|
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="rhihip",
|
||||||
|
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from batdetect2.evaluate.match import (
|
from batdetect2.evaluate.match import (
|
||||||
MatchConfig,
|
MatchConfig,
|
||||||
|
build_matcher,
|
||||||
match_all_predictions,
|
match_all_predictions,
|
||||||
)
|
)
|
||||||
from batdetect2.plotting.clips import PreprocessorProtocol
|
from batdetect2.plotting.clips import PreprocessorProtocol
|
||||||
@ -42,6 +43,8 @@ class ValidationMetrics(Callback):
|
|||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
|
|
||||||
|
self.matcher = build_matcher(config=match_config)
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self._predictions: List[List[RawPrediction]] = []
|
self._predictions: List[List[RawPrediction]] = []
|
||||||
|
|
||||||
@ -93,7 +96,7 @@ class ValidationMetrics(Callback):
|
|||||||
self._clip_annotations,
|
self._clip_annotations,
|
||||||
self._predictions,
|
self._predictions,
|
||||||
targets=pl_module.model.targets,
|
targets=pl_module.model.targets,
|
||||||
config=self.match_config,
|
matcher=self.matcher,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log_metrics(pl_module, matches)
|
self.log_metrics(pl_module, matches)
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.evaluate.metrics import (
|
||||||
ClassificationAccuracy,
|
|
||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
@ -175,7 +174,6 @@ def build_trainer_callbacks(
|
|||||||
ClassificationMeanAveragePrecision(
|
ClassificationMeanAveragePrecision(
|
||||||
class_names=targets.class_names
|
class_names=targets.class_names
|
||||||
),
|
),
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
|
||||||
],
|
],
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
match_config=config.match,
|
match_config=config.match,
|
||||||
|
|||||||
@ -1,5 +1,15 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Protocol
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -40,5 +50,27 @@ class MatchEvaluation:
|
|||||||
return self.pred_class_scores[pred_class]
|
return self.pred_class_scores[pred_class]
|
||||||
|
|
||||||
|
|
||||||
|
class MatcherProtocol(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AffinityFunction(Protocol, Generic[Geom]):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
geometry1: Geom,
|
||||||
|
geometry2: Geom,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float: ...
|
||||||
|
|
||||||
|
|
||||||
class MetricsProtocol(Protocol):
|
class MetricsProtocol(Protocol):
|
||||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user