diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index b911e67..c8f244b 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Iterable, Mapping +from dataclasses import dataclass, field from typing import List, Literal, Optional, Tuple import numpy as np @@ -340,3 +341,36 @@ def match_predictions_and_annotations( ) return matches + + +@dataclass +class ClassExamples: + false_positives: List[MatchEvaluation] = field(default_factory=list) + false_negatives: List[MatchEvaluation] = field(default_factory=list) + true_positives: List[MatchEvaluation] = field(default_factory=list) + cross_triggers: List[MatchEvaluation] = field(default_factory=list) + + +def group_matches(matches: List[MatchEvaluation]) -> ClassExamples: + class_examples = ClassExamples() + + for match in matches: + gt_class = match.gt_class + pred_class = match.pred_class + + if pred_class is None: + class_examples.false_negatives.append(match) + continue + + if gt_class is None: + class_examples.false_positives.append(match) + continue + + if gt_class != pred_class: + class_examples.cross_triggers.append(match) + class_examples.cross_triggers.append(match) + continue + + class_examples.true_positives.append(match) + + return class_examples diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py index 11cc670..051bf6c 100644 --- a/src/batdetect2/plotting/clips.py +++ b/src/batdetect2/plotting/clips.py @@ -37,6 +37,10 @@ def plot_clip( plot_spectrogram( spec, + start_time=clip.start_time, + end_time=clip.end_time, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, ax=ax, cmap=spec_cmap, ) diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index a1b1b93..9a8b930 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import matplotlib.pyplot as plt +import numpy as np import torch from matplotlib import axes @@ -25,10 +26,20 @@ def create_ax( def plot_spectrogram( spec: torch.Tensor, + start_time: float, + end_time: float, + min_freq: float, + max_freq: float, ax: Optional[axes.Axes] = None, figsize: Optional[Tuple[int, int]] = None, cmap="gray", ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) - ax.pcolormesh(spec.numpy(), cmap=cmap) + + ax.pcolormesh( + np.linspace(start_time, end_time, spec.shape[-1], endpoint=False), + np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False), + spec.numpy(), + cmap=cmap, + ) return ax diff --git a/src/batdetect2/plotting/evaluation.py b/src/batdetect2/plotting/evaluation.py index 6345b35..9ec9220 100644 --- a/src/batdetect2/plotting/evaluation.py +++ b/src/batdetect2/plotting/evaluation.py @@ -100,7 +100,7 @@ def plot_class_examples( preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError): + except (ValueError, AssertionError, RuntimeError): continue for index, match in enumerate(false_positives[:n_examples]): @@ -112,7 +112,7 @@ def plot_class_examples( preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError): + except (ValueError, AssertionError, RuntimeError): continue for index, match in enumerate(false_negatives[:n_examples]): @@ -124,7 +124,7 @@ def plot_class_examples( preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError): + except (ValueError, AssertionError, RuntimeError): continue for index, match in enumerate(cross_triggers[:n_examples]): @@ -136,7 +136,7 @@ def plot_class_examples( preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError): + except (ValueError, AssertionError, RuntimeError): continue return fig