diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index f2642f8..c39533c 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -62,6 +62,10 @@ class BaseConfig(BaseModel): ) ) + @classmethod + def from_yaml(cls, yaml_str: str): + return cls.model_validate(yaml.safe_load(yaml_str)) + T = TypeVar("T", bound=BaseModel) diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index a95cb75..8f5084e 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -2,75 +2,98 @@ from typing import Annotated, Literal from pydantic import Field from soundevent import data -from soundevent.evaluation import compute_affinity -from soundevent.geometry import compute_interval_overlap +from soundevent.geometry import ( + buffer_geometry, + compute_bbox_iou, + compute_geometric_iou, + compute_temporal_closeness, + compute_temporal_iou, +) -from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import Registry -from batdetect2.typing.evaluate import AffinityFunction +from batdetect2.core import BaseConfig, Registry +from batdetect2.typing import AffinityFunction, RawPrediction affinity_functions: Registry[AffinityFunction, []] = Registry( - "matching_strategy" + "affinity_function" ) class TimeAffinityConfig(BaseConfig): name: Literal["time_affinity"] = "time_affinity" - time_buffer: float = 0.01 + position: Literal["start", "end", "center"] | float = "start" + max_distance: float = 0.01 class TimeAffinity(AffinityFunction): - def __init__(self, time_buffer: float): - self.time_buffer = time_buffer + def __init__( + self, + max_distance: float = 0.01, + position: Literal["start", "end", "center"] | float = "start", + ): + if position == "start": + position = 0 + elif position == "end": + position = 1 + elif position == "center": + position = 0.5 - def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): - return compute_timestamp_affinity( - geometry1, geometry2, time_buffer=self.time_buffer + self.position = position + self.max_distance = max_distance + + def __call__( + self, + detection: RawPrediction, + ground_truth: data.SoundEventAnnotation, + ) -> float: + target_geometry = ground_truth.sound_event.geometry + source_geometry = detection.geometry + return compute_temporal_closeness( + target_geometry, + source_geometry, + ratio=self.position, + max_distance=self.max_distance, ) @affinity_functions.register(TimeAffinityConfig) @staticmethod def from_config(config: TimeAffinityConfig): - return TimeAffinity(time_buffer=config.time_buffer) - - -def compute_timestamp_affinity( - geometry1: data.Geometry, - geometry2: data.Geometry, - time_buffer: float = 0.01, -) -> float: - assert isinstance(geometry1, data.TimeStamp) - assert isinstance(geometry2, data.TimeStamp) - - start_time1 = geometry1.coordinates - start_time2 = geometry2.coordinates - - a = min(start_time1, start_time2) - b = max(start_time1, start_time2) - - if b - a >= 2 * time_buffer: - return 0 - - intersection = a - b + 2 * time_buffer - union = b - a + 2 * time_buffer - return intersection / union + return TimeAffinity( + max_distance=config.max_distance, + position=config.position, + ) class IntervalIOUConfig(BaseConfig): name: Literal["interval_iou"] = "interval_iou" - time_buffer: float = 0.01 + time_buffer: float = 0.0 class IntervalIOU(AffinityFunction): def __init__(self, time_buffer: float): + if time_buffer < 0: + raise ValueError("time_buffer must be non-negative") + self.time_buffer = time_buffer - def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): - return compute_interval_iou( - geometry1, - geometry2, - time_buffer=self.time_buffer, - ) + def __call__( + self, + detection: RawPrediction, + ground_truth: data.SoundEventAnnotation, + ) -> float: + target_geometry = ground_truth.sound_event.geometry + source_geometry = detection.geometry + + if self.time_buffer > 0: + target_geometry = buffer_geometry( + target_geometry, + time=self.time_buffer, + ) + source_geometry = buffer_geometry( + source_geometry, + time=self.time_buffer, + ) + + return compute_temporal_iou(target_geometry, source_geometry) @affinity_functions.register(IntervalIOUConfig) @staticmethod @@ -78,64 +101,44 @@ class IntervalIOU(AffinityFunction): return IntervalIOU(time_buffer=config.time_buffer) -def compute_interval_iou( - geometry1: data.Geometry, - geometry2: data.Geometry, - time_buffer: float = 0.01, -) -> float: - assert isinstance(geometry1, data.TimeInterval) - assert isinstance(geometry2, data.TimeInterval) - - start_time1, end_time1 = geometry1.coordinates - start_time2, end_time2 = geometry1.coordinates - - start_time1 -= time_buffer - start_time2 -= time_buffer - end_time1 += time_buffer - end_time2 += time_buffer - - intersection = compute_interval_overlap( - (start_time1, end_time1), - (start_time2, end_time2), - ) - - union = ( - (end_time1 - start_time1) + (end_time2 - start_time2) - intersection - ) - - if union == 0: - return 0 - - return intersection / union - - class BBoxIOUConfig(BaseConfig): name: Literal["bbox_iou"] = "bbox_iou" - time_buffer: float = 0.01 - freq_buffer: float = 1000 + time_buffer: float = 0.0 + freq_buffer: float = 0.0 class BBoxIOU(AffinityFunction): def __init__(self, time_buffer: float, freq_buffer: float): + if time_buffer < 0: + raise ValueError("time_buffer must be non-negative") + + if freq_buffer < 0: + raise ValueError("freq_buffer must be non-negative") + self.time_buffer = time_buffer self.freq_buffer = freq_buffer - def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): - if not isinstance(geometry1, data.BoundingBox): - raise TypeError( - f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}" + def __call__( + self, + prediction: RawPrediction, + gt: data.SoundEventAnnotation, + ): + target_geometry = gt.sound_event.geometry + source_geometry = prediction.geometry + + if self.time_buffer > 0 or self.freq_buffer > 0: + target_geometry = buffer_geometry( + target_geometry, + time=self.time_buffer, + freq=self.freq_buffer, + ) + source_geometry = buffer_geometry( + source_geometry, + time=self.time_buffer, + freq=self.freq_buffer, ) - if not isinstance(geometry2, data.BoundingBox): - raise TypeError( - f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}" - ) - return bbox_iou( - geometry1, - geometry2, - time_buffer=self.time_buffer, - freq_buffer=self.freq_buffer, - ) + return compute_bbox_iou(target_geometry, source_geometry) @affinity_functions.register(BBoxIOUConfig) @staticmethod @@ -146,65 +149,44 @@ class BBoxIOU(AffinityFunction): ) -def bbox_iou( - geometry1: data.BoundingBox, - geometry2: data.BoundingBox, - time_buffer: float = 0.01, - freq_buffer: float = 1000, -) -> float: - start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates - start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates - - start_time1 -= time_buffer - start_time2 -= time_buffer - end_time1 += time_buffer - end_time2 += time_buffer - - low_freq1 -= freq_buffer - low_freq2 -= freq_buffer - high_freq1 += freq_buffer - high_freq2 += freq_buffer - - time_intersection = compute_interval_overlap( - (start_time1, end_time1), - (start_time2, end_time2), - ) - - freq_intersection = max( - 0, - min(high_freq1, high_freq2) - max(low_freq1, low_freq2), - ) - - intersection = time_intersection * freq_intersection - - if intersection == 0: - return 0 - - union = ( - (end_time1 - start_time1) * (high_freq1 - low_freq1) - + (end_time2 - start_time2) * (high_freq2 - low_freq2) - - intersection - ) - - return intersection / union - - class GeometricIOUConfig(BaseConfig): name: Literal["geometric_iou"] = "geometric_iou" - time_buffer: float = 0.01 - freq_buffer: float = 1000 + time_buffer: float = 0.0 + freq_buffer: float = 0.0 class GeometricIOU(AffinityFunction): - def __init__(self, time_buffer: float): - self.time_buffer = time_buffer + def __init__(self, time_buffer: float = 0, freq_buffer: float = 0): + if time_buffer < 0: + raise ValueError("time_buffer must be non-negative") - def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): - return compute_affinity( - geometry1, - geometry2, - time_buffer=self.time_buffer, - ) + if freq_buffer < 0: + raise ValueError("freq_buffer must be non-negative") + + self.time_buffer = time_buffer + self.freq_buffer = freq_buffer + + def __call__( + self, + prediction: RawPrediction, + gt: data.SoundEventAnnotation, + ): + target_geometry = gt.sound_event.geometry + source_geometry = prediction.geometry + + if self.time_buffer > 0 or self.freq_buffer > 0: + target_geometry = buffer_geometry( + target_geometry, + time=self.time_buffer, + freq=self.freq_buffer, + ) + source_geometry = buffer_geometry( + source_geometry, + time=self.time_buffer, + freq=self.freq_buffer, + ) + + return compute_geometric_iou(target_geometry, source_geometry) @affinity_functions.register(GeometricIOUConfig) @staticmethod @@ -213,7 +195,10 @@ class GeometricIOU(AffinityFunction): AffinityConfig = Annotated[ - TimeAffinityConfig | IntervalIOUConfig | BBoxIOUConfig | GeometricIOUConfig, + TimeAffinityConfig + | IntervalIOUConfig + | BBoxIOUConfig + | GeometricIOUConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 247ddb9..e1c654a 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -31,93 +31,6 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"] matching_strategies = Registry("matching_strategy") -def match( - sound_event_annotations: Sequence[data.SoundEventAnnotation], - raw_predictions: Sequence[RawPrediction], - clip: data.Clip, - scores: Sequence[float] | None = None, - targets: TargetProtocol | None = None, - matcher: MatcherProtocol | None = None, -) -> ClipMatches: - if matcher is None: - matcher = build_matcher() - - if targets is None: - targets = build_targets() - - target_geometries: List[data.Geometry] = [ # type: ignore - sound_event_annotation.sound_event.geometry - for sound_event_annotation in sound_event_annotations - ] - - predicted_geometries = [ - raw_prediction.geometry for raw_prediction in raw_predictions - ] - - if scores is None: - scores = [ - raw_prediction.detection_score - for raw_prediction in raw_predictions - ] - - matches = [] - - for source_idx, target_idx, affinity in matcher( - ground_truth=target_geometries, - predictions=predicted_geometries, - scores=scores, - ): - target = ( - sound_event_annotations[target_idx] - if target_idx is not None - else None - ) - prediction = ( - raw_predictions[source_idx] if source_idx is not None else None - ) - - gt_det = target_idx is not None - gt_class = targets.encode_class(target) if target is not None else None - gt_geometry = ( - target_geometries[target_idx] if target_idx is not None else None - ) - - pred_score = float(prediction.detection_score) if prediction else 0 - pred_geometry = ( - predicted_geometries[source_idx] - if source_idx is not None - else None - ) - - class_scores = ( - { - class_name: score - for class_name, score in zip( - targets.class_names, - prediction.class_scores, strict=False, - ) - } - if prediction is not None - else {} - ) - - matches.append( - MatchEvaluation( - clip=clip, - sound_event_annotation=target, - gt_det=gt_det, - gt_class=gt_class, - gt_geometry=gt_geometry, - pred_score=pred_score, - pred_class_scores=class_scores, - pred_geometry=pred_geometry, - affinity=affinity, - ) - ) - - return ClipMatches(clip=clip, matches=matches) - - class StartTimeMatchConfig(BaseConfig): name: Literal["start_time_match"] = "start_time_match" distance_threshold: float = 0.01 @@ -514,99 +427,9 @@ class OptimalMatcher(MatcherProtocol): MatchConfig = Annotated[ - GreedyMatchConfig | StartTimeMatchConfig | OptimalMatchConfig | GreedyAffinityMatchConfig, + GreedyMatchConfig + | StartTimeMatchConfig + | OptimalMatchConfig + | GreedyAffinityMatchConfig, Field(discriminator="name"), ] - - -def compute_affinity_matrix( - ground_truth: Sequence[data.Geometry], - predictions: Sequence[data.Geometry], - affinity_function: AffinityFunction, - time_scale: float = 1, - frequency_scale: float = 1, -) -> np.ndarray: - # Scale geometries if necessary - if time_scale != 1 or frequency_scale != 1: - ground_truth = [ - scale_geometry(geometry, time_scale, frequency_scale) - for geometry in ground_truth - ] - - predictions = [ - scale_geometry(geometry, time_scale, frequency_scale) - for geometry in predictions - ] - - affinity_matrix = np.zeros((len(ground_truth), len(predictions))) - for gt_idx, gt_geometry in enumerate(ground_truth): - for pred_idx, pred_geometry in enumerate(predictions): - affinity = affinity_function( - gt_geometry, - pred_geometry, - ) - affinity_matrix[gt_idx, pred_idx] = affinity - - return affinity_matrix - - -def select_optimal_matches( - affinity_matrix: np.ndarray, - affinity_threshold: float = 0.5, -) -> Iterable[Tuple[int | None, int | None, float]]: - num_gt, num_pred = affinity_matrix.shape - gts = set(range(num_gt)) - preds = set(range(num_pred)) - - assiged_rows, assigned_columns = linear_sum_assignment( - affinity_matrix, - maximize=True, - ) - - for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False): - affinity = float(affinity_matrix[gt_idx, pred_idx]) - - if affinity <= affinity_threshold: - continue - - yield gt_idx, pred_idx, affinity - gts.remove(gt_idx) - preds.remove(pred_idx) - - for gt_idx in gts: - yield gt_idx, None, 0 - - for pred_idx in preds: - yield None, pred_idx, 0 - - -def select_greedy_matches( - affinity_matrix: np.ndarray, - affinity_threshold: float = 0.5, -) -> Iterable[Tuple[int | None, int | None, float]]: - num_gt, num_pred = affinity_matrix.shape - unmatched_pred = set(range(num_pred)) - - for gt_idx in range(num_gt): - row = affinity_matrix[gt_idx] - - top_pred = int(np.argmax(row)) - top_affinity = float(row[top_pred]) - - if ( - top_affinity <= affinity_threshold - or top_pred not in unmatched_pred - ): - yield None, gt_idx, 0 - continue - - unmatched_pred.remove(top_pred) - yield top_pred, gt_idx, top_affinity - - for pred_idx in unmatched_pred: - yield pred_idx, None, 0 - - -def build_matcher(config: MatchConfig | None = None) -> MatcherProtocol: - config = config or StartTimeMatchConfig() - return matching_strategies.build(config) diff --git a/src/batdetect2/evaluate/metrics/detection.py b/src/batdetect2/evaluate/metrics/detection.py index 493f3cb..7d000e4 100644 --- a/src/batdetect2/evaluate/metrics/detection.py +++ b/src/batdetect2/evaluate/metrics/detection.py @@ -210,7 +210,10 @@ class DetectionPrecision: DetectionMetricConfig = Annotated[ - DetectionAveragePrecisionConfig | DetectionROCAUCConfig | DetectionRecallConfig | DetectionPrecisionConfig, + DetectionAveragePrecisionConfig + | DetectionROCAUCConfig + | DetectionRecallConfig + | DetectionPrecisionConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 4fa15cd..64fd5b9 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -14,16 +14,15 @@ from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds -from batdetect2.core import BaseConfig -from batdetect2.core.registries import Registry -from batdetect2.evaluate.match import ( - MatchConfig, - StartTimeMatchConfig, - build_matcher, +from batdetect2.core import BaseConfig, Registry +from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig +from batdetect2.typing import ( + AffinityFunction, + BatDetect2Prediction, + EvaluatorProtocol, + RawPrediction, + TargetProtocol, ) -from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol -from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction -from batdetect2.typing.targets import TargetProtocol __all__ = [ "BaseTaskConfig", @@ -40,39 +39,34 @@ T_Output = TypeVar("T_Output") class BaseTaskConfig(BaseConfig): prefix: str - ignore_start_end: float = 0.01 - matching_strategy: MatchConfig = Field( - default_factory=StartTimeMatchConfig - ) class BaseTask(EvaluatorProtocol, Generic[T_Output]): targets: TargetProtocol - matcher: MatcherProtocol - metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] - ignore_start_end: float - prefix: str + ignore_start_end: float + def __init__( self, - matcher: MatcherProtocol, targets: TargetProtocol, metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], prefix: str, + plots: List[ + Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]] + ] + | None = None, ignore_start_end: float = 0.01, - plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None, ): - self.matcher = matcher + self.prefix = prefix + self.targets = targets self.metrics = metrics self.plots = plots or [] - self.targets = targets - self.prefix = prefix self.ignore_start_end = ignore_start_end def compute_metrics( @@ -100,7 +94,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): ) -> List[T_Output]: return [ self.evaluate_clip(clip_annotation, preds) - for clip_annotation, preds in zip(clip_annotations, predictions, strict=False) + for clip_annotation, preds in zip( + clip_annotations, predictions, strict=False + ) ] def evaluate_clip( @@ -118,9 +114,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): return False geometry = sound_event_annotation.sound_event.geometry - if geometry is None: - return False - return is_in_bounds( geometry, clip, @@ -138,25 +131,40 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): self.ignore_start_end, ) - @classmethod - def build( - cls, - config: BaseTaskConfig, + +class BaseSEDTaskConfig(BaseTaskConfig): + affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig) + affinity_threshold: float = 0 + strict_match: bool = True + + +class BaseSEDTask(BaseTask[T_Output]): + affinity: AffinityFunction + + def __init__( + self, + prefix: str, targets: TargetProtocol, metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], - plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] | None = None, - **kwargs, + affinity: AffinityFunction, + plots: List[ + Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]] + ] + | None = None, + affinity_threshold: float = 0, + ignore_start_end: float = 0.01, + strict_match: bool = True, ): - matcher = build_matcher(config.matching_strategy) - return cls( - matcher=matcher, - targets=targets, + super().__init__( + prefix=prefix, metrics=metrics, plots=plots, - prefix=config.prefix, - ignore_start_end=config.ignore_start_end, - **kwargs, + targets=targets, + ignore_start_end=ignore_start_end, ) + self.affinity = affinity + self.affinity_threshold = affinity_threshold + self.strict_match = strict_match def is_in_bounds( diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index 9a31db2..1a1fc6a 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -1,11 +1,11 @@ -from typing import ( - List, - Literal, -) +from functools import partial +from typing import Literal from pydantic import Field from soundevent import data +from soundevent.evaluation import match_detections_and_gts +from batdetect2.evaluate.affinity import build_affinity_function from batdetect2.evaluate.metrics.classification import ( ClassificationAveragePrecisionConfig, ClassificationMetricConfig, @@ -18,24 +18,28 @@ from batdetect2.evaluate.plots.classification import ( build_classification_plotter, ) from batdetect2.evaluate.tasks.base import ( - BaseTask, - BaseTaskConfig, + BaseSEDTask, + BaseSEDTaskConfig, tasks_registry, ) -from batdetect2.typing import BatDetect2Prediction, TargetProtocol +from batdetect2.typing import ( + BatDetect2Prediction, + RawPrediction, + TargetProtocol, +) -class ClassificationTaskConfig(BaseTaskConfig): +class ClassificationTaskConfig(BaseSEDTaskConfig): name: Literal["sound_event_classification"] = "sound_event_classification" prefix: str = "classification" - metrics: List[ClassificationMetricConfig] = Field( + metrics: list[ClassificationMetricConfig] = Field( default_factory=lambda: [ClassificationAveragePrecisionConfig()] ) - plots: List[ClassificationPlotConfig] = Field(default_factory=list) + plots: list[ClassificationPlotConfig] = Field(default_factory=list) include_generics: bool = True -class ClassificationTask(BaseTask[ClipEval]): +class ClassificationTask(BaseSEDTask[ClipEval]): def __init__( self, *args, @@ -73,40 +77,39 @@ class ClassificationTask(BaseTask[ClipEval]): gts = [ sound_event for sound_event in all_gts - if self.is_class(sound_event, class_name) + if is_target_class( + sound_event, + class_name, + self.targets, + include_generics=self.include_generics, + ) ] - scores = [float(pred.class_scores[class_idx]) for pred in preds] matches = [] - for pred_idx, gt_idx, _ in self.matcher( - ground_truth=[se.sound_event.geometry for se in gts], # type: ignore - predictions=[pred.geometry for pred in preds], - scores=scores, + for match in match_detections_and_gts( + detections=preds, + ground_truths=gts, + affinity=self.affinity, + score=partial(get_class_score, class_idx=class_idx), + strict_match=self.strict_match, ): - gt = gts[gt_idx] if gt_idx is not None else None - pred = preds[pred_idx] if pred_idx is not None else None - true_class = ( - self.targets.encode_class(gt) if gt is not None else None + self.targets.encode_class(match.annotation) + if match.annotation is not None + else None ) - - score = ( - float(pred.class_scores[class_idx]) - if pred is not None - else 0 - ) - matches.append( MatchEval( clip=clip, - gt=gt, - pred=pred, - is_prediction=pred is not None, - is_ground_truth=gt is not None, - is_generic=gt is not None and true_class is None, + gt=match.annotation, + pred=match.prediction, + is_prediction=match.prediction is not None, + is_ground_truth=match.annotation is not None, + is_generic=match.annotation is not None + and true_class is None, true_class=true_class, - score=score, + score=match.prediction_score, ) ) @@ -114,20 +117,6 @@ class ClassificationTask(BaseTask[ClipEval]): return ClipEval(clip=clip, matches=per_class_matches) - def is_class( - self, - sound_event: data.SoundEventAnnotation, - class_name: str, - ) -> bool: - sound_event_class = self.targets.encode_class(sound_event) - - if sound_event_class is None and self.include_generics: - # Sound events that are generic could be of the given - # class - return True - - return sound_event_class == class_name - @tasks_registry.register(ClassificationTaskConfig) @staticmethod def from_config( @@ -142,9 +131,32 @@ class ClassificationTask(BaseTask[ClipEval]): build_classification_plotter(plot, targets) for plot in config.plots ] - return ClassificationTask.build( - config=config, + affinity = build_affinity_function(config.affinity) + return ClassificationTask( + affinity=affinity, + prefix=config.prefix, plots=plots, targets=targets, metrics=metrics, + strict_match=config.strict_match, ) + + +def get_class_score(pred: RawPrediction, class_idx: int) -> float: + return pred.class_scores[class_idx] + + +def is_target_class( + sound_event: data.SoundEventAnnotation, + class_name: str, + targets: TargetProtocol, + include_generics: bool = True, +) -> bool: + sound_event_class = targets.encode_class(sound_event) + + if sound_event_class is None and include_generics: + # Sound events that are generic could be of the given + # class + return True + + return sound_event_class == class_name diff --git a/src/batdetect2/evaluate/tasks/clip_classification.py b/src/batdetect2/evaluate/tasks/clip_classification.py index 8215555..67b2c71 100644 --- a/src/batdetect2/evaluate/tasks/clip_classification.py +++ b/src/batdetect2/evaluate/tasks/clip_classification.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import List, Literal +from typing import Literal from pydantic import Field from soundevent import data @@ -19,19 +19,18 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import TargetProtocol -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing import BatDetect2Prediction, TargetProtocol class ClipClassificationTaskConfig(BaseTaskConfig): name: Literal["clip_classification"] = "clip_classification" prefix: str = "clip_classification" - metrics: List[ClipClassificationMetricConfig] = Field( + metrics: list[ClipClassificationMetricConfig] = Field( default_factory=lambda: [ ClipClassificationAveragePrecisionConfig(), ] ) - plots: List[ClipClassificationPlotConfig] = Field(default_factory=list) + plots: list[ClipClassificationPlotConfig] = Field(default_factory=list) class ClipClassificationTask(BaseTask[ClipEval]): @@ -78,8 +77,8 @@ class ClipClassificationTask(BaseTask[ClipEval]): build_clip_classification_plotter(plot, targets) for plot in config.plots ] - return ClipClassificationTask.build( - config=config, + return ClipClassificationTask( + prefix=config.prefix, plots=plots, metrics=metrics, targets=targets, diff --git a/src/batdetect2/evaluate/tasks/clip_detection.py b/src/batdetect2/evaluate/tasks/clip_detection.py index 66a7a9d..ddb9d1a 100644 --- a/src/batdetect2/evaluate/tasks/clip_detection.py +++ b/src/batdetect2/evaluate/tasks/clip_detection.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import Literal from pydantic import Field from soundevent import data @@ -18,19 +18,18 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import TargetProtocol -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing import BatDetect2Prediction, TargetProtocol class ClipDetectionTaskConfig(BaseTaskConfig): name: Literal["clip_detection"] = "clip_detection" prefix: str = "clip_detection" - metrics: List[ClipDetectionMetricConfig] = Field( + metrics: list[ClipDetectionMetricConfig] = Field( default_factory=lambda: [ ClipDetectionAveragePrecisionConfig(), ] ) - plots: List[ClipDetectionPlotConfig] = Field(default_factory=list) + plots: list[ClipDetectionPlotConfig] = Field(default_factory=list) class ClipDetectionTask(BaseTask[ClipEval]): @@ -69,8 +68,8 @@ class ClipDetectionTask(BaseTask[ClipEval]): build_clip_detection_plotter(plot, targets) for plot in config.plots ] - return ClipDetectionTask.build( - config=config, + return ClipDetectionTask( + prefix=config.prefix, metrics=metrics, targets=targets, plots=plots, diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py index 8e404d9..3ba034e 100644 --- a/src/batdetect2/evaluate/tasks/detection.py +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -1,8 +1,10 @@ -from typing import List, Literal +from typing import Literal from pydantic import Field from soundevent import data +from soundevent.evaluation import match_detections_and_gts +from batdetect2.evaluate.affinity import build_affinity_function from batdetect2.evaluate.metrics.detection import ( ClipEval, DetectionAveragePrecisionConfig, @@ -15,24 +17,24 @@ from batdetect2.evaluate.plots.detection import ( build_detection_plotter, ) from batdetect2.evaluate.tasks.base import ( - BaseTask, - BaseTaskConfig, + BaseSEDTask, + BaseSEDTaskConfig, tasks_registry, ) from batdetect2.typing import TargetProtocol from batdetect2.typing.postprocess import BatDetect2Prediction -class DetectionTaskConfig(BaseTaskConfig): +class DetectionTaskConfig(BaseSEDTaskConfig): name: Literal["sound_event_detection"] = "sound_event_detection" prefix: str = "detection" - metrics: List[DetectionMetricConfig] = Field( + metrics: list[DetectionMetricConfig] = Field( default_factory=lambda: [DetectionAveragePrecisionConfig()] ) - plots: List[DetectionPlotConfig] = Field(default_factory=list) + plots: list[DetectionPlotConfig] = Field(default_factory=list) -class DetectionTask(BaseTask[ClipEval]): +class DetectionTask(BaseSEDTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, @@ -50,24 +52,22 @@ class DetectionTask(BaseTask[ClipEval]): for pred in prediction.predictions if self.include_prediction(pred, clip) ] - scores = [pred.detection_score for pred in preds] matches = [] - for pred_idx, gt_idx, _ in self.matcher( - ground_truth=[se.sound_event.geometry for se in gts], # type: ignore - predictions=[pred.geometry for pred in preds], - scores=scores, + for match in match_detections_and_gts( + detections=preds, + ground_truths=gts, + affinity=self.affinity, + score=lambda pred: pred.detection_score, + strict_match=self.strict_match, ): - gt = gts[gt_idx] if gt_idx is not None else None - pred = preds[pred_idx] if pred_idx is not None else None - matches.append( MatchEval( - gt=gt, - pred=pred, - is_prediction=pred is not None, - is_ground_truth=gt is not None, - score=pred.detection_score if pred is not None else 0, + gt=match.annotation, + pred=match.prediction, + is_prediction=match.prediction is not None, + is_ground_truth=match.annotation is not None, + score=match.prediction_score, ) ) @@ -83,9 +83,12 @@ class DetectionTask(BaseTask[ClipEval]): plots = [ build_detection_plotter(plot, targets) for plot in config.plots ] - return DetectionTask.build( - config=config, + affinity = build_affinity_function(config.affinity) + return DetectionTask( + prefix=config.prefix, + affinity=affinity, metrics=metrics, targets=targets, plots=plots, + strict_match=config.strict_match, ) diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index ef94041..a625ecc 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -1,8 +1,10 @@ -from typing import List, Literal +from typing import Literal from pydantic import Field from soundevent import data +from soundevent.evaluation import match_detections_and_gts +from batdetect2.evaluate.affinity import build_affinity_function from batdetect2.evaluate.metrics.top_class import ( ClipEval, MatchEval, @@ -15,24 +17,23 @@ from batdetect2.evaluate.plots.top_class import ( build_top_class_plotter, ) from batdetect2.evaluate.tasks.base import ( - BaseTask, - BaseTaskConfig, + BaseSEDTask, + BaseSEDTaskConfig, tasks_registry, ) -from batdetect2.typing import TargetProtocol -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing import BatDetect2Prediction, TargetProtocol -class TopClassDetectionTaskConfig(BaseTaskConfig): +class TopClassDetectionTaskConfig(BaseSEDTaskConfig): name: Literal["top_class_detection"] = "top_class_detection" prefix: str = "top_class" - metrics: List[TopClassMetricConfig] = Field( + metrics: list[TopClassMetricConfig] = Field( default_factory=lambda: [TopClassAveragePrecisionConfig()] ) - plots: List[TopClassPlotConfig] = Field(default_factory=list) + plots: list[TopClassPlotConfig] = Field(default_factory=list) -class TopClassDetectionTask(BaseTask[ClipEval]): +class TopClassDetectionTask(BaseSEDTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, @@ -50,18 +51,17 @@ class TopClassDetectionTask(BaseTask[ClipEval]): for pred in prediction.predictions if self.include_prediction(pred, clip) ] - # Take the highest score for each prediction - scores = [pred.class_scores.max() for pred in preds] matches = [] - for pred_idx, gt_idx, _ in self.matcher( - ground_truth=[se.sound_event.geometry for se in gts], # type: ignore - predictions=[pred.geometry for pred in preds], - scores=scores, + for match in match_detections_and_gts( + ground_truths=gts, + detections=preds, + affinity=self.affinity, + score=lambda pred: pred.class_scores.max(), + strict_match=self.strict_match, ): - gt = gts[gt_idx] if gt_idx is not None else None - pred = preds[pred_idx] if pred_idx is not None else None - + gt = match.annotation + pred = match.prediction true_class = ( self.targets.encode_class(gt) if gt is not None else None ) @@ -69,11 +69,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]): class_idx = ( pred.class_scores.argmax() if pred is not None else None ) - - score = ( - float(pred.class_scores[class_idx]) if pred is not None else 0 - ) - pred_class = ( self.targets.class_names[class_idx] if class_idx is not None @@ -90,7 +85,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]): true_class=true_class, is_generic=gt is not None and true_class is None, pred_class=pred_class, - score=score, + score=match.prediction_score, ) ) @@ -106,9 +101,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]): plots = [ build_top_class_plotter(plot, targets) for plot in config.plots ] - return TopClassDetectionTask.build( - config=config, + affinity = build_affinity_function(config.affinity) + return TopClassDetectionTask( + prefix=config.prefix, plots=plots, metrics=metrics, targets=targets, + affinity=affinity, + strict_match=config.strict_match, ) diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 6903af2..56438ff 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -81,11 +81,11 @@ class MatcherProtocol(Protocol): Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) -class AffinityFunction(Protocol, Generic[Geom]): +class AffinityFunction(Protocol): def __call__( self, - geometry1: Geom, - geometry2: Geom, + detection: RawPrediction, + ground_truth: data.SoundEventAnnotation, ) -> float: ... diff --git a/tests/test_data/test_transforms/test_conditions.py b/tests/test_data/test_transforms/test_conditions.py index 6d5068e..bdb4eb9 100644 --- a/tests/test_data/test_transforms/test_conditions.py +++ b/tests/test_data/test_transforms/test_conditions.py @@ -28,13 +28,13 @@ def test_has_tag(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert condition(sound_event_annotation) sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, - tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + tags=[data.Tag(key="species", value="Eptesicus fuscus")], ) assert not condition(sound_event_annotation) @@ -51,15 +51,15 @@ def test_has_all_tags(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert not condition(sound_event_annotation) sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore - data.Tag(key="event", value="Echolocation"), # type: ignore + data.Tag(key="species", value="Eptesicus fuscus"), + data.Tag(key="event", value="Echolocation"), ], ) assert not condition(sound_event_annotation) @@ -67,8 +67,8 @@ def test_has_all_tags(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Myotis myotis"), # type: ignore - data.Tag(key="event", value="Echolocation"), # type: ignore + data.Tag(key="species", value="Myotis myotis"), + data.Tag(key="event", value="Echolocation"), ], ) assert condition(sound_event_annotation) @@ -76,9 +76,9 @@ def test_has_all_tags(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Myotis myotis"), # type: ignore - data.Tag(key="event", value="Echolocation"), # type: ignore - data.Tag(key="sex", value="Female"), # type: ignore + data.Tag(key="species", value="Myotis myotis"), + data.Tag(key="event", value="Echolocation"), + data.Tag(key="sex", value="Female"), ], ) assert condition(sound_event_annotation) @@ -96,15 +96,15 @@ def test_has_any_tags(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert condition(sound_event_annotation) sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore - data.Tag(key="event", value="Echolocation"), # type: ignore + data.Tag(key="species", value="Eptesicus fuscus"), + data.Tag(key="event", value="Echolocation"), ], ) assert condition(sound_event_annotation) @@ -112,8 +112,8 @@ def test_has_any_tags(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Myotis myotis"), # type: ignore - data.Tag(key="event", value="Echolocation"), # type: ignore + data.Tag(key="species", value="Myotis myotis"), + data.Tag(key="event", value="Echolocation"), ], ) assert condition(sound_event_annotation) @@ -121,8 +121,8 @@ def test_has_any_tags(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore - data.Tag(key="event", value="Social"), # type: ignore + data.Tag(key="species", value="Eptesicus fuscus"), + data.Tag(key="event", value="Social"), ], ) assert not condition(sound_event_annotation) @@ -140,21 +140,21 @@ def test_not(sound_event: data.SoundEvent): sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert not condition(sound_event_annotation) sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, - tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + tags=[data.Tag(key="species", value="Eptesicus fuscus")], ) assert condition(sound_event_annotation) sound_event_annotation = data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Myotis myotis"), # type: ignore - data.Tag(key="event", value="Echolocation"), # type: ignore + data.Tag(key="species", value="Myotis myotis"), + data.Tag(key="event", value="Echolocation"), ], ) assert not condition(sound_event_annotation) @@ -402,31 +402,6 @@ def test_has_tags_fails_if_empty(): """) -def test_frequency_is_false_if_no_geometry(recording: data.Recording): - condition = build_condition_from_str(""" - name: frequency - boundary: low - operator: eq - hertz: 200 - """) - se = data.SoundEventAnnotation( - sound_event=data.SoundEvent(geometry=None, recording=recording) - ) - assert not condition(se) - - -def test_duration_is_false_if_no_geometry(recording: data.Recording): - condition = build_condition_from_str(""" - name: duration - operator: eq - seconds: 1 - """) - se = data.SoundEventAnnotation( - sound_event=data.SoundEvent(geometry=None, recording=recording) - ) - assert not condition(se) - - def test_all_of(recording: data.Recording): condition = build_condition_from_str(""" name: all_of @@ -444,7 +419,7 @@ def test_all_of(recording: data.Recording): geometry=data.TimeInterval(coordinates=[0, 0.5]), recording=recording, ), - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert condition(se) @@ -453,7 +428,7 @@ def test_all_of(recording: data.Recording): geometry=data.TimeInterval(coordinates=[0, 2]), recording=recording, ), - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert not condition(se) @@ -462,7 +437,7 @@ def test_all_of(recording: data.Recording): geometry=data.TimeInterval(coordinates=[0, 0.5]), recording=recording, ), - tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + tags=[data.Tag(key="species", value="Eptesicus fuscus")], ) assert not condition(se) @@ -484,7 +459,7 @@ def test_any_of(recording: data.Recording): geometry=data.TimeInterval(coordinates=[0, 2]), recording=recording, ), - tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + tags=[data.Tag(key="species", value="Eptesicus fuscus")], ) assert not condition(se) @@ -493,7 +468,7 @@ def test_any_of(recording: data.Recording): geometry=data.TimeInterval(coordinates=[0, 0.5]), recording=recording, ), - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert condition(se) @@ -502,7 +477,7 @@ def test_any_of(recording: data.Recording): geometry=data.TimeInterval(coordinates=[0, 2]), recording=recording, ), - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + tags=[data.Tag(key="species", value="Myotis myotis")], ) assert condition(se) @@ -511,6 +486,6 @@ def test_any_of(recording: data.Recording): geometry=data.TimeInterval(coordinates=[0, 0.5]), recording=recording, ), - tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + tags=[data.Tag(key="species", value="Eptesicus fuscus")], ) assert condition(se)