diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index 399c852..4345b26 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -19,9 +19,13 @@ from soundevent import data from batdetect2.core import BaseConfig, Registry from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing import RawPrediction +from batdetect2.typing import RawPrediction, TargetProtocol -__all__ = [] +__all__ = [ + "ClassificationMetric", + "ClassificationMetricConfig", + "build_classification_metric", +] @dataclass @@ -45,8 +49,8 @@ class ClipEval: ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] -classification_metrics: Registry[ClassificationMetric, []] = Registry( - "classification_metric" +classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = ( + Registry("classification_metric") ) @@ -58,9 +62,11 @@ class BaseClassificationConfig(BaseConfig): class BaseClassificationMetric: def __init__( self, + targets: TargetProtocol, include: Optional[List[str]] = None, exclude: Optional[List[str]] = None, ): + self.targets = targets self.include = include self.exclude = exclude @@ -84,13 +90,14 @@ class ClassificationAveragePrecisionConfig(BaseClassificationConfig): class ClassificationAveragePrecision(BaseClassificationMetric): def __init__( self, + targets: TargetProtocol, ignore_non_predictions: bool = True, ignore_generic: bool = True, label: str = "average_precision", include: Optional[List[str]] = None, exclude: Optional[List[str]] = None, ): - super().__init__(include=include, exclude=exclude) + super().__init__(include=include, exclude=exclude, targets=targets) self.ignore_non_predictions = ignore_non_predictions self.ignore_generic = ignore_generic self.label = label @@ -98,33 +105,11 @@ class ClassificationAveragePrecision(BaseClassificationMetric): def __call__( self, clip_evaluations: Sequence[ClipEval] ) -> Dict[str, float]: - y_true = defaultdict(list) - y_score = defaultdict(list) - num_positives = defaultdict(lambda: 0) - - class_names = set() - - for clip_eval in clip_evaluations: - for class_name, matches in clip_eval.matches.items(): - class_names.add(class_name) - - for m in matches: - # Exclude matches with ground truth sounds where the class - # is unknown - if m.is_generic and self.ignore_generic: - continue - - is_class = m.true_class == class_name - - if is_class: - num_positives[class_name] += 1 - - # Ignore matches that don't correspond to a prediction - if not m.is_prediction and self.ignore_non_predictions: - continue - - y_true[class_name].append(is_class) - y_score[class_name].append(m.score) + y_true, y_score, num_positives = _extract_per_class_metric_data( + clip_evaluations, + ignore_non_predictions=self.ignore_non_predictions, + ignore_generic=self.ignore_generic, + ) class_scores = { class_name: average_precision( @@ -132,7 +117,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric): y_score[class_name], num_positives=num_positives[class_name], ) - for class_name in class_names + for class_name in self.targets.class_names } mean_score = float( @@ -150,8 +135,12 @@ class ClassificationAveragePrecision(BaseClassificationMetric): @classification_metrics.register(ClassificationAveragePrecisionConfig) @staticmethod - def from_config(config: ClassificationAveragePrecisionConfig): + def from_config( + config: ClassificationAveragePrecisionConfig, + targets: TargetProtocol, + ): return ClassificationAveragePrecision( + targets=targets, ignore_non_predictions=config.ignore_non_predictions, ignore_generic=config.ignore_generic, label=config.label, @@ -170,12 +159,14 @@ class ClassificationROCAUCConfig(BaseClassificationConfig): class ClassificationROCAUC(BaseClassificationMetric): def __init__( self, + targets: TargetProtocol, ignore_non_predictions: bool = True, ignore_generic: bool = True, label: str = "roc_auc", include: Optional[List[str]] = None, exclude: Optional[List[str]] = None, ): + self.targets = targets self.ignore_non_predictions = ignore_non_predictions self.ignore_generic = ignore_generic self.label = label @@ -185,27 +176,11 @@ class ClassificationROCAUC(BaseClassificationMetric): def __call__( self, clip_evaluations: Sequence[ClipEval] ) -> Dict[str, float]: - y_true = defaultdict(list) - y_score = defaultdict(list) - - class_names = set() - - for clip_eval in clip_evaluations: - for class_name, matches in clip_eval.matches.items(): - class_names.add(class_name) - - for m in matches: - # Exclude matches with ground truth sounds where the class - # is unknown - if m.is_generic and self.ignore_generic: - continue - - # Ignore matches that don't correspond to a prediction - if not m.is_prediction and self.ignore_non_predictions: - continue - - y_true[class_name].append(m.true_class == class_name) - y_score[class_name].append(m.score) + y_true, y_score, _ = _extract_per_class_metric_data( + clip_evaluations, + ignore_non_predictions=self.ignore_non_predictions, + ignore_generic=self.ignore_generic, + ) class_scores = { class_name: float( @@ -214,7 +189,7 @@ class ClassificationROCAUC(BaseClassificationMetric): y_score[class_name], ) ) - for class_name in class_names + for class_name in self.targets.class_names } mean_score = float( @@ -232,8 +207,11 @@ class ClassificationROCAUC(BaseClassificationMetric): @classification_metrics.register(ClassificationROCAUCConfig) @staticmethod - def from_config(config: ClassificationROCAUCConfig): + def from_config( + config: ClassificationROCAUCConfig, targets: TargetProtocol + ): return ClassificationROCAUC( + targets=targets, ignore_non_predictions=config.ignore_non_predictions, ignore_generic=config.ignore_generic, label=config.label, @@ -249,5 +227,40 @@ ClassificationMetricConfig = Annotated[ ] -def build_classification_metrics(config: ClassificationMetricConfig): - return classification_metrics.build(config) +def build_classification_metric( + config: ClassificationMetricConfig, + targets: TargetProtocol, +) -> ClassificationMetric: + return classification_metrics.build(config, targets) + + +def _extract_per_class_metric_data( + clip_evaluations: Sequence[ClipEval], + ignore_non_predictions: bool = True, + ignore_generic: bool = True, +): + y_true = defaultdict(list) + y_score = defaultdict(list) + num_positives = defaultdict(lambda: 0) + + for clip_eval in clip_evaluations: + for class_name, matches in clip_eval.matches.items(): + for m in matches: + # Exclude matches with ground truth sounds where the class + # is unknown + if m.is_generic and ignore_generic: + continue + + is_class = m.true_class == class_name + + if is_class: + num_positives[class_name] += 1 + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and ignore_non_predictions: + continue + + y_true[class_name].append(is_class) + y_score[class_name].append(m.score) + + return y_true, y_score, num_positives diff --git a/src/batdetect2/evaluate/metrics/common.py b/src/batdetect2/evaluate/metrics/common.py index 44ce045..7c2925a 100644 --- a/src/batdetect2/evaluate/metrics/common.py +++ b/src/batdetect2/evaluate/metrics/common.py @@ -12,7 +12,7 @@ def compute_precision_recall( y_true, y_score, num_positives: Optional[int] = None, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: y_true = np.array(y_true) y_score = np.array(y_score) @@ -22,6 +22,7 @@ def compute_precision_recall( # Sort by score sort_ind = np.argsort(y_score)[::-1] y_true_sorted = y_true[sort_ind] + y_score_sorted = y_score[sort_ind] false_pos_c = np.cumsum(1 - y_true_sorted) true_pos_c = np.cumsum(y_true_sorted) @@ -34,7 +35,7 @@ def compute_precision_recall( precision[np.isnan(precision)] = 0 recall[np.isnan(recall)] = 0 - return precision, recall + return precision, recall, y_score_sorted def average_precision( @@ -42,7 +43,7 @@ def average_precision( y_score, num_positives: Optional[int] = None, ) -> float: - precision, recall = compute_precision_recall( + precision, recall, _ = compute_precision_recall( y_true, y_score, num_positives=num_positives, diff --git a/src/batdetect2/evaluate/plots/__init__.py b/src/batdetect2/evaluate/plots/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/batdetect2/evaluate/plots/base.py b/src/batdetect2/evaluate/plots/base.py new file mode 100644 index 0000000..e54e675 --- /dev/null +++ b/src/batdetect2/evaluate/plots/base.py @@ -0,0 +1,54 @@ +from typing import Optional + +import matplotlib.pyplot as plt +from matplotlib.figure import Figure + +from batdetect2.core import BaseConfig +from batdetect2.typing import TargetProtocol + + +class BasePlotConfig(BaseConfig): + label: str = "plot" + theme: str = "default" + title: Optional[str] = None + figsize: tuple[int, int] = (5, 5) + dpi: int = 100 + + +class BasePlot: + def __init__( + self, + targets: TargetProtocol, + label: str = "plot", + figsize: tuple[int, int] = (5, 5), + title: Optional[str] = None, + dpi: int = 100, + theme: str = "default", + ): + self.targets = targets + self.label = label + self.figsize = figsize + self.dpi = dpi + self.theme = theme + self.title = title + + def get_figure(self) -> Figure: + plt.style.use(self.theme) + fig = plt.figure(figsize=self.figsize, dpi=self.dpi) + + if self.title is not None: + fig.suptitle(self.title) + + return fig + + @classmethod + def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs): + return cls( + targets=targets, + figsize=config.figsize, + dpi=config.dpi, + theme=config.theme, + label=config.label, + title=config.title, + **kwargs, + ) diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py new file mode 100644 index 0000000..55fce6c --- /dev/null +++ b/src/batdetect2/evaluate/plots/classification.py @@ -0,0 +1,212 @@ +from typing import Annotated, Callable, Literal, Sequence, Tuple, Union + +from matplotlib.figure import Figure +from pydantic import Field +from sklearn import metrics + +from batdetect2.core import Registry +from batdetect2.evaluate.metrics.classification import ( + ClipEval, + _extract_per_class_metric_data, +) +from batdetect2.evaluate.metrics.common import compute_precision_recall +from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig +from batdetect2.plotting.metrics import ( + plot_pr_curves, + plot_roc_curves, + plot_threshold_precision_curves, + plot_threshold_recall_curves, +) +from batdetect2.typing import TargetProtocol + +ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] + +classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = ( + Registry("classification_plot") +) + + +class PRCurveConfig(BasePlotConfig): + name: Literal["pr_curve"] = "pr_curve" + label: str = "pr_curve" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class PRCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true, y_score, num_positives = _extract_per_class_metric_data( + clip_evaluations, + ignore_non_predictions=self.ignore_non_predictions, + ignore_generic=self.ignore_generic, + ) + + fig = self.get_figure() + ax = fig.subplots() + + data = { + class_name: compute_precision_recall( + y_true[class_name], + y_score[class_name], + num_positives=num_positives[class_name], + ) + for class_name in self.targets.class_names + } + + plot_pr_curves(data, ax=ax) + + return self.label, fig + + @classification_plots.register(PRCurveConfig) + @staticmethod + def from_config(config: PRCurveConfig, targets: TargetProtocol): + return PRCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ThresholdPRCurveConfig(BasePlotConfig): + name: Literal["threshold_pr_curve"] = "threshold_pr_curve" + label: str = "threshold_pr_curve" + figsize: tuple[int, int] = (10, 5) + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ThresholdPRCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true, y_score, num_positives = _extract_per_class_metric_data( + clip_evaluations, + ignore_non_predictions=self.ignore_non_predictions, + ignore_generic=self.ignore_generic, + ) + + data = { + class_name: compute_precision_recall( + y_true[class_name], + y_score[class_name], + num_positives[class_name], + ) + for class_name in self.targets.class_names + } + + fig = self.get_figure() + ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True) + + plot_threshold_precision_curves(data, ax=ax1, add_legend=False) + plot_threshold_recall_curves(data, ax=ax2, add_legend=True) + + return self.label, fig + + @classification_plots.register(ThresholdPRCurveConfig) + @staticmethod + def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol): + return ThresholdPRCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ROCCurveConfig(BasePlotConfig): + name: Literal["roc_curve"] = "roc_curve" + label: str = "roc_curve" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ROCCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true, y_score, _ = _extract_per_class_metric_data( + clip_evaluations, + ignore_non_predictions=self.ignore_non_predictions, + ignore_generic=self.ignore_generic, + ) + + data = { + class_name: metrics.roc_curve( + y_true[class_name], + y_score[class_name], + ) + for class_name in self.targets.class_names + } + + fig = self.get_figure() + ax = fig.subplots() + + plot_roc_curves(data, ax=ax) + + return self.label, fig + + @classification_plots.register(ROCCurveConfig) + @staticmethod + def from_config(config: ROCCurveConfig, targets: TargetProtocol): + return ROCCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +ClassificationPlotConfig = Annotated[ + Union[ + PRCurveConfig, + ROCCurveConfig, + ThresholdPRCurveConfig, + ], + Field(discriminator="name"), +] + + +def build_classification_plotter( + config: ClassificationPlotConfig, + targets: TargetProtocol, +) -> ClassificationPlotter: + return classification_plots.build(config, targets) diff --git a/src/batdetect2/evaluate/plots/clip_classification.py b/src/batdetect2/evaluate/plots/clip_classification.py new file mode 100644 index 0000000..322650d --- /dev/null +++ b/src/batdetect2/evaluate/plots/clip_classification.py @@ -0,0 +1,131 @@ +from typing import ( + Annotated, + Callable, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +from matplotlib.figure import Figure +from pydantic import Field +from sklearn import metrics + +from batdetect2.core import Registry +from batdetect2.evaluate.metrics.clip_classification import ClipEval +from batdetect2.evaluate.metrics.common import compute_precision_recall +from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig +from batdetect2.plotting.metrics import ( + plot_pr_curves, + plot_roc_curves, +) +from batdetect2.typing import TargetProtocol + +__all__ = [ + "ClipClassificationPlotConfig", + "ClipClassificationPlotter", + "build_clip_classification_plotter", +] + +ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] + +clip_classification_plots: Registry[ + ClipClassificationPlotter, [TargetProtocol] +] = Registry("clip_classification_plot") + + +class PRCurveConfig(BasePlotConfig): + name: Literal["pr_curve"] = "pr_curve" + label: str = "pr_curve" + title: Optional[str] = "Precision-Recall Curve" + + +class PRCurve(BasePlot): + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + data = {} + + for class_name in self.targets.class_names: + y_true = [class_name in c.true_classes for c in clip_evaluations] + y_score = [ + c.class_scores.get(class_name, 0) for c in clip_evaluations + ] + + precision, recall, thresholds = compute_precision_recall( + y_true, + y_score, + ) + + data[class_name] = (precision, recall, thresholds) + + fig = self.get_figure() + ax = fig.subplots() + plot_pr_curves(data, ax=ax) + return self.label, fig + + @clip_classification_plots.register(PRCurveConfig) + @staticmethod + def from_config(config: PRCurveConfig, targets: TargetProtocol): + return PRCurve.build( + config=config, + targets=targets, + ) + + +class ROCCurveConfig(BasePlotConfig): + name: Literal["roc_curve"] = "roc_curve" + label: str = "roc_curve" + title: Optional[str] = "ROC Curve" + + +class ROCCurve(BasePlot): + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + data = {} + + for class_name in self.targets.class_names: + y_true = [class_name in c.true_classes for c in clip_evaluations] + y_score = [ + c.class_scores.get(class_name, 0) for c in clip_evaluations + ] + + fpr, tpr, thresholds = metrics.roc_curve( + y_true, + y_score, + ) + + data[class_name] = (fpr, tpr, thresholds) + + fig = self.get_figure() + ax = fig.subplots() + plot_roc_curves(data, ax=ax) + return self.label, fig + + @clip_classification_plots.register(ROCCurveConfig) + @staticmethod + def from_config(config: ROCCurveConfig, targets: TargetProtocol): + return ROCCurve.build( + config=config, + targets=targets, + ) + + +ClipClassificationPlotConfig = Annotated[ + Union[ + PRCurveConfig, + ROCCurveConfig, + ], + Field(discriminator="name"), +] + + +def build_clip_classification_plotter( + config: ClipClassificationPlotConfig, + targets: TargetProtocol, +) -> ClipClassificationPlotter: + return clip_classification_plots.build(config, targets) diff --git a/src/batdetect2/evaluate/plots/clip_detection.py b/src/batdetect2/evaluate/plots/clip_detection.py new file mode 100644 index 0000000..8a34d65 --- /dev/null +++ b/src/batdetect2/evaluate/plots/clip_detection.py @@ -0,0 +1,160 @@ +from typing import ( + Annotated, + Callable, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +import pandas as pd +import seaborn as sns +from matplotlib.figure import Figure +from pydantic import Field +from sklearn import metrics + +from batdetect2.core import Registry +from batdetect2.evaluate.metrics.clip_detection import ClipEval +from batdetect2.evaluate.metrics.common import compute_precision_recall +from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig +from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve +from batdetect2.typing import TargetProtocol + +__all__ = [ + "ClipDetectionPlotConfig", + "ClipDetectionPlotter", + "build_clip_detection_plotter", +] + +ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] + + +clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = ( + Registry("clip_detection_plot") +) + + +class PRCurveConfig(BasePlotConfig): + name: Literal["pr_curve"] = "pr_curve" + label: str = "pr_curve" + title: Optional[str] = "Precision-Recall Curve" + + +class PRCurve(BasePlot): + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [c.gt_det for c in clip_evaluations] + y_score = [c.score for c in clip_evaluations] + + precision, recall, thresholds = compute_precision_recall( + y_true, + y_score, + ) + + fig = self.get_figure() + ax = fig.subplots() + plot_pr_curve(precision, recall, thresholds, ax=ax) + return self.label, fig + + @clip_detection_plots.register(PRCurveConfig) + @staticmethod + def from_config(config: PRCurveConfig, targets: TargetProtocol): + return PRCurve.build( + config=config, + targets=targets, + ) + + +class ROCCurveConfig(BasePlotConfig): + name: Literal["roc_curve"] = "roc_curve" + label: str = "roc_curve" + title: Optional[str] = "ROC Curve" + + +class ROCCurve(BasePlot): + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [c.gt_det for c in clip_evaluations] + y_score = [c.score for c in clip_evaluations] + + fpr, tpr, thresholds = metrics.roc_curve( + y_true, + y_score, + ) + + fig = self.get_figure() + ax = fig.subplots() + plot_roc_curve(fpr, tpr, thresholds, ax=ax) + return self.label, fig + + @clip_detection_plots.register(ROCCurveConfig) + @staticmethod + def from_config(config: ROCCurveConfig, targets: TargetProtocol): + return ROCCurve.build( + config=config, + targets=targets, + ) + + +class ScoreDistributionPlotConfig(BasePlotConfig): + name: Literal["score_distribution"] = "score_distribution" + label: str = "score_distribution" + title: Optional[str] = "Score Distribution" + + +class ScoreDistributionPlot(BasePlot): + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [c.gt_det for c in clip_evaluations] + y_score = [c.score for c in clip_evaluations] + + fig = self.get_figure() + ax = fig.subplots() + + df = pd.DataFrame({"is_true": y_true, "score": y_score}) + sns.histplot( + data=df, + x="score", + binwidth=0.025, + binrange=(0, 1), + hue="is_true", + ax=ax, + stat="probability", + common_norm=False, + ) + + return self.label, fig + + @clip_detection_plots.register(ScoreDistributionPlotConfig) + @staticmethod + def from_config( + config: ScoreDistributionPlotConfig, targets: TargetProtocol + ): + return ScoreDistributionPlot.build( + config=config, + targets=targets, + ) + + +ClipDetectionPlotConfig = Annotated[ + Union[ + PRCurveConfig, + ROCCurveConfig, + ScoreDistributionPlotConfig, + ], + Field(discriminator="name"), +] + + +def build_clip_detection_plotter( + config: ClipDetectionPlotConfig, + targets: TargetProtocol, +) -> ClipDetectionPlotter: + return clip_detection_plots.build(config, targets) diff --git a/src/batdetect2/evaluate/plots/detection.py b/src/batdetect2/evaluate/plots/detection.py new file mode 100644 index 0000000..dbcd9fc --- /dev/null +++ b/src/batdetect2/evaluate/plots/detection.py @@ -0,0 +1,350 @@ +import random +from typing import Annotated, Callable, Literal, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from matplotlib import patches +from matplotlib.figure import Figure +from pydantic import Field +from sklearn import metrics +from soundevent.plot import plot_geometry + +from batdetect2.audio import AudioConfig, build_audio_loader +from batdetect2.core import Registry +from batdetect2.evaluate.metrics.common import compute_precision_recall +from batdetect2.evaluate.metrics.detection import ClipEval +from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig +from batdetect2.plotting.clips import plot_clip +from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve +from batdetect2.preprocess import PreprocessingConfig, build_preprocessor +from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol + +DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] + +detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry( + name="detection_plot" +) + + +class PRCurveConfig(BasePlotConfig): + name: Literal["pr_curve"] = "pr_curve" + label: str = "pr_curve" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class PRCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evals: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evals: + for m in clip_eval.matches: + num_positives += int(m.is_ground_truth) + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and self.ignore_non_predictions: + continue + + y_true.append(m.is_ground_truth) + y_score.append(m.score) + + precision, recall, thresholds = compute_precision_recall( + y_true, + y_score, + num_positives=num_positives, + ) + + fig = self.get_figure() + ax = fig.subplots() + plot_pr_curve(precision, recall, thresholds, ax=ax) + return self.label, fig + + @detection_plots.register(PRCurveConfig) + @staticmethod + def from_config(config: PRCurveConfig, targets: TargetProtocol): + return PRCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ROCCurveConfig(BasePlotConfig): + name: Literal["roc_curve"] = "roc_curve" + label: str = "roc_curve" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ROCCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if not m.is_prediction and self.ignore_non_predictions: + # Ignore matches that don't correspond to a prediction + continue + + y_true.append(m.is_ground_truth) + y_score.append(m.score) + + fpr, tpr, thresholds = metrics.roc_curve( + y_true, + y_score, + ) + + fig = self.get_figure() + ax = fig.subplots() + plot_roc_curve(fpr, tpr, thresholds, ax=ax) + return self.label, fig + + @detection_plots.register(ROCCurveConfig) + @staticmethod + def from_config(config: ROCCurveConfig, targets: TargetProtocol): + return ROCCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ScoreDistributionPlotConfig(BasePlotConfig): + name: Literal["score_distribution"] = "score_distribution" + label: str = "score_distribution" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ScoreDistributionPlot(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if not m.is_prediction and self.ignore_non_predictions: + # Ignore matches that don't correspond to a prediction + continue + + y_true.append(m.is_ground_truth) + y_score.append(m.score) + + df = pd.DataFrame({"is_true": y_true, "score": y_score}) + + fig = self.get_figure() + ax = fig.subplots() + + sns.histplot( + data=df, + x="score", + binwidth=0.025, + binrange=(0, 1), + hue="is_true", + ax=ax, + stat="probability", + common_norm=False, + ) + + return self.label, fig + + @detection_plots.register(ScoreDistributionPlotConfig) + @staticmethod + def from_config( + config: ScoreDistributionPlotConfig, targets: TargetProtocol + ): + return ScoreDistributionPlot.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ExampleDetectionPlotConfig(BasePlotConfig): + name: Literal["example_detection"] = "example_detection" + label: str = "example_detection" + figsize: tuple[int, int] = (10, 15) + num_examples: int = 5 + threshold: float = 0.2 + audio: AudioConfig = Field(default_factory=AudioConfig) + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + + +class ExampleDetectionPlot(BasePlot): + def __init__( + self, + *args, + num_examples: int = 5, + threshold: float = 0.2, + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_examples = num_examples + self.audio_loader = audio_loader + self.threshold = threshold + self.preprocessor = preprocessor + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + fig = self.get_figure() + + sample = clip_evaluations + + if self.num_examples < len(sample): + sample = random.sample(sample, self.num_examples) + + axes = fig.subplots(nrows=self.num_examples, ncols=1) + + for ax, clip_eval in zip(axes, sample): + plot_clip( + clip_eval.clip, + audio_loader=self.audio_loader, + preprocessor=self.preprocessor, + ax=ax, + ) + + for m in clip_eval.matches: + is_match = ( + m.pred is not None + and m.gt is not None + and m.score >= self.threshold + ) + + if m.pred is not None: + plot_geometry( + m.pred.geometry, + ax=ax, + add_points=False, + facecolor="none", + alpha=m.pred.detection_score, + linestyle="-" if not is_match else "--", + color="red" if not is_match else "orange", + ) + + if m.gt is not None: + plot_geometry( + m.gt.sound_event.geometry, # type: ignore + ax=ax, + add_points=False, + facecolor="none", + color="green" if not is_match else "orange", + ) + + ax.set_title(clip_eval.clip.recording.path.name) + + # ax.legend( + # handles=[ + # patches.Patch( + # edgecolor="green", + # label="Ground Truth (Unmatched)", + # facecolor="none", + # ), + # patches.Patch( + # edgecolor="orange", + # label="Ground Truth (Matched)", + # facecolor="none", + # ), + # patches.Patch( + # edgecolor="red", + # label="Detection (Unmatched)", + # facecolor="none", + # ), + # patches.Patch( + # edgecolor="orange", + # label="Detection (Matched)", + # facecolor="none", + # linestyle="--", + # ), + # ] + # ) + + plt.tight_layout() + + return self.label, fig + + @detection_plots.register(ExampleDetectionPlotConfig) + @staticmethod + def from_config( + config: ExampleDetectionPlotConfig, + targets: TargetProtocol, + ): + return ExampleDetectionPlot.build( + config=config, + targets=targets, + num_examples=config.num_examples, + audio_loader=build_audio_loader(config.audio), + preprocessor=build_preprocessor(config.preprocessing), + ) + + +DetectionPlotConfig = Annotated[ + Union[ + PRCurveConfig, + ROCCurveConfig, + ScoreDistributionPlotConfig, + ExampleDetectionPlotConfig, + ], + Field(discriminator="name"), +] + + +def build_detection_plotter( + config: DetectionPlotConfig, + targets: TargetProtocol, +) -> DetectionPlotter: + return detection_plots.build(config, targets) diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py new file mode 100644 index 0000000..a398b79 --- /dev/null +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -0,0 +1,270 @@ +from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union + +from matplotlib.figure import Figure +from pydantic import Field +from sklearn import metrics + +from batdetect2.core import Registry +from batdetect2.evaluate.metrics.common import compute_precision_recall +from batdetect2.evaluate.metrics.top_class import ClipEval +from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig +from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve +from batdetect2.typing import TargetProtocol + +TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]] + +top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( + name="top_class_plot" +) + + +class PRCurveConfig(BasePlotConfig): + name: Literal["pr_curve"] = "pr_curve" + label: str = "pr_curve" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class PRCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if m.is_generic and self.ignore_generic: + # Ignore gt sounds with unknown class + continue + + num_positives += int(m.is_ground_truth) + + if not m.is_prediction and self.ignore_non_predictions: + # Ignore non predictions + continue + + y_true.append(m.pred_class == m.true_class) + y_score.append(m.score) + + precision, recall, thresholds = compute_precision_recall( + y_true, + y_score, + num_positives=num_positives, + ) + + fig = self.get_figure() + ax = fig.subplots() + plot_pr_curve(precision, recall, thresholds, ax=ax) + return self.label, fig + + @top_class_plots.register(PRCurveConfig) + @staticmethod + def from_config(config: PRCurveConfig, targets: TargetProtocol): + return PRCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ROCCurveConfig(BasePlotConfig): + name: Literal["roc_curve"] = "roc_curve" + label: str = "roc_curve" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ROCCurve(BasePlot): + def __init__( + self, + *args, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if m.is_generic and self.ignore_generic: + # Ignore gt sounds with unknown class + continue + + if not m.is_prediction and self.ignore_non_predictions: + # Ignore non predictions + continue + + y_true.append(m.pred_class == m.true_class) + y_score.append(m.score) + + fpr, tpr, thresholds = metrics.roc_curve( + y_true, + y_score, + ) + + fig = self.get_figure() + ax = fig.subplots() + plot_roc_curve(fpr, tpr, thresholds, ax=ax) + return self.label, fig + + @top_class_plots.register(ROCCurveConfig) + @staticmethod + def from_config(config: ROCCurveConfig, targets: TargetProtocol): + return ROCCurve.build( + config=config, + targets=targets, + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + ) + + +class ConfusionMatrixConfig(BasePlotConfig): + name: Literal["confusion_matrix"] = "confusion_matrix" + figsize: tuple[int, int] = (10, 10) + label: str = "confusion_matrix" + exclude_generic: bool = True + exclude_noise: bool = False + noise_class: str = "noise" + normalize: Literal["true", "pred", "all", "none"] = "true" + threshold: float = 0.2 + add_colorbar: bool = True + cmap: str = "Blues" + + +class ConfusionMatrix(BasePlot): + def __init__( + self, + *args, + exclude_generic: bool = True, + exclude_noise: bool = False, + noise_class: str = "noise", + add_colorbar: bool = True, + normalize: Literal["true", "pred", "all", "none"] = "true", + cmap: str = "Blues", + threshold: float = 0.2, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.exclude_generic = exclude_generic + self.exclude_noise = exclude_noise + self.noise_class = noise_class + self.normalize = normalize + self.add_colorbar = add_colorbar + self.threshold = threshold + self.cmap = cmap + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Tuple[str, Figure]: + y_true: List[str] = [] + y_pred: List[str] = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + true_class = m.true_class + pred_class = m.pred_class + + if not m.is_prediction and self.exclude_noise: + # Ignore matches that don't correspond to a prediction + continue + + if not m.is_ground_truth and self.exclude_noise: + # Ignore matches that don't correspond to a ground truth + continue + + if m.score < self.threshold: + if self.exclude_noise: + continue + + pred_class = self.noise_class + + if m.is_generic: + if self.exclude_generic: + # Ignore gt sounds with unknown class + continue + + true_class = self.targets.detection_class_name + + y_true.append(true_class or self.noise_class) + y_pred.append(pred_class or self.noise_class) + + fig = self.get_figure() + ax = fig.subplots() + + class_names = [*self.targets.class_names] + + if not self.exclude_generic: + class_names.append(self.targets.detection_class_name) + + if not self.exclude_noise: + class_names.append(self.noise_class) + + metrics.ConfusionMatrixDisplay.from_predictions( + y_true, + y_pred, + labels=class_names, + ax=ax, + xticks_rotation="vertical", + cmap=self.cmap, + colorbar=self.add_colorbar, + normalize=self.normalize if self.normalize != "none" else None, + values_format=".2f", + ) + + return self.label, fig + + @top_class_plots.register(ConfusionMatrixConfig) + @staticmethod + def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol): + return ConfusionMatrix.build( + config=config, + targets=targets, + exclude_generic=config.exclude_generic, + exclude_noise=config.exclude_noise, + noise_class=config.noise_class, + add_colorbar=config.add_colorbar, + normalize=config.normalize, + cmap=config.cmap, + ) + + +TopClassPlotConfig = Annotated[ + Union[ + PRCurveConfig, + ROCCurveConfig, + ConfusionMatrixConfig, + ], + Field(discriminator="name"), +] + + +def build_top_class_plotter( + config: TopClassPlotConfig, + targets: TargetProtocol, +) -> TopClassPlotter: + return top_class_plots.build(config, targets) diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 29b61ae..e545259 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -1,5 +1,16 @@ -from typing import Callable, Dict, Generic, List, Sequence, TypeVar +from typing import ( + Callable, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) +from matplotlib.figure import Figure from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds @@ -43,6 +54,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] + plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]] + ignore_start_end: float prefix: str @@ -54,9 +67,13 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], prefix: str, ignore_start_end: float = 0.01, + plots: Optional[ + List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]] + ] = None, ): self.matcher = matcher self.metrics = metrics + self.plots = plots or [] self.targets = targets self.prefix = prefix self.ignore_start_end = ignore_start_end @@ -72,6 +89,12 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): for name, score in metric_output.items() } + def generate_plots( + self, eval_outputs: List[T_Output] + ) -> Iterable[Tuple[str, Figure]]: + for plot in self.plots: + yield plot(eval_outputs) + def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], @@ -123,6 +146,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): config: BaseTaskConfig, targets: TargetProtocol, metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], + plots: Optional[ + List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]] + ] = None, **kwargs, ): matcher = build_matcher(config.matching_strategy) @@ -130,6 +156,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): matcher=matcher, targets=targets, metrics=metrics, + plots=plots, prefix=config.prefix, ignore_start_end=config.ignore_start_end, **kwargs, diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index 3481bd1..886d473 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -12,7 +12,11 @@ from batdetect2.evaluate.metrics.classification import ( ClassificationMetricConfig, ClipEval, MatchEval, - build_classification_metrics, + build_classification_metric, +) +from batdetect2.evaluate.plots.classification import ( + ClassificationPlotConfig, + build_classification_plotter, ) from batdetect2.evaluate.tasks.base import ( BaseTask, @@ -28,6 +32,7 @@ class ClassificationTaskConfig(BaseTaskConfig): metrics: List[ClassificationMetricConfig] = Field( default_factory=lambda: [ClassificationAveragePrecisionConfig()] ) + plots: List[ClassificationPlotConfig] = Field(default_factory=list) include_generics: bool = True @@ -128,10 +133,16 @@ class ClassificationTask(BaseTask[ClipEval]): targets: TargetProtocol, ): metrics = [ - build_classification_metrics(metric) for metric in config.metrics + build_classification_metric(metric, targets) + for metric in config.metrics + ] + plots = [ + build_classification_plotter(plot, targets) + for plot in config.plots ] return ClassificationTask.build( config=config, + plots=plots, targets=targets, metrics=metrics, ) diff --git a/src/batdetect2/evaluate/tasks/clip_classification.py b/src/batdetect2/evaluate/tasks/clip_classification.py index 91392ee..798f79b 100644 --- a/src/batdetect2/evaluate/tasks/clip_classification.py +++ b/src/batdetect2/evaluate/tasks/clip_classification.py @@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.clip_classification import ( ClipEval, build_clip_metric, ) +from batdetect2.evaluate.plots.clip_classification import ( + ClipClassificationPlotConfig, + build_clip_classification_plotter, +) from batdetect2.evaluate.tasks.base import ( BaseTask, BaseTaskConfig, @@ -26,6 +30,7 @@ class ClipClassificationTaskConfig(BaseTaskConfig): ClipClassificationAveragePrecisionConfig(), ] ) + plots: List[ClipClassificationPlotConfig] = Field(default_factory=list) class ClipClassificationTask(BaseTask[ClipEval]): @@ -68,8 +73,13 @@ class ClipClassificationTask(BaseTask[ClipEval]): targets: TargetProtocol, ): metrics = [build_clip_metric(metric) for metric in config.metrics] + plots = [ + build_clip_classification_plotter(plot, targets) + for plot in config.plots + ] return ClipClassificationTask.build( config=config, + plots=plots, metrics=metrics, targets=targets, ) diff --git a/src/batdetect2/evaluate/tasks/clip_detection.py b/src/batdetect2/evaluate/tasks/clip_detection.py index 15f2311..2fb60a7 100644 --- a/src/batdetect2/evaluate/tasks/clip_detection.py +++ b/src/batdetect2/evaluate/tasks/clip_detection.py @@ -9,6 +9,10 @@ from batdetect2.evaluate.metrics.clip_detection import ( ClipEval, build_clip_metric, ) +from batdetect2.evaluate.plots.clip_detection import ( + ClipDetectionPlotConfig, + build_clip_detection_plotter, +) from batdetect2.evaluate.tasks.base import ( BaseTask, BaseTaskConfig, @@ -25,6 +29,7 @@ class ClipDetectionTaskConfig(BaseTaskConfig): ClipDetectionAveragePrecisionConfig(), ] ) + plots: List[ClipDetectionPlotConfig] = Field(default_factory=list) class ClipDetectionTask(BaseTask[ClipEval]): @@ -59,8 +64,13 @@ class ClipDetectionTask(BaseTask[ClipEval]): targets: TargetProtocol, ): metrics = [build_clip_metric(metric) for metric in config.metrics] + plots = [ + build_clip_detection_plotter(plot, targets) + for plot in config.plots + ] return ClipDetectionTask.build( config=config, metrics=metrics, targets=targets, + plots=plots, ) diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py index 1308d43..0c13914 100644 --- a/src/batdetect2/evaluate/tasks/detection.py +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.detection import ( MatchEval, build_detection_metric, ) +from batdetect2.evaluate.plots.detection import ( + DetectionPlotConfig, + build_detection_plotter, +) from batdetect2.evaluate.tasks.base import ( BaseTask, BaseTaskConfig, @@ -24,6 +28,7 @@ class DetectionTaskConfig(BaseTaskConfig): metrics: List[DetectionMetricConfig] = Field( default_factory=lambda: [DetectionAveragePrecisionConfig()] ) + plots: List[DetectionPlotConfig] = Field(default_factory=list) class DetectionTask(BaseTask[ClipEval]): @@ -72,8 +77,12 @@ class DetectionTask(BaseTask[ClipEval]): targets: TargetProtocol, ): metrics = [build_detection_metric(metric) for metric in config.metrics] + plots = [ + build_detection_plotter(plot, targets) for plot in config.plots + ] return DetectionTask.build( config=config, metrics=metrics, targets=targets, + plots=plots, ) diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index c2a215c..db082eb 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.top_class import ( TopClassMetricConfig, build_top_class_metric, ) +from batdetect2.evaluate.plots.top_class import ( + TopClassPlotConfig, + build_top_class_plotter, +) from batdetect2.evaluate.tasks.base import ( BaseTask, BaseTaskConfig, @@ -24,6 +28,7 @@ class TopClassDetectionTaskConfig(BaseTaskConfig): metrics: List[TopClassMetricConfig] = Field( default_factory=lambda: [TopClassAveragePrecisionConfig()] ) + plots: List[TopClassPlotConfig] = Field(default_factory=list) class TopClassDetectionTask(BaseTask[ClipEval]): @@ -94,8 +99,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]): targets: TargetProtocol, ): metrics = [build_top_class_metric(metric) for metric in config.metrics] + plots = [ + build_top_class_plotter(plot, targets) for plot in config.plots + ] return TopClassDetectionTask.build( config=config, + plots=plots, metrics=metrics, targets=targets, ) diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index ff47802..3e2eea9 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -19,7 +19,7 @@ def create_ax( ) -> axes.Axes: """Create a new axis if none is provided""" if ax is None: - _, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore + _, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore return ax # type: ignore diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py new file mode 100644 index 0000000..09fcff3 --- /dev/null +++ b/src/batdetect2/plotting/metrics.py @@ -0,0 +1,281 @@ +from typing import Dict, Optional, Tuple + +import numpy as np +import seaborn as sns +from cycler import cycler +from matplotlib import axes + +from batdetect2.plotting.common import create_ax + + +def set_default_styler(ax: axes.Axes) -> axes.Axes: + color_cycler = cycler(color=sns.color_palette("muted")) + style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler( + marker=["o", "s", "^"] + ) + custom_cycler = color_cycler * len(style_cycler) + style_cycler * len( + color_cycler + ) + + ax.set_prop_cycle(custom_cycler) + return ax + + +def set_default_style(ax: axes.Axes) -> axes.Axes: + ax = set_default_styler(ax) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + return ax + + +def plot_pr_curve( + precision: np.ndarray, + recall: np.ndarray, + thresholds: np.ndarray, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_labels: bool = True, +) -> axes.Axes: + ax = create_ax(ax=ax, figsize=figsize) + + ax = set_default_style(ax) + + ax.plot( + recall, + precision, + label="PR Curve", + marker="o", + markevery=_get_marker_positions(thresholds), + ) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + + return ax + + +def plot_pr_curves( + data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_legend: bool = True, + add_labels: bool = True, +) -> axes.Axes: + ax = create_ax(ax=ax, figsize=figsize) + ax = set_default_style(ax) + + for name, (precision, recall, thresholds) in data.items(): + ax.plot( + recall, + precision, + label=name, + markevery=_get_marker_positions(thresholds), + ) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + + if add_legend: + ax.legend( + bbox_to_anchor=(1.05, 1), + loc="upper left", + borderaxespad=0.0, + ) + return ax + + +def plot_threshold_precision_curve( + threshold: np.ndarray, + precision: np.ndarray, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_labels: bool = True, +): + ax = create_ax(ax=ax, figsize=figsize) + + ax = set_default_style(ax) + + ax.plot(threshold, precision, markevery=_get_marker_positions(threshold)) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("Threshold") + ax.set_ylabel("Precision") + + return ax + + +def plot_threshold_precision_curves( + data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_legend: bool = True, + add_labels: bool = True, +): + ax = create_ax(ax=ax, figsize=figsize) + ax = set_default_style(ax) + + for name, (precision, _, thresholds) in data.items(): + ax.plot( + thresholds, + precision, + label=name, + markevery=_get_marker_positions(thresholds), + ) + + if add_legend: + ax.legend( + bbox_to_anchor=(1.05, 1), + loc="upper left", + borderaxespad=0.0, + ) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("Threshold") + ax.set_ylabel("Precision") + + return ax + + +def plot_threshold_recall_curve( + threshold: np.ndarray, + recall: np.ndarray, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_labels: bool = True, +): + ax = create_ax(ax=ax, figsize=figsize) + + ax = set_default_style(ax) + + ax.plot(threshold, recall, markevery=_get_marker_positions(threshold)) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("Threshold") + ax.set_ylabel("Recall") + + return ax + + +def plot_threshold_recall_curves( + data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_legend: bool = True, + add_labels: bool = True, +): + ax = create_ax(ax=ax, figsize=figsize) + ax = set_default_style(ax) + + for name, (_, recall, thresholds) in data.items(): + ax.plot( + thresholds, + recall, + label=name, + markevery=_get_marker_positions(thresholds), + ) + + if add_legend: + ax.legend( + bbox_to_anchor=(1.05, 1), + loc="upper left", + borderaxespad=0.0, + ) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("Threshold") + ax.set_ylabel("Recall") + + return ax + + +def plot_roc_curve( + fpr: np.ndarray, + tpr: np.ndarray, + thresholds: np.ndarray, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_labels: bool = True, +) -> axes.Axes: + ax = create_ax(ax=ax, figsize=figsize) + + ax = set_default_style(ax) + + ax.plot( + fpr, + tpr, + markevery=_get_marker_positions(thresholds), + ) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + + return ax + + +def plot_roc_curves( + data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + add_legend: bool = True, + add_labels: bool = True, +) -> axes.Axes: + ax = create_ax(ax=ax, figsize=figsize) + ax = set_default_style(ax) + + for name, (fpr, tpr, thresholds) in data.items(): + ax.plot( + fpr, + tpr, + label=name, + markevery=_get_marker_positions(thresholds), + ) + + if add_legend: + ax.legend( + bbox_to_anchor=(1.05, 1), + loc="upper left", + borderaxespad=0.0, + ) + + ax.set_xlim(0, 1.05) + ax.set_ylim(0, 1.05) + + if add_labels: + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + + return ax + + +def _get_marker_positions( + thresholds: np.ndarray, + n_points: int = 11, +) -> np.ndarray: + size = len(thresholds) + cut_points = np.linspace(0, 1, n_points) + indices = np.searchsorted(thresholds[::-1], cut_points) + return np.clip(size - indices, 0, size - 1) # type: ignore diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 3c1ed24..fc24cf1 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback @@ -35,6 +35,7 @@ class ValidationMetrics(Callback): def generate_plots( self, + eval_outputs: Any, pl_module: LightningModule, ): plotter = get_image_logger(pl_module.logger) # type: ignore @@ -42,20 +43,15 @@ class ValidationMetrics(Callback): if plotter is None: return - for figure_name, fig in self.evaluator.generate_plots( - self._clip_annotations, - self._predictions, - ): + for figure_name, fig in self.evaluator.generate_plots(eval_outputs): plotter(figure_name, fig, pl_module.global_step) def log_metrics( self, + eval_outputs: Any, pl_module: LightningModule, ): - metrics = self.evaluator.compute_metrics( - self._clip_annotations, - self._predictions, - ) + metrics = self.evaluator.compute_metrics(eval_outputs) pl_module.log_dict(metrics) def on_validation_epoch_end( @@ -63,8 +59,13 @@ class ValidationMetrics(Callback): trainer: Trainer, pl_module: LightningModule, ) -> None: - self.log_metrics(pl_module) - self.generate_plots(pl_module) + eval_outputs = self.evaluator.evaluate( + self._clip_annotations, + self._predictions, + ) + + self.log_metrics(eval_outputs, pl_module) + self.generate_plots(eval_outputs, pl_module) return super().on_validation_epoch_end(trainer, pl_module)