diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index ecd2812..412ed72 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,11 +1,6 @@ -from batdetect2.evaluate.config import ( - EvaluationConfig, - load_evaluation_config, -) -from batdetect2.evaluate.match import match_predictions_and_annotations +from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config __all__ = [ "EvaluationConfig", "load_evaluation_config", - "match_predictions_and_annotations", ] diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index 5fc9d81..cc93a89 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -4,7 +4,7 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config -from batdetect2.evaluate.match import MatchConfig +from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig __all__ = [ "EvaluationConfig", @@ -13,7 +13,7 @@ __all__ = [ class EvaluationConfig(BaseConfig): - match: MatchConfig = Field(default_factory=MatchConfig) + match: MatchConfig = Field(default_factory=StartTimeMatchConfig) def load_evaluation_config( diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index aed3a08..174f75b 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -4,9 +4,8 @@ import pandas as pd from soundevent import data from batdetect2.evaluate.dataframe import extract_matches_dataframe -from batdetect2.evaluate.match import match_all_predictions +from batdetect2.evaluate.match import build_matcher, match_all_predictions from batdetect2.evaluate.metrics import ( - ClassificationAccuracy, ClassificationMeanAveragePrecision, DetectionAveragePrecision, ) @@ -77,11 +76,13 @@ def evaluate( clip_annotations.extend(clip_annotations) predictions.extend(predictions) + matcher = build_matcher(config.evaluation.match) + matches = match_all_predictions( clip_annotations, predictions, targets=targets, - config=config.evaluation.match, + matcher=matcher, ) df = extract_matches_dataframe(matches) @@ -89,7 +90,6 @@ def evaluate( metrics = [ DetectionAveragePrecision(), ClassificationMeanAveragePrecision(class_names=targets.class_names), - ClassificationAccuracy(class_names=targets.class_names), ] results = { diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index cca9946..914341d 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,63 +1,120 @@ from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, field -from typing import List, Literal, Optional, Protocol, Tuple +from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union import numpy as np from loguru import logger +from pydantic import Field from soundevent import data from soundevent.evaluation import compute_affinity from soundevent.evaluation import match_geometries as optimal_match from soundevent.geometry import compute_bounds from batdetect2.configs import BaseConfig +from batdetect2.data._core import Registry +from batdetect2.targets import build_targets from batdetect2.typing import ( MatchEvaluation, TargetProtocol, ) +from batdetect2.typing.evaluate import AffinityFunction, MatcherProtocol from batdetect2.typing.postprocess import RawPrediction -MatchingStrategy = Literal["greedy", "optimal"] -"""The type of matching algorithm to use: 'greedy' or 'optimal'.""" - - MatchingGeometry = Literal["bbox", "interval", "timestamp"] """The geometry representation to use for matching.""" +matching_strategy = Registry("matching_strategy") + + +class StartTimeMatchConfig(BaseConfig): + name: Literal["start_time"] = "start_time" + distance_threshold: float = 0.01 + + +@matching_strategy.register(StartTimeMatchConfig) +class StartTimeMatcher(MatcherProtocol): + def __init__(self, distance_threshold: float): + self.distance_threshold = distance_threshold -class AffinityFunction(Protocol): def __call__( self, - geometry1: data.Geometry, - geometry2: data.Geometry, - time_buffer: float = 0.01, - freq_buffer: float = 1000, - ) -> float: ... + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + scores: Sequence[float], + ): + return match_start_times( + ground_truth, + predictions, + scores, + distance_threshold=self.distance_threshold, + ) + + @classmethod + def from_config(cls, config: StartTimeMatchConfig) -> "StartTimeMatcher": + return cls(distance_threshold=config.distance_threshold) -class MatchConfig(BaseConfig): - """Configuration for matching geometries. +def match_start_times( + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + scores: Sequence[float], + distance_threshold: float = 0.01, +) -> Iterable[Tuple[Optional[int], Optional[int], float]]: + if not ground_truth: + for index in range(len(predictions)): + yield index, None, 0 - Attributes - ---------- - strategy : MatchingStrategy, default="greedy" - The matching algorithm to use. 'greedy' prioritizes high-confidence - predictions, while 'optimal' finds the globally best set of matches. - geometry : MatchingGeometry, default="timestamp" - The geometric representation to use when computing affinity. - affinity_threshold : float, default=0.0 - The minimum affinity score (e.g., IoU) required for a valid match. - time_buffer : float, default=0.005 - Time tolerance in seconds used in affinity calculations. - frequency_buffer : float, default=1000 - Frequency tolerance in Hertz used in affinity calculations. - """ + return - strategy: MatchingStrategy = "greedy" - geometry: MatchingGeometry = "timestamp" - affinity_threshold: float = 0.0 - time_buffer: float = 0.005 - frequency_buffer: float = 1_000 - ignore_start_end: float = 0.01 + if not predictions: + for index in range(len(ground_truth)): + yield None, index, 0 + + return + + gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth]) + pred_times = np.array([compute_bounds(geom)[0] for geom in predictions]) + scores = np.array(scores) + + sort_args = np.argsort(scores)[::-1] + + distances = np.abs(gt_times[None, :] - pred_times[:, None]) + closests = np.argmin(distances, axis=-1) + + unmatched_gt = set(range(len(gt_times))) + + for pred_index in sort_args: + # Get the closest ground truth + gt_closest_index = closests[pred_index] + + if gt_closest_index not in unmatched_gt: + # Does not match if closest has been assigned + yield pred_index, None, 0 + continue + + # Get the actual distance + distance = distances[pred_index, gt_closest_index] + + if distance > distance_threshold: + # Does not match if too far from closest + yield pred_index, None, 0 + continue + + # Return affinity value: linear interpolation between 0 to 1, where a + # distance at the threshold maps to 0 affinity and a zero distance maps + # to 1. + affinity = np.interp( + distance, + [0, distance_threshold], + [1, 0], + left=1, + right=0, + ) + unmatched_gt.remove(gt_closest_index) + yield pred_index, gt_closest_index, affinity + + for missing_index in unmatched_gt: + yield None, missing_index, 0 def _to_bbox(geometry: data.Geometry) -> data.BoundingBox: @@ -142,50 +199,65 @@ def _interval_affinity( _affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = { "timestamp": _timestamp_affinity, "interval": _interval_affinity, + "bbox": compute_affinity, } -def match_geometries( - source: List[data.Geometry], - target: List[data.Geometry], - config: MatchConfig, - scores: Optional[List[float]] = None, -) -> Iterable[Tuple[Optional[int], Optional[int], float]]: - geometry_cast = _geometry_cast_functions[config.geometry] - affinity_function = _affinity_functions.get( - config.geometry, - compute_affinity, - ) +class GreedyMatchConfig(BaseConfig): + name: Literal["greedy_match"] = "greedy_match" + geometry: MatchingGeometry = "timestamp" + affinity_threshold: float = 0.0 + time_buffer: float = 0.005 + frequency_buffer: float = 1_000 - if config.strategy == "optimal": - return optimal_match( - source=[geometry_cast(geom) for geom in source], - target=[geometry_cast(geom) for geom in target], - time_buffer=config.time_buffer, - freq_buffer=config.frequency_buffer, - affinity_threshold=config.affinity_threshold, - ) - if config.strategy == "greedy": +@matching_strategy.register(GreedyMatchConfig) +class GreedyMatcher(MatcherProtocol): + def __init__( + self, + geometry: MatchingGeometry, + affinity_threshold: float, + time_buffer: float, + frequency_buffer: float, + ): + self.geometry = geometry + self.affinity_threshold = affinity_threshold + self.time_buffer = time_buffer + self.frequency_buffer = frequency_buffer + + self.affinity_function = _affinity_functions[self.geometry] + self.cast_geometry = _geometry_cast_functions[self.geometry] + + def __call__( + self, + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + scores: Sequence[float], + ): return greedy_match( - source=[geometry_cast(geom) for geom in source], - target=[geometry_cast(geom) for geom in target], - time_buffer=config.time_buffer, - freq_buffer=config.frequency_buffer, - affinity_threshold=config.affinity_threshold, - affinity_function=affinity_function, + ground_truth=[self.cast_geometry(geom) for geom in ground_truth], + predictions=[self.cast_geometry(geom) for geom in predictions], scores=scores, + affinity_function=self.affinity_function, + affinity_threshold=self.affinity_threshold, + time_buffer=self.time_buffer, + freq_buffer=self.frequency_buffer, ) - raise NotImplementedError( - f"Matching strategy not implemented {config.strategy}" - ) + @classmethod + def from_config(cls, config: GreedyMatchConfig): + return cls( + geometry=config.geometry, + affinity_threshold=config.affinity_threshold, + time_buffer=config.time_buffer, + frequency_buffer=config.frequency_buffer, + ) def greedy_match( - source: List[data.Geometry], - target: List[data.Geometry], - scores: Optional[List[float]] = None, + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + scores: Sequence[float], affinity_threshold: float = 0.5, affinity_function: AffinityFunction = compute_affinity, time_buffer: float = 0.001, @@ -221,27 +293,24 @@ def greedy_match( - Unmatched Source (False Positive): `(source_idx, None, 0)` - Unmatched Target (False Negative): `(None, target_idx, 0)` """ - assigned = set() + unassigned_gt = set(range(len(ground_truth))) - if not source: - for target_idx in range(len(target)): + if not predictions: + for target_idx in range(len(ground_truth)): yield None, target_idx, 0 return - if not target: - for source_idx in range(len(source)): + if not ground_truth: + for source_idx in range(len(predictions)): yield source_idx, None, 0 return - if scores is None: - indices = np.arange(len(source)) - else: - indices = np.argsort(scores)[::-1] + indices = np.argsort(scores)[::-1] for source_idx in indices: - source_geometry = source[source_idx] + source_geometry = predictions[source_idx] affinities = np.array( [ @@ -251,7 +320,7 @@ def greedy_match( time_buffer=time_buffer, freq_buffer=freq_buffer, ) - for target_geometry in target + for target_geometry in ground_truth ] ) @@ -262,18 +331,74 @@ def greedy_match( yield source_idx, None, 0 continue - if closest_target in assigned: + if closest_target not in unassigned_gt: yield source_idx, None, 0 continue - assigned.add(closest_target) + unassigned_gt.remove(closest_target) yield source_idx, closest_target, affinity - missed_ground_truth = set(range(len(target))) - assigned - for target_idx in missed_ground_truth: + for target_idx in unassigned_gt: yield None, target_idx, 0 +class OptimalMatchConfig(BaseConfig): + name: Literal["optimal_match"] = "optimal_match" + affinity_threshold: float = 0.0 + time_buffer: float = 0.005 + frequency_buffer: float = 1_000 + + +@matching_strategy.register(OptimalMatchConfig) +class OptimalMatcher(MatcherProtocol): + def __init__( + self, + affinity_threshold: float, + time_buffer: float, + frequency_buffer: float, + ): + self.affinity_threshold = affinity_threshold + self.time_buffer = time_buffer + self.frequency_buffer = frequency_buffer + + def __call__( + self, + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + scores: Sequence[float], + ): + return optimal_match( + source=predictions, + target=ground_truth, + time_buffer=self.time_buffer, + freq_buffer=self.frequency_buffer, + affinity_threshold=self.affinity_threshold, + ) + + @classmethod + def from_config(cls, config: OptimalMatchConfig): + return cls( + affinity_threshold=config.affinity_threshold, + time_buffer=config.time_buffer, + frequency_buffer=config.frequency_buffer, + ) + + +MatchConfig = Annotated[ + Union[ + GreedyMatchConfig, + StartTimeMatchConfig, + OptimalMatchConfig, + ], + Field(discriminator="name"), +] + + +def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol: + config = config or StartTimeMatchConfig() + return matching_strategy.build(config) + + def _is_in_bounds( geometry: data.Geometry, clip: data.Clip, @@ -285,13 +410,18 @@ def _is_in_bounds( ) -def match_sound_events_and_raw_predictions( +def match_sound_events_and_predictions( clip_annotation: data.ClipAnnotation, raw_predictions: List[RawPrediction], - targets: TargetProtocol, - config: Optional[MatchConfig] = None, + targets: Optional[TargetProtocol] = None, + matcher: Optional[MatcherProtocol] = None, + ignore_start_end: float = 0.01, ) -> List[MatchEvaluation]: - config = config or MatchConfig() + if matcher is None: + matcher = build_matcher() + + if targets is None: + targets = build_targets() target_sound_events = [ sound_event_annotation @@ -301,7 +431,7 @@ def match_sound_events_and_raw_predictions( and _is_in_bounds( sound_event_annotation.sound_event.geometry, clip=clip_annotation.clip, - buffer=config.ignore_start_end, + buffer=ignore_start_end, ) ] @@ -317,7 +447,7 @@ def match_sound_events_and_raw_predictions( if _is_in_bounds( raw_prediction.geometry, clip=clip_annotation.clip, - buffer=config.ignore_start_end, + buffer=ignore_start_end, ) ] @@ -331,10 +461,9 @@ def match_sound_events_and_raw_predictions( matches = [] - for source_idx, target_idx, affinity in match_geometries( - source=predicted_geometries, - target=target_geometries, - config=config, + for source_idx, target_idx, affinity in matcher( + ground_truth=target_geometries, + predictions=predicted_geometries, scores=scores, ): target = ( @@ -344,7 +473,7 @@ def match_sound_events_and_raw_predictions( raw_predictions[source_idx] if source_idx is not None else None ) - gt_det = target is not None + gt_det = target_idx is not None gt_class = targets.encode_class(target) if target is not None else None pred_score = float(prediction.detection_score) if prediction else 0 @@ -383,76 +512,12 @@ def match_sound_events_and_raw_predictions( return matches -def match_predictions_and_annotations( - clip_annotation: data.ClipAnnotation, - clip_prediction: data.ClipPrediction, - config: Optional[MatchConfig] = None, -) -> List[data.Match]: - config = config or MatchConfig() - - annotated_sound_events = [ - sound_event_annotation - for sound_event_annotation in clip_annotation.sound_events - if sound_event_annotation.sound_event.geometry is not None - ] - - predicted_sound_events = [ - sound_event_prediction - for sound_event_prediction in clip_prediction.sound_events - if sound_event_prediction.sound_event.geometry is not None - ] - - annotated_geometries: List[data.Geometry] = [ - sound_event.sound_event.geometry - for sound_event in annotated_sound_events - if sound_event.sound_event.geometry is not None - ] - - predicted_geometries: List[data.Geometry] = [ - sound_event.sound_event.geometry - for sound_event in predicted_sound_events - if sound_event.sound_event.geometry is not None - ] - - scores = [ - sound_event.score - for sound_event in predicted_sound_events - if sound_event.sound_event.geometry is not None - ] - - matches = [] - for source_idx, target_idx, affinity in match_geometries( - source=predicted_geometries, - target=annotated_geometries, - config=config, - scores=scores, - ): - target = ( - annotated_sound_events[target_idx] - if target_idx is not None - else None - ) - source = ( - predicted_sound_events[source_idx] - if source_idx is not None - else None - ) - matches.append( - data.Match( - source=source, - target=target, - affinity=affinity, - ) - ) - - return matches - - def match_all_predictions( clip_annotations: List[data.ClipAnnotation], predictions: List[List[RawPrediction]], - targets: TargetProtocol, - config: Optional[MatchConfig] = None, + targets: Optional[TargetProtocol] = None, + matcher: Optional[MatcherProtocol] = None, + ignore_start_end: float = 0.01, ) -> List[MatchEvaluation]: logger.info("Matching all annotations and predictions...") return [ @@ -461,11 +526,12 @@ def match_all_predictions( clip_annotations, predictions, ) - for match in match_sound_events_and_raw_predictions( + for match in match_sound_events_and_predictions( clip_annotation, raw_predictions, targets=targets, - config=config, + matcher=matcher, + ignore_start_end=ignore_start_end, ) ] diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index b42230d..c29fa5f 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -24,10 +24,12 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): self.class_names = class_names def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: + # NOTE: Need to exclude generic but unclassified targets y_true = label_binarize( [ match.gt_class if match.gt_class is not None else "__NONE__" for match in matches + if not (match.gt_det and match.gt_class is None) ], classes=self.class_names, ) @@ -38,11 +40,11 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): for name in self.class_names } for match in matches + if not (match.gt_det and match.gt_class is None) ] ).fillna(0) ret = {} - for class_index, class_name in enumerate(self.class_names): y_true_class = y_true[:, class_index] y_pred_class = y_pred[class_name] @@ -57,39 +59,3 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): ) return ret - - -class ClassificationAccuracy(MetricsProtocol): - def __init__(self, class_names: List[str]): - self.class_names = class_names - - def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: - y_true = [ - match.gt_class if match.gt_class is not None else "__NONE__" - for match in matches - ] - - y_pred = pd.DataFrame( - [ - { - name: match.pred_class_scores.get(name, 0) - for name in self.class_names - } - for match in matches - ] - ).fillna(0) - y_pred = y_pred.apply( - lambda row: row.idxmax() - if row.max() >= (1 - row.sum()) - else "__NONE__", - axis=1, - ) - - accuracy = metrics.balanced_accuracy_score( - y_true, - y_pred, - ) - - return { - "classification_acc": float(accuracy), - } diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 19e586b..0384b7a 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -8,13 +8,10 @@ from pydantic import Field, field_validator from soundevent import data from batdetect2.configs import BaseConfig, load_config -from batdetect2.data.conditions import ( - SoundEventCondition, - build_sound_event_condition, -) +from batdetect2.data.conditions import build_sound_event_condition from batdetect2.targets.classes import ( DEFAULT_CLASSES, - DEFAULT_GENERIC_CLASS, + DEFAULT_DETECTION_CLASS, SoundEventDecoder, SoundEventEncoder, TargetClassConfig, @@ -58,7 +55,9 @@ __all__ = [ class TargetConfig(BaseConfig): - detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS) + detection_target: TargetClassConfig = Field( + default=DEFAULT_DETECTION_CLASS + ) classification_targets: List[TargetClassConfig] = Field( default_factory=lambda: DEFAULT_CLASSES @@ -151,49 +150,36 @@ class Targets(TargetProtocol): dimension_names: List[str] detection_class_name: str - def __init__( - self, - detection_class_name: str, - encode_fn: SoundEventEncoder, - decode_fn: SoundEventDecoder, - roi_mapper: ROITargetMapper, - class_names: list[str], - detection_class_tags: List[data.Tag], - filter_fn: Optional[SoundEventCondition] = None, - roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None, - ): - """Initialize the Targets object. + def __init__(self, config: TargetConfig): + """Initialize the Targets object.""" + self.config = config - Note: This constructor is typically called internally by the - `build_targets` factory function. + self._filter_fn = build_sound_event_condition( + config.detection_target.match_if + ) + self._encode_fn = build_sound_event_encoder( + config.classification_targets + ) + self._decode_fn = build_sound_event_decoder( + config.classification_targets + ) - Parameters - ---------- - encode_fn : SoundEventEncoder - Configured function to encode annotations to class names. - decode_fn : SoundEventDecoder - Configured function to decode class names to tags. - roi_mapper : ROITargetMapper - Configured object for mapping geometry to/from position/size. - class_names : list[str] - Ordered list of specific target class names. - generic_class_tags : List[data.Tag] - List of tags representing the generic class. - filter_fn : SoundEventFilter, optional - Configured function to filter annotations. Defaults to None. - transform_fn : SoundEventTransformation, optional - Configured function to transform annotation tags. Defaults to None. - """ - self.detection_class_name = detection_class_name - self.class_names = class_names - self.detection_class_tags = detection_class_tags - self.dimension_names = roi_mapper.dimension_names + self._roi_mapper = build_roi_mapper(config.roi) - self._roi_mapper = roi_mapper - self._filter_fn = filter_fn - self._encode_fn = encode_fn - self._decode_fn = decode_fn - self._roi_mapper_overrides = roi_mapper_overrides or {} + self.dimension_names = self._roi_mapper.dimension_names + + self.class_names = get_class_names_from_config( + config.classification_targets + ) + + self.detection_class_name = config.detection_target.name + self.detection_class_tags = config.detection_target.assign_tags + + self._roi_mapper_overrides = { + class_config.name: build_roi_mapper(class_config.roi) + for class_config in config.classification_targets + if class_config.roi is not None + } for class_name in self._roi_mapper_overrides: if class_name not in self.class_names: @@ -218,8 +204,6 @@ class Targets(TargetProtocol): True if the annotation should be kept (passes the filter), False otherwise. If no filter was configured, always returns True. """ - if not self._filter_fn: - return True return self._filter_fn(sound_event) def encode_class( @@ -331,7 +315,7 @@ class Targets(TargetProtocol): DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( classification_targets=DEFAULT_CLASSES, - detection_target=DEFAULT_GENERIC_CLASS, + detection_target=DEFAULT_DETECTION_CLASS, roi=AnchorBBoxMapperConfig(), ) @@ -339,13 +323,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( def build_targets(config: Optional[TargetConfig] = None) -> Targets: """Build a Targets object from a loaded TargetConfig. - This factory function takes the unified `TargetConfig` and constructs all - necessary functional components (filter, transform, encoder, - decoder, ROI mapper) by calling their respective builder functions. It also - extracts metadata (class names, generic tags, dimension names) to create - and return a fully initialized `Targets` instance, ready to process - annotations. - Parameters ---------- config : TargetConfig @@ -370,31 +347,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets: lambda: config.to_yaml_string(), ) - filter_fn = build_sound_event_condition(config.detection_target.match_if) - encode_fn = build_sound_event_encoder(config.classification_targets) - decode_fn = build_sound_event_decoder(config.classification_targets) - - roi_mapper = build_roi_mapper(config.roi) - class_names = get_class_names_from_config(config.classification_targets) - - generic_class_tags = config.detection_target.assign_tags - - roi_overrides = { - class_config.name: build_roi_mapper(class_config.roi) - for class_config in config.classification_targets - if class_config.roi is not None - } - - return Targets( - filter_fn=filter_fn, - encode_fn=encode_fn, - decode_fn=decode_fn, - class_names=class_names, - roi_mapper=roi_mapper, - detection_class_name=config.detection_target.name, - detection_class_tags=generic_class_tags, - roi_mapper_overrides=roi_overrides, - ) + return Targets(config=config) def load_targets( diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 2f660f0..b277d5e 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -15,6 +15,7 @@ from batdetect2.data.conditions import ( build_sound_event_condition, ) from batdetect2.targets.rois import ROIMapperConfig +from batdetect2.targets.terms import call_type, generic_class from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder __all__ = [ @@ -69,24 +70,27 @@ class TargetClassConfig(BaseConfig): return self._match_if -DEFAULT_GENERIC_CLASS = TargetClassConfig( +DEFAULT_DETECTION_CLASS = TargetClassConfig( name="bat", match_if=AllOfConfig( conditions=[ - HasTagConfig(tag=data.Tag(key="event", value="Echolocation")), + HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")), NotConfig( condition=HasAnyTagConfig( tags=[ - data.Tag(key="event", value="Feeding"), - data.Tag(key="event", value="Unknown"), - data.Tag(key="event", value="Not Bat"), + data.Tag(term=call_type, value="Feeding"), + data.Tag(term=call_type, value="Social"), + data.Tag(term=call_type, value="Unknown"), + data.Tag(term=generic_class, value="Unknown"), + data.Tag(term=generic_class, value="Not Bat"), + data.Tag(term=call_type, value="Not Bat"), ] ) ), ] ), assign_tags=[ - data.Tag(key="call_type", value="Echolocation"), + data.Tag(term=call_type, value="Echolocation"), data.Tag(key="order", value="Chiroptera"), ], ) @@ -94,73 +98,73 @@ DEFAULT_GENERIC_CLASS = TargetClassConfig( DEFAULT_CLASSES = [ TargetClassConfig( - name="myomys", - tags=[data.Tag(key="class", value="Myotis mystacinus")], - ), - TargetClassConfig( - name="myoalc", - tags=[data.Tag(key="class", value="Myotis alcathoe")], + name="barbar", + tags=[data.Tag(key="class", value="Barbastellus barbastellus")], ), TargetClassConfig( name="eptser", tags=[data.Tag(key="class", value="Eptesicus serotinus")], ), TargetClassConfig( - name="pipnat", - tags=[data.Tag(key="class", value="Pipistrellus nathusii")], - ), - TargetClassConfig( - name="barbar", - tags=[data.Tag(key="class", value="Barbastellus barbastellus")], - ), - TargetClassConfig( - name="myonat", - tags=[data.Tag(key="class", value="Myotis nattereri")], - ), - TargetClassConfig( - name="myodau", - tags=[data.Tag(key="class", value="Myotis daubentonii")], - ), - TargetClassConfig( - name="myobra", - tags=[data.Tag(key="class", value="Myotis brandtii")], - ), - TargetClassConfig( - name="pippip", - tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")], + name="myoalc", + tags=[data.Tag(key="class", value="Myotis alcathoe")], ), TargetClassConfig( name="myobec", tags=[data.Tag(key="class", value="Myotis bechsteinii")], ), TargetClassConfig( - name="pippyg", - tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")], + name="myobra", + tags=[data.Tag(key="class", value="Myotis brandtii")], ), TargetClassConfig( - name="rhihip", - tags=[data.Tag(key="class", value="Rhinolophus hipposideros")], + name="myodau", + tags=[data.Tag(key="class", value="Myotis daubentonii")], + ), + TargetClassConfig( + name="myomys", + tags=[data.Tag(key="class", value="Myotis mystacinus")], + ), + TargetClassConfig( + name="myonat", + tags=[data.Tag(key="class", value="Myotis nattereri")], ), TargetClassConfig( name="nyclei", tags=[data.Tag(key="class", value="Nyctalus leisleri")], ), TargetClassConfig( - name="rhifer", - tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")], + name="nycnoc", + tags=[data.Tag(key="class", value="Nyctalus noctula")], + ), + TargetClassConfig( + name="pipnat", + tags=[data.Tag(key="class", value="Pipistrellus nathusii")], + ), + TargetClassConfig( + name="pippip", + tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")], + ), + TargetClassConfig( + name="pippyg", + tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")], ), TargetClassConfig( name="pleaur", tags=[data.Tag(key="class", value="Plecotus auritus")], ), - TargetClassConfig( - name="nycnoc", - tags=[data.Tag(key="class", value="Nyctalus noctula")], - ), TargetClassConfig( name="pleaus", tags=[data.Tag(key="class", value="Plecotus austriacus")], ), + TargetClassConfig( + name="rhifer", + tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")], + ), + TargetClassConfig( + name="rhihip", + tags=[data.Tag(key="class", value="Rhinolophus hipposideros")], + ), ] diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index c45bf23..d759766 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -7,6 +7,7 @@ from torch.utils.data import DataLoader from batdetect2.evaluate.match import ( MatchConfig, + build_matcher, match_all_predictions, ) from batdetect2.plotting.clips import PreprocessorProtocol @@ -42,6 +43,8 @@ class ValidationMetrics(Callback): self.preprocessor = preprocessor self.plot = plot + self.matcher = build_matcher(config=match_config) + self._clip_annotations: List[data.ClipAnnotation] = [] self._predictions: List[List[RawPrediction]] = [] @@ -93,7 +96,7 @@ class ValidationMetrics(Callback): self._clip_annotations, self._predictions, targets=pl_module.model.targets, - config=self.match_config, + matcher=self.matcher, ) self.log_metrics(pl_module, matches) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 660a2f1..9822191 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -11,7 +11,6 @@ from torch.utils.data import DataLoader from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.metrics import ( - ClassificationAccuracy, ClassificationMeanAveragePrecision, DetectionAveragePrecision, ) @@ -175,7 +174,6 @@ def build_trainer_callbacks( ClassificationMeanAveragePrecision( class_names=targets.class_names ), - ClassificationAccuracy(class_names=targets.class_names), ], preprocessor=preprocessor, match_config=config.match, diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 06c53e2..b706bed 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -1,5 +1,15 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Protocol +from typing import ( + Dict, + Generic, + Iterable, + List, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, +) from soundevent import data @@ -40,5 +50,27 @@ class MatchEvaluation: return self.pred_class_scores[pred_class] +class MatcherProtocol(Protocol): + def __call__( + self, + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + scores: Sequence[float], + ) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ... + + +Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) + + +class AffinityFunction(Protocol, Generic[Geom]): + def __call__( + self, + geometry1: Geom, + geometry2: Geom, + time_buffer: float = 0.01, + freq_buffer: float = 1000, + ) -> float: ... + + class MetricsProtocol(Protocol): def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...