diff --git a/src/batdetect2/api/base.py b/src/batdetect2/api/base.py index df0cae0..f56b7b6 100644 --- a/src/batdetect2/api/base.py +++ b/src/batdetect2/api/base.py @@ -56,7 +56,10 @@ class BatDetect2API: train( train_annotations=train_annotations, val_annotations=val_annotations, + targets=self.targets, config=self.config, + audio_loader=self.audio_loader, + preprocessor=self.preprocessor, train_workers=train_workers, val_workers=val_workers, checkpoint_dir=checkpoint_dir, diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index 7399d6e..c7ffcd3 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -27,7 +27,7 @@ class BaseConfig(BaseModel): and serialization capabilities. """ - model_config = ConfigDict(extra="ignore") + model_config = ConfigDict(extra="forbid") def to_yaml_string( self, diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index 3a02265..90e3f14 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -10,7 +10,7 @@ from batdetect2.evaluate.metrics import ( DetectionAPConfig, MetricConfig, ) -from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig +from batdetect2.evaluate.plots import PlotConfig __all__ = [ "EvaluationConfig", @@ -20,18 +20,14 @@ __all__ = [ class EvaluationConfig(BaseConfig): ignore_start_end: float = 0.01 - match: MatchConfig = Field(default_factory=StartTimeMatchConfig) + match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig) metrics: List[MetricConfig] = Field( default_factory=lambda: [ DetectionAPConfig(), ClassificationAPConfig(), ] ) - plots: List[PlotConfig] = Field( - default_factory=lambda: [ - ExampleGalleryConfig(), - ] - ) + plots: List[PlotConfig] = Field(default_factory=list) def load_evaluation_config( diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 768673a..1a70b76 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -58,7 +58,7 @@ def evaluate( clip_annotations = [] predictions = [] - evaluator = build_evaluator(config=config) + evaluator = build_evaluator(config=config, targets=targets) for batch in loader: outputs = model.detector(batch.spec) diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index d60f333..f0bca83 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -138,7 +138,7 @@ def build_evaluator( ) -> Evaluator: config = config or EvaluationConfig() targets = targets or build_targets() - matcher = matcher or build_matcher(config.match) + matcher = matcher or build_matcher(config.match_strategy) if metrics is None: metrics = [ @@ -147,7 +147,10 @@ def build_evaluator( ] if plots is None: - plots = [build_plotter(config) for config in config.plots] + plots = [ + build_plotter(config, targets.class_names) + for config in config.plots + ] return Evaluator( config=config, diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 2d67c13..af3545b 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -111,7 +111,7 @@ def match( class StartTimeMatchConfig(BaseConfig): - name: Literal["start_time"] = "start_time" + name: Literal["start_time_match"] = "start_time_match" distance_threshold: float = 0.01 diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index 6b52f7a..ea29ba4 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -1,3 +1,4 @@ +from collections import defaultdict from collections.abc import Callable, Mapping from typing import ( Annotated, @@ -12,8 +13,7 @@ from typing import ( import numpy as np from pydantic import Field -from sklearn import metrics -from sklearn.preprocessing import label_binarize +from sklearn import metrics, preprocessing from batdetect2.core.configs import BaseConfig from batdetect2.core.registries import Registry @@ -26,57 +26,18 @@ __all__ = ["DetectionAP", "ClassificationAP"] metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric") -AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"] +APImplementation = Literal["sklearn", "pascal_voc"] class DetectionAPConfig(BaseConfig): name: Literal["detection_ap"] = "detection_ap" - ap_implementation: AveragePrecisionImplementation = "pascal_voc" - - -def pascal_voc_average_precision(y_true, y_score) -> float: - y_true = np.array(y_true) - y_score = np.array(y_score) - - sort_ind = np.argsort(y_score)[::-1] - y_true_sorted = y_true[sort_ind] - - num_positives = y_true.sum() - false_pos_c = np.cumsum(1 - y_true_sorted) - true_pos_c = np.cumsum(y_true_sorted) - - recall = true_pos_c / num_positives - precision = true_pos_c / np.maximum( - true_pos_c + false_pos_c, - np.finfo(np.float64).eps, - ) - - precision[np.isnan(precision)] = 0 - recall[np.isnan(recall)] = 0 - - # pascal 12 way - mprec = np.hstack((0, precision, 0)) - mrec = np.hstack((0, recall, 1)) - for ii in range(mprec.shape[0] - 2, -1, -1): - mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1]) - inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1 - ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum() - - return ave_prec - - -_ap_impl_mapping: Mapping[ - AveragePrecisionImplementation, Callable[[Any, Any], float] -] = { - "sklearn": metrics.average_precision_score, - "pascal_voc": pascal_voc_average_precision, -} + ap_implementation: APImplementation = "pascal_voc" class DetectionAP(MetricsProtocol): def __init__( self, - implementation: AveragePrecisionImplementation = "pascal_voc", + implementation: APImplementation = "pascal_voc", ): self.implementation = implementation self.metric = _ap_impl_mapping[self.implementation] @@ -102,9 +63,37 @@ class DetectionAP(MetricsProtocol): metrics_registry.register(DetectionAPConfig, DetectionAP) +class DetectionROCAUCConfig(BaseConfig): + name: Literal["detection_roc_auc"] = "detection_roc_auc" + + +class DetectionROCAUC(MetricsProtocol): + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true, y_score = zip( + *[ + (match.gt_det, match.pred_score) + for clip_eval in clip_evaluations + for match in clip_eval.matches + ] + ) + score = float(metrics.roc_auc_score(y_true, y_score)) + return {"detection_ROC_AUC": score} + + @classmethod + def from_config( + cls, config: DetectionROCAUCConfig, class_names: List[str] + ): + return cls() + + +metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC) + + class ClassificationAPConfig(BaseConfig): name: Literal["classification_ap"] = "classification_ap" - ap_implementation: AveragePrecisionImplementation = "pascal_voc" + ap_implementation: APImplementation = "pascal_voc" include: Optional[List[str]] = None exclude: Optional[List[str]] = None @@ -113,7 +102,7 @@ class ClassificationAP(MetricsProtocol): def __init__( self, class_names: List[str], - implementation: AveragePrecisionImplementation = "pascal_voc", + implementation: APImplementation = "pascal_voc", include: Optional[List[str]] = None, exclude: Optional[List[str]] = None, ): @@ -164,7 +153,7 @@ class ClassificationAP(MetricsProtocol): ) ) - y_true = label_binarize(y_true, classes=self.class_names) + y_true = preprocessing.label_binarize(y_true, classes=self.class_names) y_pred = np.stack(y_pred) class_scores = {} @@ -203,11 +192,435 @@ class ClassificationAP(MetricsProtocol): metrics_registry.register(ClassificationAPConfig, ClassificationAP) +class ClassificationROCAUCConfig(BaseConfig): + name: Literal["classification_roc_auc"] = "classification_roc_auc" + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class ClassificationROCAUC(MetricsProtocol): + def __init__( + self, + class_names: List[str], + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ): + self.class_names = class_names + self.selected = class_names + + if include is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name in include + ] + + if exclude is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name not in exclude + ] + + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + for match in clip_eval.matches: + # Ignore generic unclassified targets + if match.gt_det and match.gt_class is None: + continue + + y_true.append( + match.gt_class + if match.gt_class is not None + else "__NONE__" + ) + + y_pred.append( + np.array( + [ + match.pred_class_scores.get(name, 0) + for name in self.class_names + ] + ) + ) + + y_true = preprocessing.label_binarize(y_true, classes=self.class_names) + y_pred = np.stack(y_pred) + + class_scores = {} + for class_index, class_name in enumerate(self.class_names): + y_true_class = y_true[:, class_index] + y_pred_class = y_pred[:, class_index] + class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class) + class_scores[class_name] = float(class_roc_auc) + + mean_roc_auc = np.mean( + [value for value in class_scores.values() if value != 0] + ) + + return { + "classification_macro_average_ROC_AUC": float(mean_roc_auc), + **{ + f"classification_ROC_AUC/{class_name}": class_scores[ + class_name + ] + for class_name in self.selected + }, + } + + @classmethod + def from_config( + cls, + config: ClassificationROCAUCConfig, + class_names: List[str], + ): + return cls( + class_names, + include=config.include, + exclude=config.exclude, + ) + + +metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC) + + +class TopClassAPConfig(BaseConfig): + name: Literal["top_class_ap"] = "top_class_ap" + ap_implementation: APImplementation = "pascal_voc" + + +class TopClassAP(MetricsProtocol): + def __init__( + self, + implementation: APImplementation = "pascal_voc", + ): + self.implementation = implementation + self.metric = _ap_impl_mapping[self.implementation] + + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + for match in clip_eval.matches: + # Ignore generic unclassified targets + if match.gt_det and match.gt_class is None: + continue + + top_class = match.pred_class + + y_true.append(top_class == match.gt_class) + y_score.append(match.pred_class_score) + + score = float(self.metric(y_true, y_score)) + return {"top_class_AP": score} + + @classmethod + def from_config(cls, config: TopClassAPConfig, class_names: List[str]): + return cls(implementation=config.ap_implementation) + + +metrics_registry.register(TopClassAPConfig, TopClassAP) + + +class ClassificationBalancedAccuracyConfig(BaseConfig): + name: Literal["classification_balanced_accuracy"] = ( + "classification_balanced_accuracy" + ) + + +class ClassificationBalancedAccuracy(MetricsProtocol): + def __init__(self, class_names: List[str]): + self.class_names = class_names + + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + for match in clip_eval.matches: + top_class = match.pred_class + + # Focus on matches + if match.gt_class is None or top_class is None: + continue + + y_true.append(self.class_names.index(match.gt_class)) + y_pred.append(self.class_names.index(top_class)) + + score = float(metrics.balanced_accuracy_score(y_true, y_pred)) + return {"classification_balanced_accuracy": score} + + @classmethod + def from_config( + cls, + config: ClassificationBalancedAccuracyConfig, + class_names: List[str], + ): + return cls(class_names) + + +metrics_registry.register( + ClassificationBalancedAccuracyConfig, + ClassificationBalancedAccuracy, +) + + +class ClipAPConfig(BaseConfig): + name: Literal["clip_ap"] = "clip_ap" + ap_implementation: APImplementation = "pascal_voc" + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class ClipAP(MetricsProtocol): + def __init__( + self, + class_names: List[str], + implementation: APImplementation, + include: Optional[Sequence[str]] = None, + exclude: Optional[Sequence[str]] = None, + ): + self.implementation = implementation + self.metric = _ap_impl_mapping[self.implementation] + self.class_names = class_names + + self.selected = class_names + + if include is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name in include + ] + + if exclude is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name not in exclude + ] + + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + clip_classes = set() + clip_scores = defaultdict(list) + + for match in clip_eval.matches: + if match.gt_class is not None: + clip_classes.add(match.gt_class) + + for class_name, score in match.pred_class_scores.items(): + clip_scores[class_name].append(score) + + y_true.append(clip_classes) + y_pred.append( + np.array( + [ + # Get max score for each class + max(clip_scores.get(class_name, [0])) + for class_name in self.class_names + ] + ) + ) + + y_true = preprocessing.MultiLabelBinarizer( + classes=self.class_names + ).fit_transform(y_true) + y_pred = np.stack(y_pred) + + class_scores = {} + for class_index, class_name in enumerate(self.class_names): + y_true_class = y_true[:, class_index] + y_pred_class = y_pred[:, class_index] + class_ap = self.metric(y_true_class, y_pred_class) + class_scores[class_name] = float(class_ap) + + mean_ap = np.mean( + [value for value in class_scores.values() if value != 0] + ) + return { + "clip_mAP": float(mean_ap), + **{ + f"clip_AP/{class_name}": class_scores[class_name] + for class_name in self.selected + }, + } + + @classmethod + def from_config(cls, config: ClipAPConfig, class_names: List[str]): + return cls( + implementation=config.ap_implementation, + include=config.include, + exclude=config.exclude, + class_names=class_names, + ) + + +metrics_registry.register(ClipAPConfig, ClipAP) + + +class ClipROCAUCConfig(BaseConfig): + name: Literal["clip_roc_auc"] = "clip_roc_auc" + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class ClipROCAUC(MetricsProtocol): + def __init__( + self, + class_names: List[str], + include: Optional[Sequence[str]] = None, + exclude: Optional[Sequence[str]] = None, + ): + self.class_names = class_names + self.selected = class_names + + if include is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name in include + ] + + if exclude is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name not in exclude + ] + + def __call__( + self, clip_evaluations: Sequence[ClipEvaluation] + ) -> Dict[str, float]: + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + clip_classes = set() + clip_scores = defaultdict(list) + + for match in clip_eval.matches: + if match.gt_class is not None: + clip_classes.add(match.gt_class) + + for class_name, score in match.pred_class_scores.items(): + clip_scores[class_name].append(score) + + y_true.append(clip_classes) + y_pred.append( + np.array( + [ + # Get maximum score for each class + max(clip_scores.get(class_name, [0])) + for class_name in self.class_names + ] + ) + ) + + y_true = preprocessing.MultiLabelBinarizer( + classes=self.class_names + ).fit_transform(y_true) + y_pred = np.stack(y_pred) + + class_scores = {} + for class_index, class_name in enumerate(self.class_names): + y_true_class = y_true[:, class_index] + y_pred_class = y_pred[:, class_index] + class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class) + class_scores[class_name] = float(class_roc_auc) + + mean_roc_auc = np.mean( + [value for value in class_scores.values() if value != 0] + ) + return { + "clip_macro_ROC_AUC": float(mean_roc_auc), + **{ + f"clip_ROC_AUC/{class_name}": class_scores[class_name] + for class_name in self.selected + }, + } + + @classmethod + def from_config( + cls, + config: ClipROCAUCConfig, + class_names: List[str], + ): + return cls( + include=config.include, + exclude=config.exclude, + class_names=class_names, + ) + + +metrics_registry.register(ClipROCAUCConfig, ClipROCAUC) + MetricConfig = Annotated[ - Union[ClassificationAPConfig, DetectionAPConfig], + Union[ + DetectionAPConfig, + DetectionROCAUCConfig, + ClassificationAPConfig, + ClassificationROCAUCConfig, + TopClassAPConfig, + ClassificationBalancedAccuracyConfig, + ClipAPConfig, + ClipROCAUCConfig, + ], Field(discriminator="name"), ] def build_metric(config: MetricConfig, class_names: List[str]): return metrics_registry.build(config, class_names) + + +def pascal_voc_average_precision(y_true, y_score) -> float: + y_true = np.array(y_true) + y_score = np.array(y_score) + + sort_ind = np.argsort(y_score)[::-1] + y_true_sorted = y_true[sort_ind] + + num_positives = y_true.sum() + false_pos_c = np.cumsum(1 - y_true_sorted) + true_pos_c = np.cumsum(y_true_sorted) + + recall = true_pos_c / num_positives + precision = true_pos_c / np.maximum( + true_pos_c + false_pos_c, + np.finfo(np.float64).eps, + ) + + precision[np.isnan(precision)] = 0 + recall[np.isnan(recall)] = 0 + + # pascal 12 way + mprec = np.hstack((0, precision, 0)) + mrec = np.hstack((0, recall, 1)) + for ii in range(mprec.shape[0] - 2, -1, -1): + mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1]) + inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1 + ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum() + + return ave_prec + + +_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = { + "sklearn": metrics.average_precision_score, + "pascal_voc": pascal_voc_average_precision, +} diff --git a/src/batdetect2/evaluate/plots.py b/src/batdetect2/evaluate/plots.py index 436e094..6680c2e 100644 --- a/src/batdetect2/evaluate/plots.py +++ b/src/batdetect2/evaluate/plots.py @@ -4,13 +4,18 @@ from dataclasses import dataclass, field from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union import matplotlib.pyplot as plt +import numpy as np import pandas as pd from pydantic import Field +from sklearn import metrics +from sklearn.preprocessing import label_binarize +from batdetect2.audio import AudioConfig from batdetect2.core.configs import BaseConfig from batdetect2.core.registries import Registry from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader from batdetect2.plotting.gallery import plot_match_gallery +from batdetect2.plotting.matches import plot_matches from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.typing.evaluate import ( ClipEvaluation, @@ -26,12 +31,13 @@ __all__ = [ ] -plots_registry: Registry[PlotterProtocol, []] = Registry("plot") +plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot") class ExampleGalleryConfig(BaseConfig): name: Literal["example_gallery"] = "example_gallery" examples_per_class: int = 5 + audio: AudioConfig = Field(default_factory=AudioConfig) preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) @@ -87,9 +93,12 @@ class ExampleGallery(PlotterProtocol): plt.close(fig) @classmethod - def from_config(cls, config: ExampleGalleryConfig): - preprocessor = build_preprocessor(config.preprocessing) - audio_loader = build_audio_loader(config.preprocessing.audio_transforms) + def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]): + audio_loader = build_audio_loader(config.audio) + preprocessor = build_preprocessor( + config.preprocessing, + input_samplerate=audio_loader.samplerate, + ) return cls( examples_per_class=config.examples_per_class, preprocessor=preprocessor, @@ -100,13 +109,345 @@ class ExampleGallery(PlotterProtocol): plots_registry.register(ExampleGalleryConfig, ExampleGallery) +class ClipEvaluationPlotConfig(BaseConfig): + name: Literal["example_clip"] = "example_clip" + num_plots: int = 5 + audio: AudioConfig = Field(default_factory=AudioConfig) + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + + +class PlotClipEvaluation(PlotterProtocol): + def __init__( + self, + num_plots: int = 3, + preprocessor: Optional[PreprocessorProtocol] = None, + audio_loader: Optional[AudioLoader] = None, + ): + self.preprocessor = preprocessor + self.audio_loader = audio_loader + self.num_plots = num_plots + + def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + examples = random.sample( + clip_evaluations, + k=min(self.num_plots, len(clip_evaluations)), + ) + + for index, clip_evaluation in enumerate(examples): + fig, ax = plt.subplots() + plot_matches( + clip_evaluation.matches, + clip=clip_evaluation.clip, + audio_loader=self.audio_loader, + ax=ax, + ) + yield f"clip_evaluation/example_{index}", fig + plt.close(fig) + + @classmethod + def from_config( + cls, + config: ClipEvaluationPlotConfig, + class_names: List[str], + ): + audio_loader = build_audio_loader(config.audio) + preprocessor = build_preprocessor( + config.preprocessing, + input_samplerate=audio_loader.samplerate, + ) + return cls( + num_plots=config.num_plots, + preprocessor=preprocessor, + audio_loader=audio_loader, + ) + + +plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation) + + +class DetectionPRCurveConfig(BaseConfig): + name: Literal["detection_pr_curve"] = "detection_pr_curve" + + +class DetectionPRCurve(PlotterProtocol): + def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + y_true, y_score = zip( + *[ + (match.gt_det, match.pred_score) + for clip_eval in clip_evaluations + for match in clip_eval.matches + ] + ) + precision, recall, _ = metrics.precision_recall_curve(y_true, y_score) + fig, ax = plt.subplots() + + ax.plot(recall, precision, label="Detector") + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.legend() + + yield "detection_pr_curve", fig + + @classmethod + def from_config( + cls, + config: DetectionPRCurveConfig, + class_names: List[str], + ): + return cls() + + +plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve) + + +class ClassificationPRCurvesConfig(BaseConfig): + name: Literal["classification_pr_curves"] = "classification_pr_curves" + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class ClassificationPRCurves(PlotterProtocol): + def __init__( + self, + class_names: List[str], + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ): + self.class_names = class_names + self.selected = class_names + + if include is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name in include + ] + + if exclude is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name not in exclude + ] + + def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + for match in clip_eval.matches: + # Ignore generic unclassified targets + if match.gt_det and match.gt_class is None: + continue + + y_true.append( + match.gt_class + if match.gt_class is not None + else "__NONE__" + ) + + y_pred.append( + np.array( + [ + match.pred_class_scores.get(name, 0) + for name in self.class_names + ] + ) + ) + + y_true = label_binarize(y_true, classes=self.class_names) + y_pred = np.stack(y_pred) + + fig, ax = plt.subplots(figsize=(10, 10)) + for class_index, class_name in enumerate(self.class_names): + if class_name not in self.selected: + continue + + y_true_class = y_true[:, class_index] + y_pred_class = y_pred[:, class_index] + precision, recall, _ = metrics.precision_recall_curve( + y_true_class, + y_pred_class, + ) + ax.plot(recall, precision, label=class_name) + + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.legend( + bbox_to_anchor=(1.05, 1), + loc="upper left", + borderaxespad=0.0, + ) + + yield "classification_pr_curve", fig + + @classmethod + def from_config( + cls, + config: ClassificationPRCurvesConfig, + class_names: List[str], + ): + return cls( + class_names=class_names, + include=config.include, + exclude=config.exclude, + ) + + +plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves) + + +class DetectionROCCurveConfig(BaseConfig): + name: Literal["detection_roc_curve"] = "detection_roc_curve" + + +class DetectionROCCurve(PlotterProtocol): + def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + y_true, y_score = zip( + *[ + (match.gt_det, match.pred_score) + for clip_eval in clip_evaluations + for match in clip_eval.matches + ] + ) + fpr, tpr, _ = metrics.roc_curve(y_true, y_score) + fig, ax = plt.subplots() + + ax.plot(fpr, tpr, label="Detection") + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + ax.legend() + + yield "detection_roc_curve", fig + + @classmethod + def from_config( + cls, + config: DetectionROCCurveConfig, + class_names: List[str], + ): + return cls() + + +plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve) + + +class ClassificationROCCurvesConfig(BaseConfig): + name: Literal["classification_roc_curves"] = "classification_roc_curves" + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class ClassificationROCCurves(PlotterProtocol): + def __init__( + self, + class_names: List[str], + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ): + self.class_names = class_names + self.selected = class_names + + if include is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name in include + ] + + if exclude is not None: + self.selected = [ + class_name + for class_name in self.selected + if class_name not in exclude + ] + + def __call__(self, clip_evaluations: Sequence[ClipEvaluation]): + y_true = [] + y_pred = [] + + for clip_eval in clip_evaluations: + for match in clip_eval.matches: + # Ignore generic unclassified targets + if match.gt_det and match.gt_class is None: + continue + + y_true.append( + match.gt_class + if match.gt_class is not None + else "__NONE__" + ) + + y_pred.append( + np.array( + [ + match.pred_class_scores.get(name, 0) + for name in self.class_names + ] + ) + ) + + y_true = label_binarize(y_true, classes=self.class_names) + y_pred = np.stack(y_pred) + + fig, ax = plt.subplots(figsize=(10, 10)) + for class_index, class_name in enumerate(self.class_names): + if class_name not in self.selected: + continue + + y_true_class = y_true[:, class_index] + y_roced_class = y_pred[:, class_index] + fpr, tpr, _ = metrics.roc_curve( + y_true_class, + y_roced_class, + ) + ax.plot(fpr, tpr, label=class_name) + + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + ax.legend( + bbox_to_anchor=(1.05, 1), + loc="upper left", + borderaxespad=0.0, + ) + + yield "classification_roc_curve", fig + + @classmethod + def from_config( + cls, + config: ClassificationROCCurvesConfig, + class_names: List[str], + ): + return cls( + class_names=class_names, + include=config.include, + exclude=config.exclude, + ) + + +plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves) + + PlotConfig = Annotated[ - Union[ExampleGalleryConfig,], Field(discriminator="name") + Union[ + ExampleGalleryConfig, + ClipEvaluationPlotConfig, + DetectionPRCurveConfig, + ClassificationPRCurvesConfig, + DetectionROCCurveConfig, + ClassificationROCCurvesConfig, + ], + Field(discriminator="name"), ] -def build_plotter(config: PlotConfig) -> PlotterProtocol: - return plots_registry.build(config) +def build_plotter( + config: PlotConfig, class_names: List[str] +) -> PlotterProtocol: + return plots_registry.build(config, class_names) @dataclass diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 1fc7c73..ae4775d 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -6,7 +6,6 @@ from soundevent import data, plot from soundevent.geometry import compute_bounds from soundevent.plot.tags import TagColorMapper -from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clips import AudioLoader, plot_clip from batdetect2.typing import MatchEvaluation, PreprocessorProtocol @@ -29,7 +28,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--" def plot_matches( - matches: List[data.Match], + matches: List[MatchEvaluation], clip: data.Clip, audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, @@ -43,8 +42,7 @@ def plot_matches( false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR, false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR, true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR, - annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, - prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, + cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR, ) -> Axes: ax = plot_clip( clip, @@ -60,52 +58,48 @@ def plot_matches( color_mapper = TagColorMapper() for match in matches: - if match.source is None and match.target is not None: - plot.plot_annotation( - annotation=match.target, + if match.is_cross_trigger(): + plot_cross_trigger_match( + match, ax=ax, - time_offset=0.004, - freq_offset=2_000, + fill=fill, + add_points=add_points, + add_spectrogram=False, + use_score=True, + color=cross_trigger_color, + add_text=False, + ) + elif match.is_true_positive(): + plot_true_positive_match( + match, + ax=ax, + fill=fill, + add_spectrogram=False, + use_score=True, + add_points=add_points, + color=true_positive_color, + add_text=False, + ) + elif match.is_false_negative(): + plot_false_negative_match( + match, + ax=ax, + fill=fill, + add_spectrogram=False, add_points=add_points, - facecolor="none" if not fill else None, color=false_negative_color, - color_mapper=color_mapper, - linestyle=annotation_linestyle, + add_text=False, ) - elif match.target is None and match.source is not None: - plot_prediction( - prediction=match.source, + elif match.is_false_positive: + plot_false_positive_match( + match, ax=ax, - time_offset=0.004, - freq_offset=2_000, + fill=fill, + add_spectrogram=False, + use_score=True, add_points=add_points, - facecolor="none" if not fill else None, color=false_positive_color, - color_mapper=color_mapper, - linestyle=prediction_linestyle, - ) - elif match.target is not None and match.source is not None: - plot.plot_annotation( - annotation=match.target, - ax=ax, - time_offset=0.004, - freq_offset=2_000, - add_points=add_points, - facecolor="none" if not fill else None, - color=true_positive_color, - color_mapper=color_mapper, - linestyle=annotation_linestyle, - ) - plot_prediction( - prediction=match.source, - ax=ax, - time_offset=0.004, - freq_offset=2_000, - add_points=add_points, - facecolor="none" if not fill else None, - color=true_positive_color, - color_mapper=color_mapper, - linestyle=prediction_linestyle, + add_text=False, ) else: continue @@ -121,6 +115,9 @@ def plot_false_positive_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, + use_score: bool = True, + add_spectrogram: bool = True, + add_text: bool = True, add_points: bool = False, fill: bool = False, spec_cmap: str = "gray", @@ -141,34 +138,36 @@ def plot_false_positive_match( recording=match.clip.recording, ) - ax = plot_clip( - clip, - audio_loader=audio_loader, - preprocessor=preprocessor, - figsize=figsize, - ax=ax, - audio_dir=audio_dir, - spec_cmap=spec_cmap, - ) + if add_spectrogram: + ax = plot_clip( + clip, + audio_loader=audio_loader, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + spec_cmap=spec_cmap, + ) - plot.plot_geometry( + ax = plot.plot_geometry( match.pred_geometry, ax=ax, add_points=add_points, facecolor="none" if not fill else None, - alpha=1, + alpha=match.pred_score if use_score else 1, color=color, ) - plt.text( - start_time, - high_freq, - f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ", - va="top", - ha="right", - color=color, - fontsize=fontsize, - ) + if add_text: + plt.text( + start_time, + high_freq, + f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) return ax @@ -181,7 +180,9 @@ def plot_false_negative_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, + add_spectrogram: bool = True, add_points: bool = False, + add_text: bool = True, fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_FALSE_NEGATIVE_COLOR, @@ -203,17 +204,18 @@ def plot_false_negative_match( recording=sound_event.recording, ) - ax = plot_clip( - clip, - audio_loader=audio_loader, - preprocessor=preprocessor, - figsize=figsize, - ax=ax, - audio_dir=audio_dir, - spec_cmap=spec_cmap, - ) + if add_spectrogram: + ax = plot_clip( + clip, + audio_loader=audio_loader, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + spec_cmap=spec_cmap, + ) - plot.plot_annotation( + ax = plot.plot_annotation( match.sound_event_annotation, ax=ax, time_offset=0.001, @@ -224,15 +226,16 @@ def plot_false_negative_match( color=color, ) - plt.text( - start_time, - high_freq, - f"False Negative \nClass: {match.gt_class} ", - va="top", - ha="right", - color=color, - fontsize=fontsize, - ) + if add_text: + plt.text( + start_time, + high_freq, + f"False Negative \nClass: {match.gt_class} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) return ax @@ -245,7 +248,10 @@ def plot_true_positive_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, + use_score: bool = True, + add_spectrogram: bool = True, add_points: bool = False, + add_text: bool = True, fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_TRUE_POSITIVE_COLOR, @@ -269,17 +275,18 @@ def plot_true_positive_match( recording=sound_event.recording, ) - ax = plot_clip( - clip, - audio_loader=audio_loader, - preprocessor=preprocessor, - figsize=figsize, - ax=ax, - audio_dir=audio_dir, - spec_cmap=spec_cmap, - ) + if add_spectrogram: + ax = plot_clip( + clip, + audio_loader=audio_loader, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + spec_cmap=spec_cmap, + ) - plot.plot_annotation( + ax = plot.plot_annotation( match.sound_event_annotation, ax=ax, time_offset=0.001, @@ -296,20 +303,21 @@ def plot_true_positive_match( ax=ax, add_points=add_points, facecolor="none" if not fill else None, - alpha=1, + alpha=match.pred_score if use_score else 1, color=color, linestyle=prediction_linestyle, ) - plt.text( - start_time, - high_freq, - f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", - va="top", - ha="right", - color=color, - fontsize=fontsize, - ) + if add_text: + plt.text( + start_time, + high_freq, + f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) return ax @@ -322,7 +330,10 @@ def plot_cross_trigger_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, + use_score: bool = True, + add_spectrogram: bool = True, add_points: bool = False, + add_text: bool = True, fill: bool = False, spec_cmap: str = "gray", color: str = DEFAULT_CROSS_TRIGGER_COLOR, @@ -346,17 +357,18 @@ def plot_cross_trigger_match( recording=sound_event.recording, ) - ax = plot_clip( - clip, - audio_loader=audio_loader, - preprocessor=preprocessor, - figsize=figsize, - ax=ax, - audio_dir=audio_dir, - spec_cmap=spec_cmap, - ) + if add_spectrogram: + ax = plot_clip( + clip, + audio_loader=audio_loader, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + spec_cmap=spec_cmap, + ) - plot.plot_annotation( + ax = plot.plot_annotation( match.sound_event_annotation, ax=ax, time_offset=0.001, @@ -368,24 +380,25 @@ def plot_cross_trigger_match( linestyle=annotation_linestyle, ) - plot.plot_geometry( + ax = plot.plot_geometry( match.pred_geometry, ax=ax, add_points=add_points, facecolor="none" if not fill else None, - alpha=1, + alpha=match.pred_score if use_score else 1, color=color, linestyle=prediction_linestyle, ) - plt.text( - start_time, - high_freq, - f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", - va="top", - ha="right", - color=color, - fontsize=fontsize, - ) + if add_text: + plt.text( + start_time, + high_freq, + f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) return ax diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index a5ec359..ad4fa27 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -4,6 +4,7 @@ from pydantic import Field from soundevent import data from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.evaluate.config import EvaluationConfig from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, @@ -82,6 +83,7 @@ class TrainingConfig(BaseConfig): trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) labels: LabelConfig = Field(default_factory=LabelConfig) + validation: EvaluationConfig = Field(default_factory=EvaluationConfig) def load_train_config( diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index 66344d1..5b1b8c6 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -1,4 +1,6 @@ import io +from collections.abc import Callable +from functools import partial from pathlib import Path from typing import ( Annotated, @@ -13,8 +15,14 @@ from typing import ( ) import numpy as np -from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger +from lightning.pytorch.loggers import ( + CSVLogger, + Logger, + MLFlowLogger, + TensorBoardLogger, +) from loguru import logger +from matplotlib.figure import Figure from pydantic import Field from soundevent import data @@ -231,18 +239,17 @@ def build_logger( ) -def get_image_plotter(logger: Logger): +Plotter = Callable[[str, Figure, int], None] + + +def get_image_plotter(logger: Logger) -> Optional[Plotter]: if isinstance(logger, TensorBoardLogger): - - def plot_figure(name, figure, step): - return logger.experiment.add_figure(name, figure, step) - - return plot_figure + return logger.experiment.add_figure if isinstance(logger, MLFlowLogger): def plot_figure(name, figure, step): - image = _convert_figure_to_image(figure) + image = _convert_figure_to_array(figure) return logger.experiment.log_image( logger.run_id, image, @@ -252,8 +259,20 @@ def get_image_plotter(logger: Logger): return plot_figure + if isinstance(logger, CSVLogger): + return partial(save_figure, dir=Path(logger.log_dir)) -def _convert_figure_to_image(figure): + +def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None: + path = dir / "plots" / f"{name}_step_{step}.png" + + if not path.parent.exists(): + path.parent.mkdir(parents=True) + + fig.savefig(path, transparent=True, bbox_inches="tight") + + +def _convert_figure_to_array(figure: Figure) -> np.ndarray: with io.BytesIO() as buff: figure.savefig(buff, format="raw") buff.seek(0) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 4247b2b..18fb248 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -12,7 +12,6 @@ from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets from batdetect2.train.callbacks import ValidationMetrics -from batdetect2.train.config import TrainingConfig from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module @@ -103,9 +102,9 @@ def train( ) trainer = trainer or build_trainer( - config.train, + config, targets=targets, - evaluator=build_evaluator(config.evaluation, targets=targets), + evaluator=build_evaluator(config.train.validation, targets=targets), checkpoint_dir=checkpoint_dir, log_dir=log_dir, experiment_name=experiment_name, @@ -151,7 +150,7 @@ def build_trainer_callbacks( def build_trainer( - conf: TrainingConfig, + conf: "BatDetect2Config", targets: "TargetProtocol", evaluator: Optional[Evaluator] = None, checkpoint_dir: Optional[Path] = None, @@ -159,13 +158,13 @@ def build_trainer( experiment_name: Optional[str] = None, run_name: Optional[str] = None, ) -> Trainer: - trainer_conf = conf.trainer + trainer_conf = conf.train.trainer logger.opt(lazy=True).debug( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) train_logger = build_logger( - conf.logger, + conf.train.logger, log_dir=log_dir, experiment_name=experiment_name, run_name=run_name, diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index e3bf7e0..fad09db 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -50,6 +50,26 @@ class MatchEvaluation: return self.pred_class_scores[pred_class] + def is_true_positive(self, threshold: float = 0) -> bool: + return ( + self.gt_det + and self.pred_score > threshold + and self.gt_class == self.pred_class + ) + + def is_false_positive(self, threshold: float = 0) -> bool: + return self.gt_det is None and self.pred_score > threshold + + def is_false_negative(self, threshold: float = 0) -> bool: + return self.gt_det and self.pred_score <= threshold + + def is_cross_trigger(self, threshold: float = 0) -> bool: + return ( + self.gt_det + and self.pred_score > threshold + and self.gt_class != self.pred_class + ) + @dataclass class ClipEvaluation: