diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index 4345b26..602600b 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -30,6 +30,7 @@ __all__ = [ @dataclass class MatchEval: + clip: data.Clip gt: Optional[data.SoundEventAnnotation] pred: Optional[RawPrediction] diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py index ee837c8..0f76e2a 100644 --- a/src/batdetect2/evaluate/metrics/top_class.py +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -28,6 +28,7 @@ __all__ = [ @dataclass class MatchEval: + clip: data.Clip gt: Optional[data.SoundEventAnnotation] pred: Optional[RawPrediction] diff --git a/src/batdetect2/evaluate/plots.py b/src/batdetect2/evaluate/plots.py deleted file mode 100644 index 53a0420..0000000 --- a/src/batdetect2/evaluate/plots.py +++ /dev/null @@ -1,560 +0,0 @@ -import random -from collections import defaultdict -from dataclasses import dataclass, field -from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from pydantic import Field -from sklearn import metrics -from sklearn.preprocessing import label_binarize - -from batdetect2.audio import AudioConfig, build_audio_loader -from batdetect2.core import BaseConfig, Registry -from batdetect2.plotting.gallery import plot_match_gallery -from batdetect2.plotting.matches import plot_matches -from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.typing import ( - AudioLoader, - ClipMatches, - MatchEvaluation, - PlotterProtocol, - PreprocessorProtocol, -) - -__all__ = [ - "build_plotter", - "ExampleGallery", - "ExampleGalleryConfig", -] - - -plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot") - - -class ExampleGalleryConfig(BaseConfig): - name: Literal["example_gallery"] = "example_gallery" - examples_per_class: int = 5 - audio: AudioConfig = Field(default_factory=AudioConfig) - preprocessing: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - - -class ExampleGallery(PlotterProtocol): - def __init__( - self, - examples_per_class: int, - preprocessor: Optional[PreprocessorProtocol] = None, - audio_loader: Optional[AudioLoader] = None, - ): - self.examples_per_class = examples_per_class - self.preprocessor = preprocessor or build_preprocessor() - self.audio_loader = audio_loader or build_audio_loader() - - def __call__(self, clip_evaluations: Sequence[ClipMatches]): - per_class_matches = group_matches(clip_evaluations) - - for class_name, matches in per_class_matches.items(): - true_positives = get_binned_sample( - matches.true_positives, - n_examples=self.examples_per_class, - ) - - false_positives = get_binned_sample( - matches.false_positives, - n_examples=self.examples_per_class, - ) - - false_negatives = random.sample( - matches.false_negatives, - k=min(self.examples_per_class, len(matches.false_negatives)), - ) - - cross_triggers = get_binned_sample( - matches.cross_triggers, - n_examples=self.examples_per_class, - ) - - fig = plot_match_gallery( - true_positives, - false_positives, - false_negatives, - cross_triggers, - preprocessor=self.preprocessor, - audio_loader=self.audio_loader, - n_examples=self.examples_per_class, - ) - - yield f"example_gallery/{class_name}", fig - - plt.close(fig) - - @classmethod - def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]): - audio_loader = build_audio_loader(config.audio) - preprocessor = build_preprocessor( - config.preprocessing, - input_samplerate=audio_loader.samplerate, - ) - return cls( - examples_per_class=config.examples_per_class, - preprocessor=preprocessor, - audio_loader=audio_loader, - ) - - -plots_registry.register(ExampleGalleryConfig, ExampleGallery) - - -class ClipEvaluationPlotConfig(BaseConfig): - name: Literal["example_clip"] = "example_clip" - num_plots: int = 5 - audio: AudioConfig = Field(default_factory=AudioConfig) - preprocessing: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - - -class PlotClipEvaluation(PlotterProtocol): - def __init__( - self, - num_plots: int = 3, - preprocessor: Optional[PreprocessorProtocol] = None, - audio_loader: Optional[AudioLoader] = None, - ): - self.preprocessor = preprocessor - self.audio_loader = audio_loader - self.num_plots = num_plots - - def __call__(self, clip_evaluations: Sequence[ClipMatches]): - examples = random.sample( - clip_evaluations, - k=min(self.num_plots, len(clip_evaluations)), - ) - - for index, clip_evaluation in enumerate(examples): - fig, ax = plt.subplots() - plot_matches( - clip_evaluation.matches, - clip=clip_evaluation.clip, - audio_loader=self.audio_loader, - ax=ax, - ) - yield f"clip_evaluation/example_{index}", fig - plt.close(fig) - - @classmethod - def from_config( - cls, - config: ClipEvaluationPlotConfig, - class_names: List[str], - ): - audio_loader = build_audio_loader(config.audio) - preprocessor = build_preprocessor( - config.preprocessing, - input_samplerate=audio_loader.samplerate, - ) - return cls( - num_plots=config.num_plots, - preprocessor=preprocessor, - audio_loader=audio_loader, - ) - - -plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation) - - -class DetectionPRCurveConfig(BaseConfig): - name: Literal["detection_pr_curve"] = "detection_pr_curve" - - -class DetectionPRCurve(PlotterProtocol): - def __call__(self, clip_evaluations: Sequence[ClipMatches]): - y_true, y_score = zip( - *[ - (match.gt_det, match.pred_score) - for clip_eval in clip_evaluations - for match in clip_eval.matches - ] - ) - precision, recall, _ = metrics.precision_recall_curve(y_true, y_score) - fig, ax = plt.subplots() - - ax.plot(recall, precision, label="Detector") - ax.set_xlabel("Recall") - ax.set_ylabel("Precision") - ax.legend() - - yield "detection_pr_curve", fig - - @classmethod - def from_config( - cls, - config: DetectionPRCurveConfig, - class_names: List[str], - ): - return cls() - - -plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve) - - -class ClassificationPRCurvesConfig(BaseConfig): - name: Literal["classification_pr_curves"] = "classification_pr_curves" - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None - - -class ClassificationPRCurves(PlotterProtocol): - def __init__( - self, - class_names: List[str], - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - ): - self.class_names = class_names - self.selected = class_names - - if include is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name in include - ] - - if exclude is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name not in exclude - ] - - def __call__(self, clip_evaluations: Sequence[ClipMatches]): - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - for match in clip_eval.matches: - # Ignore generic unclassified targets - if match.gt_det and match.gt_class is None: - continue - - y_true.append( - match.gt_class - if match.gt_class is not None - else "__NONE__" - ) - - y_pred.append( - np.array( - [ - match.pred_class_scores.get(name, 0) - for name in self.class_names - ] - ) - ) - - y_true = label_binarize(y_true, classes=self.class_names) - y_pred = np.stack(y_pred) - - fig, ax = plt.subplots(figsize=(10, 10)) - for class_index, class_name in enumerate(self.class_names): - if class_name not in self.selected: - continue - - y_true_class = y_true[:, class_index] - y_pred_class = y_pred[:, class_index] - precision, recall, _ = metrics.precision_recall_curve( - y_true_class, - y_pred_class, - ) - ax.plot(recall, precision, label=class_name) - - ax.set_xlabel("Recall") - ax.set_ylabel("Precision") - ax.legend( - bbox_to_anchor=(1.05, 1), - loc="upper left", - borderaxespad=0.0, - ) - - yield "classification_pr_curve", fig - - @classmethod - def from_config( - cls, - config: ClassificationPRCurvesConfig, - class_names: List[str], - ): - return cls( - class_names=class_names, - include=config.include, - exclude=config.exclude, - ) - - -plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves) - - -class DetectionROCCurveConfig(BaseConfig): - name: Literal["detection_roc_curve"] = "detection_roc_curve" - - -class DetectionROCCurve(PlotterProtocol): - def __call__(self, clip_evaluations: Sequence[ClipMatches]): - y_true, y_score = zip( - *[ - (match.gt_det, match.pred_score) - for clip_eval in clip_evaluations - for match in clip_eval.matches - ] - ) - fpr, tpr, _ = metrics.roc_curve(y_true, y_score) - fig, ax = plt.subplots() - - ax.plot(fpr, tpr, label="Detection") - ax.set_xlabel("False Positive Rate") - ax.set_ylabel("True Positive Rate") - ax.legend() - - yield "detection_roc_curve", fig - - @classmethod - def from_config( - cls, - config: DetectionROCCurveConfig, - class_names: List[str], - ): - return cls() - - -plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve) - - -class ClassificationROCCurvesConfig(BaseConfig): - name: Literal["classification_roc_curves"] = "classification_roc_curves" - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None - - -class ClassificationROCCurves(PlotterProtocol): - def __init__( - self, - class_names: List[str], - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - ): - self.class_names = class_names - self.selected = class_names - - if include is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name in include - ] - - if exclude is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name not in exclude - ] - - def __call__(self, clip_evaluations: Sequence[ClipMatches]): - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - for match in clip_eval.matches: - # Ignore generic unclassified targets - if match.gt_det and match.gt_class is None: - continue - - y_true.append( - match.gt_class - if match.gt_class is not None - else "__NONE__" - ) - - y_pred.append( - np.array( - [ - match.pred_class_scores.get(name, 0) - for name in self.class_names - ] - ) - ) - - y_true = label_binarize(y_true, classes=self.class_names) - y_pred = np.stack(y_pred) - - fig, ax = plt.subplots(figsize=(10, 10)) - for class_index, class_name in enumerate(self.class_names): - if class_name not in self.selected: - continue - - y_true_class = y_true[:, class_index] - y_roced_class = y_pred[:, class_index] - fpr, tpr, _ = metrics.roc_curve( - y_true_class, - y_roced_class, - ) - ax.plot(fpr, tpr, label=class_name) - - ax.set_xlabel("False Positive Rate") - ax.set_ylabel("True Positive Rate") - ax.legend( - bbox_to_anchor=(1.05, 1), - loc="upper left", - borderaxespad=0.0, - ) - - yield "classification_roc_curve", fig - - @classmethod - def from_config( - cls, - config: ClassificationROCCurvesConfig, - class_names: List[str], - ): - return cls( - class_names=class_names, - include=config.include, - exclude=config.exclude, - ) - - -plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves) - - -class ConfusionMatrixConfig(BaseConfig): - name: Literal["confusion_matrix"] = "confusion_matrix" - background_class: str = "noise" - - -class ConfusionMatrix(PlotterProtocol): - def __init__(self, background_class: str, class_names: List[str]): - self.background_class = background_class - self.class_names = class_names - - def __call__(self, clip_evaluations: Sequence[ClipMatches]): - y_true = [] - y_pred = [] - - for clip_eval in clip_evaluations: - for match in clip_eval.matches: - # Ignore generic unclassified targets - if match.gt_det and match.gt_class is None: - continue - - y_true.append( - match.gt_class - if match.gt_class is not None - else self.background_class - ) - - top_class = match.top_class - y_pred.append( - top_class - if top_class is not None - else self.background_class - ) - - display = metrics.ConfusionMatrixDisplay.from_predictions( - y_true, - y_pred, - labels=[*self.class_names, self.background_class], - ) - - yield "confusion_matrix", display.figure_ - - @classmethod - def from_config( - cls, - config: ConfusionMatrixConfig, - class_names: List[str], - ): - return cls( - background_class=config.background_class, - class_names=class_names, - ) - - -plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix) - - -PlotConfig = Annotated[ - Union[ - ExampleGalleryConfig, - ClipEvaluationPlotConfig, - DetectionPRCurveConfig, - ClassificationPRCurvesConfig, - DetectionROCCurveConfig, - ClassificationROCCurvesConfig, - ConfusionMatrixConfig, - ], - Field(discriminator="name"), -] - - -def build_plotter( - config: PlotConfig, class_names: List[str] -) -> PlotterProtocol: - return plots_registry.build(config, class_names) - - -@dataclass -class ClassMatches: - false_positives: List[MatchEvaluation] = field(default_factory=list) - false_negatives: List[MatchEvaluation] = field(default_factory=list) - true_positives: List[MatchEvaluation] = field(default_factory=list) - cross_triggers: List[MatchEvaluation] = field(default_factory=list) - - -def group_matches( - clip_evaluations: Sequence[ClipMatches], -) -> Dict[str, ClassMatches]: - class_examples = defaultdict(ClassMatches) - - for clip_evaluation in clip_evaluations: - for match in clip_evaluation.matches: - gt_class = match.gt_class - pred_class = match.top_class - - if pred_class is None: - class_examples[gt_class].false_negatives.append(match) - continue - - if gt_class is None: - class_examples[pred_class].false_positives.append(match) - continue - - if gt_class != pred_class: - class_examples[gt_class].cross_triggers.append(match) - class_examples[pred_class].cross_triggers.append(match) - continue - - class_examples[gt_class].true_positives.append(match) - - return class_examples - - -def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5): - if len(matches) < n_examples: - return matches - - indices, pred_scores = zip( - *[ - (index, match.pred_class_scores[pred_class]) - for index, match in enumerate(matches) - if (pred_class := match.top_class) is not None - ] - ) - - bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") - df = pd.DataFrame({"indices": indices, "bins": bins}) - sample = df.groupby("bins").sample(1) - return [matches[ind] for ind in sample["indices"]] diff --git a/src/batdetect2/evaluate/plots/base.py b/src/batdetect2/evaluate/plots/base.py index e54e675..c01406b 100644 --- a/src/batdetect2/evaluate/plots/base.py +++ b/src/batdetect2/evaluate/plots/base.py @@ -11,7 +11,7 @@ class BasePlotConfig(BaseConfig): label: str = "plot" theme: str = "default" title: Optional[str] = None - figsize: tuple[int, int] = (5, 5) + figsize: tuple[int, int] = (10, 10) dpi: int = 100 @@ -20,7 +20,7 @@ class BasePlot: self, targets: TargetProtocol, label: str = "plot", - figsize: tuple[int, int] = (5, 5), + figsize: tuple[int, int] = (10, 10), title: Optional[str] = None, dpi: int = 100, theme: str = "default", @@ -32,7 +32,7 @@ class BasePlot: self.theme = theme self.title = title - def get_figure(self) -> Figure: + def create_figure(self) -> Figure: plt.style.use(self.theme) fig = plt.figure(figsize=self.figsize, dpi=self.dpi) diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py index 55fce6c..bc3faac 100644 --- a/src/batdetect2/evaluate/plots/classification.py +++ b/src/batdetect2/evaluate/plots/classification.py @@ -1,5 +1,15 @@ -from typing import Annotated, Callable, Literal, Sequence, Tuple, Union +from typing import ( + Annotated, + Callable, + Iterable, + Literal, + Optional, + Sequence, + Tuple, + Union, +) +import matplotlib.pyplot as plt from matplotlib.figure import Figure from pydantic import Field from sklearn import metrics @@ -12,14 +22,20 @@ from batdetect2.evaluate.metrics.classification import ( from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.plotting.metrics import ( + plot_pr_curve, plot_pr_curves, + plot_roc_curve, plot_roc_curves, + plot_threshold_precision_curve, plot_threshold_precision_curves, + plot_threshold_recall_curve, plot_threshold_recall_curves, ) from batdetect2.typing import TargetProtocol -ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] +ClassificationPlotter = Callable[ + [Sequence[ClipEval]], Iterable[Tuple[str, Figure]] +] classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = ( Registry("classification_plot") @@ -29,8 +45,10 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = ( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" + title: Optional[str] = "Classification Precision-Recall Curve" ignore_non_predictions: bool = True ignore_generic: bool = True + separate_figures: bool = False class PRCurve(BasePlot): @@ -39,25 +57,24 @@ class PRCurve(BasePlot): *args, ignore_non_predictions: bool = True, ignore_generic: bool = True, + separate_figures: bool = False, **kwargs, ): super().__init__(*args, **kwargs) self.ignore_non_predictions = ignore_non_predictions self.ignore_generic = ignore_generic + self.separate_figures = separate_figures def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true, y_score, num_positives = _extract_per_class_metric_data( clip_evaluations, ignore_non_predictions=self.ignore_non_predictions, ignore_generic=self.ignore_generic, ) - fig = self.get_figure() - ax = fig.subplots() - data = { class_name: compute_precision_recall( y_true[class_name], @@ -67,9 +84,23 @@ class PRCurve(BasePlot): for class_name in self.targets.class_names } - plot_pr_curves(data, ax=ax) + if not self.separate_figures: + fig = self.create_figure() + ax = fig.subplots() + plot_pr_curves(data, ax=ax) + yield self.label, fig + return - return self.label, fig + for class_name, (precision, recall, thresholds) in data.items(): + fig = self.create_figure() + ax = fig.subplots() + + ax = plot_pr_curve(precision, recall, thresholds, ax=ax) + ax.set_title(class_name) + + yield f"{self.label}/{class_name}", fig + + plt.close(fig) @classification_plots.register(PRCurveConfig) @staticmethod @@ -79,33 +110,37 @@ class PRCurve(BasePlot): targets=targets, ignore_non_predictions=config.ignore_non_predictions, ignore_generic=config.ignore_generic, + separate_figures=config.separate_figures, ) -class ThresholdPRCurveConfig(BasePlotConfig): - name: Literal["threshold_pr_curve"] = "threshold_pr_curve" - label: str = "threshold_pr_curve" - figsize: tuple[int, int] = (10, 5) +class ThresholdPrecisionCurveConfig(BasePlotConfig): + name: Literal["threshold_precision_curve"] = "threshold_precision_curve" + label: str = "threshold_precision_curve" + title: Optional[str] = "Classification Threshold-Precision Curve" ignore_non_predictions: bool = True ignore_generic: bool = True + separate_figures: bool = False -class ThresholdPRCurve(BasePlot): +class ThresholdPrecisionCurve(BasePlot): def __init__( self, *args, ignore_non_predictions: bool = True, ignore_generic: bool = True, + separate_figures: bool = False, **kwargs, ): super().__init__(*args, **kwargs) self.ignore_non_predictions = ignore_non_predictions self.ignore_generic = ignore_generic + self.separate_figures = separate_figures def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true, y_score, num_positives = _extract_per_class_metric_data( clip_evaluations, ignore_non_predictions=self.ignore_non_predictions, @@ -121,30 +156,135 @@ class ThresholdPRCurve(BasePlot): for class_name in self.targets.class_names } - fig = self.get_figure() - ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True) + if not self.separate_figures: + fig = self.create_figure() + ax = fig.subplots() - plot_threshold_precision_curves(data, ax=ax1, add_legend=False) - plot_threshold_recall_curves(data, ax=ax2, add_legend=True) + plot_threshold_precision_curves(data, ax=ax) - return self.label, fig + yield self.label, fig - @classification_plots.register(ThresholdPRCurveConfig) + return + + for class_name, (precision, _, thresholds) in data.items(): + fig = self.create_figure() + ax = fig.subplots() + + ax = plot_threshold_precision_curve( + thresholds, + precision, + ax=ax, + ) + + ax.set_title(class_name) + + yield f"{self.label}/{class_name}", fig + + plt.close(fig) + + @classification_plots.register(ThresholdPrecisionCurveConfig) @staticmethod - def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol): - return ThresholdPRCurve.build( + def from_config( + config: ThresholdPrecisionCurveConfig, targets: TargetProtocol + ): + return ThresholdPrecisionCurve.build( config=config, targets=targets, ignore_non_predictions=config.ignore_non_predictions, ignore_generic=config.ignore_generic, + separate_figures=config.separate_figures, + ) + + +class ThresholdRecallCurveConfig(BasePlotConfig): + name: Literal["threshold_recall_curve"] = "threshold_recall_curve" + label: str = "threshold_recall_curve" + title: Optional[str] = "Classification Threshold-Recall Curve" + ignore_non_predictions: bool = True + ignore_generic: bool = True + separate_figures: bool = False + + +class ThresholdRecallCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + separate_figures: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + self.separate_figures = separate_figures + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Iterable[Tuple[str, Figure]]: + y_true, y_score, num_positives = _extract_per_class_metric_data( + clip_evaluations, + ignore_non_predictions=self.ignore_non_predictions, + ignore_generic=self.ignore_generic, + ) + + data = { + class_name: compute_precision_recall( + y_true[class_name], + y_score[class_name], + num_positives[class_name], + ) + for class_name in self.targets.class_names + } + + if not self.separate_figures: + fig = self.create_figure() + ax = fig.subplots() + + plot_threshold_recall_curves(data, ax=ax, add_legend=True) + + yield self.label, fig + + return + + for class_name, (_, recall, thresholds) in data.items(): + fig = self.create_figure() + ax = fig.subplots() + + ax = plot_threshold_recall_curve( + thresholds, + recall, + ax=ax, + ) + + ax.set_title(class_name) + + yield f"{self.label}/{class_name}", fig + + plt.close(fig) + + @classification_plots.register(ThresholdRecallCurveConfig) + @staticmethod + def from_config( + config: ThresholdRecallCurveConfig, targets: TargetProtocol + ): + return ThresholdRecallCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + separate_figures=config.separate_figures, ) class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" + title: Optional[str] = "Classification ROC Curve" ignore_non_predictions: bool = True ignore_generic: bool = True + separate_figures: bool = False class ROCCurve(BasePlot): @@ -153,16 +293,18 @@ class ROCCurve(BasePlot): *args, ignore_non_predictions: bool = True, ignore_generic: bool = True, + separate_figures: bool = False, **kwargs, ): super().__init__(*args, **kwargs) self.ignore_non_predictions = ignore_non_predictions self.ignore_generic = ignore_generic + self.separate_figures = separate_figures def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true, y_score, _ = _extract_per_class_metric_data( clip_evaluations, ignore_non_predictions=self.ignore_non_predictions, @@ -177,12 +319,26 @@ class ROCCurve(BasePlot): for class_name in self.targets.class_names } - fig = self.get_figure() - ax = fig.subplots() + if not self.separate_figures: + fig = self.create_figure() + ax = fig.subplots() - plot_roc_curves(data, ax=ax) + plot_roc_curves(data, ax=ax) - return self.label, fig + yield self.label, fig + + return + + for class_name, (fpr, tpr, thresholds) in data.items(): + fig = self.create_figure() + ax = fig.subplots() + + ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax) + ax.set_title(class_name) + + yield f"{self.label}/{class_name}", fig + + plt.close(fig) @classification_plots.register(ROCCurveConfig) @staticmethod @@ -192,6 +348,7 @@ class ROCCurve(BasePlot): targets=targets, ignore_non_predictions=config.ignore_non_predictions, ignore_generic=config.ignore_generic, + separate_figures=config.separate_figures, ) @@ -199,7 +356,8 @@ ClassificationPlotConfig = Annotated[ Union[ PRCurveConfig, ROCCurveConfig, - ThresholdPRCurveConfig, + ThresholdPrecisionCurveConfig, + ThresholdRecallCurveConfig, ], Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/clip_classification.py b/src/batdetect2/evaluate/plots/clip_classification.py index 322650d..388e999 100644 --- a/src/batdetect2/evaluate/plots/clip_classification.py +++ b/src/batdetect2/evaluate/plots/clip_classification.py @@ -1,6 +1,7 @@ from typing import ( Annotated, Callable, + Iterable, Literal, Optional, Sequence, @@ -8,6 +9,7 @@ from typing import ( Union, ) +import matplotlib.pyplot as plt from matplotlib.figure import Figure from pydantic import Field from sklearn import metrics @@ -17,7 +19,9 @@ from batdetect2.evaluate.metrics.clip_classification import ClipEval from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.plotting.metrics import ( + plot_pr_curve, plot_pr_curves, + plot_roc_curve, plot_roc_curves, ) from batdetect2.typing import TargetProtocol @@ -28,7 +32,9 @@ __all__ = [ "build_clip_classification_plotter", ] -ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] +ClipClassificationPlotter = Callable[ + [Sequence[ClipEval]], Iterable[Tuple[str, Figure]] +] clip_classification_plots: Registry[ ClipClassificationPlotter, [TargetProtocol] @@ -38,14 +44,24 @@ clip_classification_plots: Registry[ class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" - title: Optional[str] = "Precision-Recall Curve" + title: Optional[str] = "Clip Classification Precision-Recall Curve" + separate_figures: bool = False class PRCurve(BasePlot): + def __init__( + self, + *args, + separate_figures: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.separate_figures = separate_figures + def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: data = {} for class_name in self.targets.class_names: @@ -61,10 +77,26 @@ class PRCurve(BasePlot): data[class_name] = (precision, recall, thresholds) - fig = self.get_figure() - ax = fig.subplots() - plot_pr_curves(data, ax=ax) - return self.label, fig + if not self.separate_figures: + fig = self.create_figure() + ax = fig.subplots() + + plot_pr_curves(data, ax=ax) + + yield self.label, fig + + return + + for class_name, (precision, recall, thresholds) in data.items(): + fig = self.create_figure() + ax = fig.subplots() + + ax = plot_pr_curve(precision, recall, thresholds, ax=ax) + ax.set_title(class_name) + + yield f"{self.label}/{class_name}", fig + + plt.close(fig) @clip_classification_plots.register(PRCurveConfig) @staticmethod @@ -72,20 +104,31 @@ class PRCurve(BasePlot): return PRCurve.build( config=config, targets=targets, + separate_figures=config.separate_figures, ) class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" - title: Optional[str] = "ROC Curve" + title: Optional[str] = "Clip Classification ROC Curve" + separate_figures: bool = False class ROCCurve(BasePlot): + def __init__( + self, + *args, + separate_figures: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.separate_figures = separate_figures + def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: data = {} for class_name in self.targets.class_names: @@ -101,10 +144,24 @@ class ROCCurve(BasePlot): data[class_name] = (fpr, tpr, thresholds) - fig = self.get_figure() - ax = fig.subplots() - plot_roc_curves(data, ax=ax) - return self.label, fig + if not self.separate_figures: + fig = self.create_figure() + ax = fig.subplots() + plot_roc_curves(data, ax=ax) + yield self.label, fig + + return + + for class_name, (fpr, tpr, thresholds) in data.items(): + fig = self.create_figure() + ax = fig.subplots() + + ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax) + ax.set_title(class_name) + + yield f"{self.label}/{class_name}", fig + + plt.close(fig) @clip_classification_plots.register(ROCCurveConfig) @staticmethod @@ -112,6 +169,7 @@ class ROCCurve(BasePlot): return ROCCurve.build( config=config, targets=targets, + separate_figures=config.separate_figures, ) diff --git a/src/batdetect2/evaluate/plots/clip_detection.py b/src/batdetect2/evaluate/plots/clip_detection.py index 8a34d65..cfcfb58 100644 --- a/src/batdetect2/evaluate/plots/clip_detection.py +++ b/src/batdetect2/evaluate/plots/clip_detection.py @@ -1,6 +1,7 @@ from typing import ( Annotated, Callable, + Iterable, Literal, Optional, Sequence, @@ -27,7 +28,9 @@ __all__ = [ "build_clip_detection_plotter", ] -ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] +ClipDetectionPlotter = Callable[ + [Sequence[ClipEval]], Iterable[Tuple[str, Figure]] +] clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = ( @@ -38,14 +41,14 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = ( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" - title: Optional[str] = "Precision-Recall Curve" + title: Optional[str] = "Clip Detection Precision-Recall Curve" class PRCurve(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [c.gt_det for c in clip_evaluations] y_score = [c.score for c in clip_evaluations] @@ -54,10 +57,10 @@ class PRCurve(BasePlot): y_score, ) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() plot_pr_curve(precision, recall, thresholds, ax=ax) - return self.label, fig + yield self.label, fig @clip_detection_plots.register(PRCurveConfig) @staticmethod @@ -71,14 +74,14 @@ class PRCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" - title: Optional[str] = "ROC Curve" + title: Optional[str] = "Clip Detection ROC Curve" class ROCCurve(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [c.gt_det for c in clip_evaluations] y_score = [c.score for c in clip_evaluations] @@ -87,10 +90,10 @@ class ROCCurve(BasePlot): y_score, ) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() plot_roc_curve(fpr, tpr, thresholds, ax=ax) - return self.label, fig + yield self.label, fig @clip_detection_plots.register(ROCCurveConfig) @staticmethod @@ -104,18 +107,18 @@ class ROCCurve(BasePlot): class ScoreDistributionPlotConfig(BasePlotConfig): name: Literal["score_distribution"] = "score_distribution" label: str = "score_distribution" - title: Optional[str] = "Score Distribution" + title: Optional[str] = "Clip Detection Score Distribution" class ScoreDistributionPlot(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [c.gt_det for c in clip_evaluations] y_score = [c.score for c in clip_evaluations] - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() df = pd.DataFrame({"is_true": y_true, "score": y_score}) @@ -130,7 +133,7 @@ class ScoreDistributionPlot(BasePlot): common_norm=False, ) - return self.label, fig + yield self.label, fig @clip_detection_plots.register(ScoreDistributionPlotConfig) @staticmethod diff --git a/src/batdetect2/evaluate/plots/detection.py b/src/batdetect2/evaluate/plots/detection.py index dbcd9fc..29e2b86 100644 --- a/src/batdetect2/evaluate/plots/detection.py +++ b/src/batdetect2/evaluate/plots/detection.py @@ -1,26 +1,33 @@ import random -from typing import Annotated, Callable, Literal, Sequence, Tuple, Union +from typing import ( + Annotated, + Callable, + Iterable, + Literal, + Optional, + Sequence, + Tuple, + Union, +) import matplotlib.pyplot as plt import pandas as pd import seaborn as sns -from matplotlib import patches from matplotlib.figure import Figure from pydantic import Field from sklearn import metrics -from soundevent.plot import plot_geometry from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.core import Registry from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.detection import ClipEval from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig -from batdetect2.plotting.clips import plot_clip +from batdetect2.plotting.detections import plot_clip_detections from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol -DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] +DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]] detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry( name="detection_plot" @@ -30,6 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" + title: Optional[str] = "Detection Precision-Recall Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -49,7 +57,7 @@ class PRCurve(BasePlot): def __call__( self, clip_evals: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [] y_score = [] num_positives = 0 @@ -71,10 +79,12 @@ class PRCurve(BasePlot): num_positives=num_positives, ) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() + plot_pr_curve(precision, recall, thresholds, ax=ax) - return self.label, fig + + yield self.label, fig @detection_plots.register(PRCurveConfig) @staticmethod @@ -90,6 +100,7 @@ class PRCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" + title: Optional[str] = "Detection ROC Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -109,7 +120,7 @@ class ROCCurve(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [] y_score = [] @@ -127,10 +138,12 @@ class ROCCurve(BasePlot): y_score, ) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() + plot_roc_curve(fpr, tpr, thresholds, ax=ax) - return self.label, fig + + yield self.label, fig @detection_plots.register(ROCCurveConfig) @staticmethod @@ -146,6 +159,7 @@ class ROCCurve(BasePlot): class ScoreDistributionPlotConfig(BasePlotConfig): name: Literal["score_distribution"] = "score_distribution" label: str = "score_distribution" + title: Optional[str] = "Detection Score Distribution" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -165,7 +179,7 @@ class ScoreDistributionPlot(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [] y_score = [] @@ -180,7 +194,7 @@ class ScoreDistributionPlot(BasePlot): df = pd.DataFrame({"is_true": y_true, "score": y_score}) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() sns.histplot( @@ -194,7 +208,7 @@ class ScoreDistributionPlot(BasePlot): common_norm=False, ) - return self.label, fig + yield self.label, fig @detection_plots.register(ScoreDistributionPlotConfig) @staticmethod @@ -212,7 +226,8 @@ class ScoreDistributionPlot(BasePlot): class ExampleDetectionPlotConfig(BasePlotConfig): name: Literal["example_detection"] = "example_detection" label: str = "example_detection" - figsize: tuple[int, int] = (10, 15) + title: Optional[str] = "Example Detection" + figsize: tuple[int, int] = (10, 4) num_examples: int = 5 threshold: float = 0.2 audio: AudioConfig = Field(default_factory=AudioConfig) @@ -240,82 +255,26 @@ class ExampleDetectionPlot(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: - fig = self.get_figure() - + ) -> Iterable[Tuple[str, Figure]]: sample = clip_evaluations if self.num_examples < len(sample): sample = random.sample(sample, self.num_examples) - axes = fig.subplots(nrows=self.num_examples, ncols=1) + for num_example, clip_eval in enumerate(sample): + fig = self.create_figure() + ax = fig.subplots() - for ax, clip_eval in zip(axes, sample): - plot_clip( - clip_eval.clip, + plot_clip_detections( + clip_eval, + ax=ax, audio_loader=self.audio_loader, preprocessor=self.preprocessor, - ax=ax, ) - for m in clip_eval.matches: - is_match = ( - m.pred is not None - and m.gt is not None - and m.score >= self.threshold - ) + yield f"{self.label}/example_{num_example}", fig - if m.pred is not None: - plot_geometry( - m.pred.geometry, - ax=ax, - add_points=False, - facecolor="none", - alpha=m.pred.detection_score, - linestyle="-" if not is_match else "--", - color="red" if not is_match else "orange", - ) - - if m.gt is not None: - plot_geometry( - m.gt.sound_event.geometry, # type: ignore - ax=ax, - add_points=False, - facecolor="none", - color="green" if not is_match else "orange", - ) - - ax.set_title(clip_eval.clip.recording.path.name) - - # ax.legend( - # handles=[ - # patches.Patch( - # edgecolor="green", - # label="Ground Truth (Unmatched)", - # facecolor="none", - # ), - # patches.Patch( - # edgecolor="orange", - # label="Ground Truth (Matched)", - # facecolor="none", - # ), - # patches.Patch( - # edgecolor="red", - # label="Detection (Unmatched)", - # facecolor="none", - # ), - # patches.Patch( - # edgecolor="orange", - # label="Detection (Matched)", - # facecolor="none", - # linestyle="--", - # ), - # ] - # ) - - plt.tight_layout() - - return self.label, fig + plt.close(fig) @detection_plots.register(ExampleDetectionPlotConfig) @staticmethod diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py index a398b79..32f354c 100644 --- a/src/batdetect2/evaluate/plots/top_class.py +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -1,17 +1,36 @@ -from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union +import random +from collections import defaultdict +from dataclasses import dataclass, field +from typing import ( + Annotated, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) +import matplotlib.pyplot as plt +import pandas as pd from matplotlib.figure import Figure from pydantic import Field from sklearn import metrics +from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.core import Registry from batdetect2.evaluate.metrics.common import compute_precision_recall -from batdetect2.evaluate.metrics.top_class import ClipEval +from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig +from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve -from batdetect2.typing import TargetProtocol +from batdetect2.preprocess import PreprocessingConfig, build_preprocessor +from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol -TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] +TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]] top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( name="top_class_plot" @@ -21,6 +40,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( class PRCurveConfig(BasePlotConfig): name: Literal["pr_curve"] = "pr_curve" label: str = "pr_curve" + title: Optional[str] = "Top Class Precision-Recall Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -40,7 +60,7 @@ class PRCurve(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [] y_score = [] num_positives = 0 @@ -66,10 +86,12 @@ class PRCurve(BasePlot): num_positives=num_positives, ) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() + plot_pr_curve(precision, recall, thresholds, ax=ax) - return self.label, fig + + yield self.label, fig @top_class_plots.register(PRCurveConfig) @staticmethod @@ -85,6 +107,7 @@ class PRCurve(BasePlot): class ROCCurveConfig(BasePlotConfig): name: Literal["roc_curve"] = "roc_curve" label: str = "roc_curve" + title: Optional[str] = "Top Class ROC Curve" ignore_non_predictions: bool = True ignore_generic: bool = True @@ -104,7 +127,7 @@ class ROCCurve(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true = [] y_score = [] @@ -126,10 +149,12 @@ class ROCCurve(BasePlot): y_score, ) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() + plot_roc_curve(fpr, tpr, thresholds, ax=ax) - return self.label, fig + + yield self.label, fig @top_class_plots.register(ROCCurveConfig) @staticmethod @@ -144,6 +169,7 @@ class ROCCurve(BasePlot): class ConfusionMatrixConfig(BasePlotConfig): name: Literal["confusion_matrix"] = "confusion_matrix" + title: Optional[str] = "Top Class Confusion Matrix" figsize: tuple[int, int] = (10, 10) label: str = "confusion_matrix" exclude_generic: bool = True @@ -180,7 +206,7 @@ class ConfusionMatrix(BasePlot): def __call__( self, clip_evaluations: Sequence[ClipEval], - ) -> Tuple[str, Figure]: + ) -> Iterable[Tuple[str, Figure]]: y_true: List[str] = [] y_pred: List[str] = [] @@ -213,7 +239,7 @@ class ConfusionMatrix(BasePlot): y_true.append(true_class or self.noise_class) y_pred.append(pred_class or self.noise_class) - fig = self.get_figure() + fig = self.create_figure() ax = fig.subplots() class_names = [*self.targets.class_names] @@ -236,7 +262,7 @@ class ConfusionMatrix(BasePlot): values_format=".2f", ) - return self.label, fig + yield self.label, fig @top_class_plots.register(ConfusionMatrixConfig) @staticmethod @@ -253,11 +279,105 @@ class ConfusionMatrix(BasePlot): ) +class ExampleClassificationPlotConfig(BasePlotConfig): + name: Literal["example_classification"] = "example_classification" + label: str = "example_classification" + title: Optional[str] = "Example Classification" + num_examples: int = 4 + threshold: float = 0.2 + audio: AudioConfig = Field(default_factory=AudioConfig) + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + + +class ExampleClassificationPlot(BasePlot): + def __init__( + self, + *args, + num_examples: int = 4, + threshold: float = 0.2, + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_examples = num_examples + self.audio_loader = audio_loader + self.threshold = threshold + self.preprocessor = preprocessor + self.num_examples = num_examples + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Iterable[Tuple[str, Figure]]: + grouped = group_matches(clip_evaluations, threshold=self.threshold) + + for class_name, matches in grouped.items(): + true_positives: List[MatchEval] = get_binned_sample( + matches.true_positives, + n_examples=self.num_examples, + ) + + false_positives: List[MatchEval] = get_binned_sample( + matches.false_positives, + n_examples=self.num_examples, + ) + + false_negatives: List[MatchEval] = random.sample( + matches.false_negatives, + k=min(self.num_examples, len(matches.false_negatives)), + ) + + cross_triggers: List[MatchEval] = get_binned_sample( + matches.cross_triggers, n_examples=self.num_examples + ) + + fig = self.create_figure() + + fig = plot_match_gallery( + true_positives, + false_positives, + false_negatives, + cross_triggers, + preprocessor=self.preprocessor, + audio_loader=self.audio_loader, + n_examples=self.num_examples, + fig=fig, + ) + + if self.title is not None: + fig.suptitle(f"{self.title}: {class_name}") + else: + fig.suptitle(class_name) + + yield f"{self.label}/{class_name}", fig + + plt.close(fig) + + @top_class_plots.register(ExampleClassificationPlotConfig) + @staticmethod + def from_config( + config: ExampleClassificationPlotConfig, + targets: TargetProtocol, + ): + return ExampleClassificationPlot.build( + config=config, + targets=targets, + num_examples=config.num_examples, + threshold=config.threshold, + audio_loader=build_audio_loader(config.audio), + preprocessor=build_preprocessor(config.preprocessing), + ) + + TopClassPlotConfig = Annotated[ Union[ PRCurveConfig, ROCCurveConfig, ConfusionMatrixConfig, + ExampleClassificationPlotConfig, ], Field(discriminator="name"), ] @@ -268,3 +388,57 @@ def build_top_class_plotter( targets: TargetProtocol, ) -> TopClassPlotter: return top_class_plots.build(config, targets) + + +@dataclass +class ClassMatches: + false_positives: List[MatchEval] = field(default_factory=list) + false_negatives: List[MatchEval] = field(default_factory=list) + true_positives: List[MatchEval] = field(default_factory=list) + cross_triggers: List[MatchEval] = field(default_factory=list) + + +def group_matches( + clip_evals: Sequence[ClipEval], + threshold: float = 0.2, +) -> Dict[str, ClassMatches]: + class_examples = defaultdict(ClassMatches) + + for clip_eval in clip_evals: + for match in clip_eval.matches: + gt_class = match.true_class + pred_class = match.pred_class + is_pred = match.score >= threshold + + if not is_pred and gt_class is not None: + class_examples[gt_class].false_negatives.append(match) + continue + + if not is_pred: + continue + + if gt_class is None: + class_examples[pred_class].false_positives.append(match) + continue + + if gt_class != pred_class: + class_examples[pred_class].cross_triggers.append(match) + continue + + class_examples[gt_class].true_positives.append(match) + + return class_examples + + +def get_binned_sample(matches: List[MatchEval], n_examples: int = 5): + if len(matches) < n_examples: + return matches + + indices, pred_scores = zip( + *[(index, match.score) for index, match in enumerate(matches)] + ) + + bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") + df = pd.DataFrame({"indices": indices, "bins": bins}) + sample = df.groupby("bins").sample(1) + return [matches[ind] for ind in sample["indices"]] diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index e545259..9a3ea8a 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -54,7 +54,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] - plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]] + plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] ignore_start_end: float @@ -68,7 +68,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): prefix: str, ignore_start_end: float = 0.01, plots: Optional[ - List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]] + List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] ] = None, ): self.matcher = matcher @@ -93,7 +93,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): self, eval_outputs: List[T_Output] ) -> Iterable[Tuple[str, Figure]]: for plot in self.plots: - yield plot(eval_outputs) + for name, fig in plot(eval_outputs): + yield f"{self.prefix}/{name}", fig def evaluate( self, @@ -147,7 +148,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): targets: TargetProtocol, metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], plots: Optional[ - List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]] + List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] ] = None, **kwargs, ): diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index 886d473..5d63c3f 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -98,6 +98,7 @@ class ClassificationTask(BaseTask[ClipEval]): matches.append( MatchEval( + clip=clip, gt=gt, pred=pred, is_prediction=pred is not None, diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index db082eb..78533d8 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -79,6 +79,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]): matches.append( MatchEval( + clip=clip, gt=gt, pred=pred, is_ground_truth=gt is not None, diff --git a/src/batdetect2/plotting/__init__.py b/src/batdetect2/plotting/__init__.py index 824ef86..08f8378 100644 --- a/src/batdetect2/plotting/__init__.py +++ b/src/batdetect2/plotting/__init__.py @@ -11,7 +11,6 @@ from batdetect2.plotting.matches import ( plot_cross_trigger_match, plot_false_negative_match, plot_false_positive_match, - plot_matches, plot_true_positive_match, ) @@ -22,7 +21,6 @@ __all__ = [ "plot_cross_trigger_match", "plot_false_negative_match", "plot_false_positive_match", - "plot_matches", "plot_spectrogram", "plot_true_positive_match", "plot_detection_heatmap", diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index 3e2eea9..d79ae02 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -66,6 +66,9 @@ def plot_spectrogram( vmax=vmax, ) + ax.set_xlim(start_time, end_time) + ax.set_ylim(min_freq, max_freq) + if add_colorbar: plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {})) diff --git a/src/batdetect2/plotting/detections.py b/src/batdetect2/plotting/detections.py new file mode 100644 index 0000000..800b8b6 --- /dev/null +++ b/src/batdetect2/plotting/detections.py @@ -0,0 +1,113 @@ +from typing import Optional + +from matplotlib import axes, patches +from soundevent.plot import plot_geometry + +from batdetect2.evaluate.metrics.detection import ClipEval +from batdetect2.plotting.clips import ( + AudioLoader, + PreprocessorProtocol, + plot_clip, +) +from batdetect2.plotting.common import create_ax + +__all__ = [ + "plot_clip_detections", +] + + +def plot_clip_detections( + clip_eval: ClipEval, + figsize: tuple[int, int] = (10, 10), + ax: Optional[axes.Axes] = None, + audio_loader: Optional[AudioLoader] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + threshold: float = 0.2, + add_legend: bool = True, + add_title: bool = True, + fill: bool = False, + linewidth: float = 1.0, + gt_color: str = "green", + gt_linestyle: str = "-", + true_pred_color: str = "yellow", + true_pred_linestyle: str = "--", + false_pred_color: str = "blue", + false_pred_linestyle: str = "-", + missed_gt_color: str = "red", + missed_gt_linestyle: str = "-", +) -> axes.Axes: + ax = create_ax(figsize=figsize, ax=ax) + + plot_clip( + clip_eval.clip, + audio_loader=audio_loader, + preprocessor=preprocessor, + ax=ax, + ) + + for m in clip_eval.matches: + is_match = ( + m.pred is not None and m.gt is not None and m.score >= threshold + ) + + if m.pred is not None: + color = true_pred_color if is_match else false_pred_color + plot_geometry( + m.pred.geometry, + ax=ax, + add_points=False, + facecolor="none" if not fill else color, + alpha=m.pred.detection_score, + linewidth=linewidth, + linestyle=true_pred_linestyle + if is_match + else missed_gt_linestyle, + color=color, + ) + + if m.gt is not None: + color = gt_color if is_match else missed_gt_color + plot_geometry( + m.gt.sound_event.geometry, # type: ignore + ax=ax, + add_points=False, + linewidth=linewidth, + facecolor="none" if not fill else color, + linestyle=gt_linestyle if is_match else false_pred_linestyle, + color=color, + ) + + if add_title: + ax.set_title(clip_eval.clip.recording.path.name) + + if add_legend: + ax.legend( + handles=[ + patches.Patch( + label="found GT", + edgecolor=gt_color, + facecolor="none" if not fill else gt_color, + linestyle=gt_linestyle, + ), + patches.Patch( + label="missed GT", + edgecolor=missed_gt_color, + facecolor="none" if not fill else missed_gt_color, + linestyle=missed_gt_linestyle, + ), + patches.Patch( + label="true Det", + edgecolor=true_pred_color, + facecolor="none" if not fill else true_pred_color, + linestyle=true_pred_linestyle, + ), + patches.Patch( + label="false Det", + edgecolor=false_pred_color, + facecolor="none" if not fill else false_pred_color, + linestyle=false_pred_linestyle, + ), + ] + ) + + return ax diff --git a/src/batdetect2/plotting/gallery.py b/src/batdetect2/plotting/gallery.py index 175f4c7..1a06f9e 100644 --- a/src/batdetect2/plotting/gallery.py +++ b/src/batdetect2/plotting/gallery.py @@ -1,81 +1,109 @@ -from typing import List, Optional +from typing import Optional, Sequence import matplotlib.pyplot as plt +from matplotlib.figure import Figure from batdetect2.plotting.matches import ( + MatchProtocol, plot_cross_trigger_match, plot_false_negative_match, plot_false_positive_match, plot_true_positive_match, ) -from batdetect2.typing.evaluate import MatchEvaluation from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol __all__ = ["plot_match_gallery"] def plot_match_gallery( - true_positives: List[MatchEvaluation], - false_positives: List[MatchEvaluation], - false_negatives: List[MatchEvaluation], - cross_triggers: List[MatchEvaluation], + true_positives: Sequence[MatchProtocol], + false_positives: Sequence[MatchProtocol], + false_negatives: Sequence[MatchProtocol], + cross_triggers: Sequence[MatchProtocol], audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, n_examples: int = 5, duration: float = 0.1, + fig: Optional[Figure] = None, ): - fig = plt.figure(figsize=(20, 20)) + if fig is None: + fig = plt.figure(figsize=(20, 20)) - for index, match in enumerate(true_positives[:n_examples]): - ax = plt.subplot(4, n_examples, index + 1) + axes = fig.subplots( + nrows=4, + ncols=n_examples, + sharex="none", + sharey="row", + ) + + for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]): try: plot_true_positive_match( - match, - ax=ax, + tp_match, + ax=tp_ax, audio_loader=audio_loader, preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + except ( + ValueError, + AssertionError, + RuntimeError, + FileNotFoundError, + ): continue - for index, match in enumerate(false_positives[:n_examples]): - ax = plt.subplot(4, n_examples, n_examples + index + 1) + for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]): try: plot_false_positive_match( - match, - ax=ax, + fp_match, + ax=fp_ax, audio_loader=audio_loader, preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + except ( + ValueError, + AssertionError, + RuntimeError, + FileNotFoundError, + ): continue - for index, match in enumerate(false_negatives[:n_examples]): - ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1) + for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]): try: plot_false_negative_match( - match, - ax=ax, + fn_match, + ax=fn_ax, audio_loader=audio_loader, preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + except ( + ValueError, + AssertionError, + RuntimeError, + FileNotFoundError, + ): continue - for index, match in enumerate(cross_triggers[:n_examples]): - ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1) + for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]): try: plot_cross_trigger_match( - match, - ax=ax, + ct_match, + ax=ct_ax, audio_loader=audio_loader, preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError): + except ( + ValueError, + AssertionError, + RuntimeError, + FileNotFoundError, + ): continue + fig.tight_layout() + return fig diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 29a6cff..1803dd3 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -1,16 +1,17 @@ -from typing import List, Optional, Tuple, Union +from typing import Optional, Protocol, Tuple, Union -import matplotlib.pyplot as plt from matplotlib.axes import Axes from soundevent import data, plot from soundevent.geometry import compute_bounds -from soundevent.plot.tags import TagColorMapper -from batdetect2.plotting.clips import AudioLoader, plot_clip -from batdetect2.typing import MatchEvaluation, PreprocessorProtocol +from batdetect2.plotting.clips import plot_clip +from batdetect2.typing import ( + AudioLoader, + PreprocessorProtocol, + RawPrediction, +) __all__ = [ - "plot_matches", "plot_false_positive_match", "plot_true_positive_match", "plot_false_negative_match", @@ -18,6 +19,14 @@ __all__ = [ ] +class MatchProtocol(Protocol): + clip: data.Clip + gt: Optional[data.SoundEventAnnotation] + pred: Optional[RawPrediction] + score: float + true_class: Optional[str] + + DEFAULT_DURATION = 0.05 DEFAULT_FALSE_POSITIVE_COLOR = "orange" DEFAULT_FALSE_NEGATIVE_COLOR = "red" @@ -27,88 +36,8 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-" DEFAULT_PREDICTION_LINE_STYLE = "--" -def plot_matches( - matches: List[MatchEvaluation], - clip: data.Clip, - audio_loader: Optional[AudioLoader] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - figsize: Optional[Tuple[int, int]] = None, - ax: Optional[Axes] = None, - audio_dir: Optional[data.PathLike] = None, - color_mapper: Optional[TagColorMapper] = None, - add_points: bool = False, - fill: bool = False, - spec_cmap: str = "gray", - false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR, - false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR, - true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR, - cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR, -) -> Axes: - ax = plot_clip( - clip, - ax=ax, - audio_loader=audio_loader, - preprocessor=preprocessor, - figsize=figsize, - audio_dir=audio_dir, - spec_cmap=spec_cmap, - ) - - if color_mapper is None: - color_mapper = TagColorMapper() - - for match in matches: - if match.is_cross_trigger(): - plot_cross_trigger_match( - match, - ax=ax, - fill=fill, - add_points=add_points, - add_spectrogram=False, - use_score=True, - color=cross_trigger_color, - add_text=False, - ) - elif match.is_true_positive(): - plot_true_positive_match( - match, - ax=ax, - fill=fill, - add_spectrogram=False, - use_score=True, - add_points=add_points, - color=true_positive_color, - add_text=False, - ) - elif match.is_false_negative(): - plot_false_negative_match( - match, - ax=ax, - fill=fill, - add_spectrogram=False, - add_points=add_points, - color=false_negative_color, - add_text=False, - ) - elif match.is_false_positive: - plot_false_positive_match( - match, - ax=ax, - fill=fill, - add_spectrogram=False, - use_score=True, - add_points=add_points, - color=false_positive_color, - add_text=False, - ) - else: - continue - - return ax - - def plot_false_positive_match( - match: MatchEvaluation, + match: MatchProtocol, audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, figsize: Optional[Tuple[int, int]] = None, @@ -119,21 +48,24 @@ def plot_false_positive_match( add_spectrogram: bool = True, add_text: bool = True, add_points: bool = False, + add_title: bool = True, fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_FALSE_POSITIVE_COLOR, fontsize: Union[float, str] = "small", ) -> Axes: - assert match.pred_geometry is not None - assert match.sound_event_annotation is None + assert match.pred is not None - start_time, _, _, high_freq = compute_bounds(match.pred_geometry) + start_time, _, _, high_freq = compute_bounds(match.pred.geometry) clip = data.Clip( - start_time=max(start_time - duration / 2, 0), + start_time=max( + start_time - duration / 2, + 0, + ), end_time=min( start_time + duration / 2, - match.clip.end_time, + match.clip.recording.duration, ), recording=match.clip.recording, ) @@ -150,30 +82,33 @@ def plot_false_positive_match( ) ax = plot.plot_geometry( - match.pred_geometry, + match.pred.geometry, ax=ax, add_points=add_points, facecolor="none" if not fill else None, - alpha=match.pred_score if use_score else 1, + alpha=match.score if use_score else 1, color=color, ) if add_text: - plt.text( + ax.text( start_time, high_freq, - f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.top_class} \nTop Class Score: {match.top_class_score:.2f} ", + f"score={match.score:.2f}", va="top", ha="right", color=color, fontsize=fontsize, ) + if add_title: + ax.set_title("False Positive") + return ax def plot_false_negative_match( - match: MatchEvaluation, + match: MatchProtocol, audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, figsize: Optional[Tuple[int, int]] = None, @@ -182,26 +117,28 @@ def plot_false_negative_match( duration: float = DEFAULT_DURATION, add_spectrogram: bool = True, add_points: bool = False, - add_text: bool = True, + add_title: bool = True, fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_FALSE_NEGATIVE_COLOR, - fontsize: Union[float, str] = "small", ) -> Axes: - assert match.pred_geometry is None - assert match.sound_event_annotation is not None - sound_event = match.sound_event_annotation.sound_event - geometry = sound_event.geometry + assert match.gt is not None + + geometry = match.gt.sound_event.geometry assert geometry is not None - start_time, _, _, high_freq = compute_bounds(geometry) + start_time = compute_bounds(geometry)[0] clip = data.Clip( - start_time=max(start_time - duration / 2, 0), - end_time=min( - start_time + duration / 2, sound_event.recording.duration + start_time=max( + start_time - duration / 2, + 0, ), - recording=sound_event.recording, + end_time=min( + start_time + duration / 2, + match.clip.recording.duration, + ), + recording=match.clip.recording, ) if add_spectrogram: @@ -215,33 +152,23 @@ def plot_false_negative_match( spec_cmap=spec_cmap, ) - ax = plot.plot_annotation( - match.sound_event_annotation, + ax = plot.plot_geometry( + geometry, ax=ax, - time_offset=0.001, - freq_offset=2_000, add_points=add_points, facecolor="none" if not fill else None, alpha=1, color=color, ) - if add_text: - plt.text( - start_time, - high_freq, - f"False Negative \nClass: {match.gt_class} ", - va="top", - ha="right", - color=color, - fontsize=fontsize, - ) + if add_title: + ax.set_title("False Negative") return ax def plot_true_positive_match( - match: MatchEvaluation, + match: MatchProtocol, preprocessor: Optional[PreprocessorProtocol] = None, audio_loader: Optional[AudioLoader] = None, figsize: Optional[Tuple[int, int]] = None, @@ -258,39 +185,42 @@ def plot_true_positive_match( fontsize: Union[float, str] = "small", annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, + add_title: bool = True, ) -> Axes: - assert match.sound_event_annotation is not None - assert match.pred_geometry is not None - sound_event = match.sound_event_annotation.sound_event - geometry = sound_event.geometry + assert match.gt is not None + assert match.pred is not None + + geometry = match.gt.sound_event.geometry assert geometry is not None start_time, _, _, high_freq = compute_bounds(geometry) clip = data.Clip( - start_time=max(start_time - duration / 2, 0), - end_time=min( - start_time + duration / 2, sound_event.recording.duration + start_time=max( + start_time - duration / 2, + 0, ), - recording=sound_event.recording, + end_time=min( + start_time + duration / 2, + match.clip.recording.duration, + ), + recording=match.clip.recording, ) if add_spectrogram: ax = plot_clip( clip, + ax=ax, audio_loader=audio_loader, preprocessor=preprocessor, figsize=figsize, - ax=ax, audio_dir=audio_dir, spec_cmap=spec_cmap, ) - ax = plot.plot_annotation( - match.sound_event_annotation, + ax = plot.plot_geometry( + geometry, ax=ax, - time_offset=0.001, - freq_offset=2_000, add_points=add_points, facecolor="none" if not fill else None, alpha=1, @@ -299,31 +229,34 @@ def plot_true_positive_match( ) plot.plot_geometry( - match.pred_geometry, + match.pred.geometry, ax=ax, add_points=add_points, facecolor="none" if not fill else None, - alpha=match.pred_score if use_score else 1, + alpha=match.score if use_score else 1, color=color, linestyle=prediction_linestyle, ) if add_text: - plt.text( + ax.text( start_time, high_freq, - f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ", + f"score={match.score:.2f}", va="top", ha="right", color=color, fontsize=fontsize, ) + if add_title: + ax.set_title("True Positive") + return ax def plot_cross_trigger_match( - match: MatchEvaluation, + match: MatchProtocol, preprocessor: Optional[PreprocessorProtocol] = None, audio_loader: Optional[AudioLoader] = None, figsize: Optional[Tuple[int, int]] = None, @@ -334,6 +267,7 @@ def plot_cross_trigger_match( add_spectrogram: bool = True, add_points: bool = False, add_text: bool = True, + add_title: bool = True, fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_CROSS_TRIGGER_COLOR, @@ -341,20 +275,24 @@ def plot_cross_trigger_match( annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, ) -> Axes: - assert match.sound_event_annotation is not None - assert match.pred_geometry is not None - sound_event = match.sound_event_annotation.sound_event - geometry = sound_event.geometry + assert match.gt is not None + assert match.pred is not None + + geometry = match.gt.sound_event.geometry assert geometry is not None start_time, _, _, high_freq = compute_bounds(geometry) clip = data.Clip( - start_time=max(start_time - duration / 2, 0), - end_time=min( - start_time + duration / 2, sound_event.recording.duration + start_time=max( + start_time - duration / 2, + 0, ), - recording=sound_event.recording, + end_time=min( + start_time + duration / 2, + match.clip.recording.duration, + ), + recording=match.clip.recording, ) if add_spectrogram: @@ -368,11 +306,9 @@ def plot_cross_trigger_match( spec_cmap=spec_cmap, ) - ax = plot.plot_annotation( - match.sound_event_annotation, + ax = plot.plot_geometry( + geometry, ax=ax, - time_offset=0.001, - freq_offset=2_000, add_points=add_points, facecolor="none" if not fill else None, alpha=1, @@ -381,24 +317,28 @@ def plot_cross_trigger_match( ) ax = plot.plot_geometry( - match.pred_geometry, + match.pred.geometry, ax=ax, add_points=add_points, facecolor="none" if not fill else None, - alpha=match.pred_score if use_score else 1, + alpha=match.score if use_score else 1, color=color, linestyle=prediction_linestyle, ) if add_text: - plt.text( + ax.text( start_time, high_freq, - f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.top_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ", + f"score={match.score:.2f}\nclass={match.true_class}", va="top", ha="right", color=color, fontsize=fontsize, ) + if add_title: + ax.set_title("Cross Trigger") + return ax + diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py index 09fcff3..52bf6fe 100644 --- a/src/batdetect2/plotting/metrics.py +++ b/src/batdetect2/plotting/metrics.py @@ -35,6 +35,8 @@ def plot_pr_curve( ax: Optional[axes.Axes] = None, figsize: Optional[Tuple[int, int]] = None, add_labels: bool = True, + add_legend: bool = False, + label: str = "PR Curve", ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) @@ -43,7 +45,7 @@ def plot_pr_curve( ax.plot( recall, precision, - label="PR Curve", + label=label, marker="o", markevery=_get_marker_positions(thresholds), ) @@ -51,6 +53,9 @@ def plot_pr_curve( ax.set_xlim(0, 1.05) ax.set_ylim(0, 1.05) + if add_legend: + ax.legend() + if add_labels: ax.set_xlabel("Recall") ax.set_ylabel("Precision")