Add metrics and plots

This commit is contained in:
mbsantiago 2025-09-17 10:30:24 +01:00
parent 6e217380f2
commit b81a882b58
13 changed files with 1014 additions and 205 deletions

View File

@ -56,7 +56,10 @@ class BatDetect2API:
train(
train_annotations=train_annotations,
val_annotations=val_annotations,
targets=self.targets,
config=self.config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,

View File

@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
and serialization capabilities.
"""
model_config = ConfigDict(extra="ignore")
model_config = ConfigDict(extra="forbid")
def to_yaml_string(
self,

View File

@ -10,7 +10,7 @@ from batdetect2.evaluate.metrics import (
DetectionAPConfig,
MetricConfig,
)
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
from batdetect2.evaluate.plots import PlotConfig
__all__ = [
"EvaluationConfig",
@ -20,18 +20,14 @@ __all__ = [
class EvaluationConfig(BaseConfig):
ignore_start_end: float = 0.01
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
metrics: List[MetricConfig] = Field(
default_factory=lambda: [
DetectionAPConfig(),
ClassificationAPConfig(),
]
)
plots: List[PlotConfig] = Field(
default_factory=lambda: [
ExampleGalleryConfig(),
]
)
plots: List[PlotConfig] = Field(default_factory=list)
def load_evaluation_config(

View File

@ -58,7 +58,7 @@ def evaluate(
clip_annotations = []
predictions = []
evaluator = build_evaluator(config=config)
evaluator = build_evaluator(config=config, targets=targets)
for batch in loader:
outputs = model.detector(batch.spec)

View File

@ -138,7 +138,7 @@ def build_evaluator(
) -> Evaluator:
config = config or EvaluationConfig()
targets = targets or build_targets()
matcher = matcher or build_matcher(config.match)
matcher = matcher or build_matcher(config.match_strategy)
if metrics is None:
metrics = [
@ -147,7 +147,10 @@ def build_evaluator(
]
if plots is None:
plots = [build_plotter(config) for config in config.plots]
plots = [
build_plotter(config, targets.class_names)
for config in config.plots
]
return Evaluator(
config=config,

View File

@ -111,7 +111,7 @@ def match(
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time"] = "start_time"
name: Literal["start_time_match"] = "start_time_match"
distance_threshold: float = 0.01

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Callable, Mapping
from typing import (
Annotated,
@ -12,8 +13,7 @@ from typing import (
import numpy as np
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from sklearn import metrics, preprocessing
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
@ -26,57 +26,18 @@ __all__ = ["DetectionAP", "ClassificationAP"]
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
APImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[
AveragePrecisionImplementation, Callable[[Any, Any], float]
] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}
ap_implementation: APImplementation = "pascal_voc"
class DetectionAP(MetricsProtocol):
def __init__(
self,
implementation: AveragePrecisionImplementation = "pascal_voc",
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
@ -102,9 +63,37 @@ class DetectionAP(MetricsProtocol):
metrics_registry.register(DetectionAPConfig, DetectionAP)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["detection_roc_auc"] = "detection_roc_auc"
class DetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(metrics.roc_auc_score(y_true, y_score))
return {"detection_ROC_AUC": score}
@classmethod
def from_config(
cls, config: DetectionROCAUCConfig, class_names: List[str]
):
return cls()
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
@ -113,7 +102,7 @@ class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: AveragePrecisionImplementation = "pascal_voc",
implementation: APImplementation = "pascal_voc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
@ -164,7 +153,7 @@ class ClassificationAP(MetricsProtocol):
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
@ -203,11 +192,435 @@ class ClassificationAP(MetricsProtocol):
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
class ClassificationROCAUCConfig(BaseConfig):
name: Literal["classification_roc_auc"] = "classification_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
**{
f"classification_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationROCAUCConfig,
class_names: List[str],
):
return cls(
class_names,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
class TopClassAPConfig(BaseConfig):
name: Literal["top_class_ap"] = "top_class_ap"
ap_implementation: APImplementation = "pascal_voc"
class TopClassAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
top_class = match.pred_class
y_true.append(top_class == match.gt_class)
y_score.append(match.pred_class_score)
score = float(self.metric(y_true, y_score))
return {"top_class_AP": score}
@classmethod
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(TopClassAPConfig, TopClassAP)
class ClassificationBalancedAccuracyConfig(BaseConfig):
name: Literal["classification_balanced_accuracy"] = (
"classification_balanced_accuracy"
)
class ClassificationBalancedAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
top_class = match.pred_class
# Focus on matches
if match.gt_class is None or top_class is None:
continue
y_true.append(self.class_names.index(match.gt_class))
y_pred.append(self.class_names.index(top_class))
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
return {"classification_balanced_accuracy": score}
@classmethod
def from_config(
cls,
config: ClassificationBalancedAccuracyConfig,
class_names: List[str],
):
return cls(class_names)
metrics_registry.register(
ClassificationBalancedAccuracyConfig,
ClassificationBalancedAccuracy,
)
class ClipAPConfig(BaseConfig):
name: Literal["clip_ap"] = "clip_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation,
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get max score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_mAP": float(mean_ap),
**{
f"clip_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(cls, config: ClipAPConfig, class_names: List[str]):
return cls(
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipAPConfig, ClipAP)
class ClipROCAUCConfig(BaseConfig):
name: Literal["clip_roc_auc"] = "clip_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get maximum score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_macro_ROC_AUC": float(mean_roc_auc),
**{
f"clip_ROC_AUC/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClipROCAUCConfig,
class_names: List[str],
):
return cls(
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipROCAUCConfig, ClipROCAUC)
MetricConfig = Annotated[
Union[ClassificationAPConfig, DetectionAPConfig],
Union[
DetectionAPConfig,
DetectionROCAUCConfig,
ClassificationAPConfig,
ClassificationROCAUCConfig,
TopClassAPConfig,
ClassificationBalancedAccuracyConfig,
ClipAPConfig,
ClipROCAUCConfig,
],
Field(discriminator="name"),
]
def build_metric(config: MetricConfig, class_names: List[str]):
return metrics_registry.build(config, class_names)
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}

View File

@ -4,13 +4,18 @@ from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing.evaluate import (
ClipEvaluation,
@ -26,12 +31,13 @@ __all__ = [
]
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
class ExampleGalleryConfig(BaseConfig):
name: Literal["example_gallery"] = "example_gallery"
examples_per_class: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
@ -87,9 +93,12 @@ class ExampleGallery(PlotterProtocol):
plt.close(fig)
@classmethod
def from_config(cls, config: ExampleGalleryConfig):
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio_transforms)
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,
@ -100,13 +109,345 @@ class ExampleGallery(PlotterProtocol):
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
class ClipEvaluationPlotConfig(BaseConfig):
name: Literal["example_clip"] = "example_clip"
num_plots: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class PlotClipEvaluation(PlotterProtocol):
def __init__(
self,
num_plots: int = 3,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.num_plots = num_plots
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
examples = random.sample(
clip_evaluations,
k=min(self.num_plots, len(clip_evaluations)),
)
for index, clip_evaluation in enumerate(examples):
fig, ax = plt.subplots()
plot_matches(
clip_evaluation.matches,
clip=clip_evaluation.clip,
audio_loader=self.audio_loader,
ax=ax,
)
yield f"clip_evaluation/example_{index}", fig
plt.close(fig)
@classmethod
def from_config(
cls,
config: ClipEvaluationPlotConfig,
class_names: List[str],
):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
num_plots=config.num_plots,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
class DetectionPRCurveConfig(BaseConfig):
name: Literal["detection_pr_curve"] = "detection_pr_curve"
class DetectionPRCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(recall, precision, label="Detector")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend()
yield "detection_pr_curve", fig
@classmethod
def from_config(
cls,
config: DetectionPRCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
class ClassificationPRCurvesConfig(BaseConfig):
name: Literal["classification_pr_curves"] = "classification_pr_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationPRCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
precision, recall, _ = metrics.precision_recall_curve(
y_true_class,
y_pred_class,
)
ax.plot(recall, precision, label=class_name)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_pr_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationPRCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
class DetectionROCCurveConfig(BaseConfig):
name: Literal["detection_roc_curve"] = "detection_roc_curve"
class DetectionROCCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, label="Detection")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
yield "detection_roc_curve", fig
@classmethod
def from_config(
cls,
config: DetectionROCCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
class ClassificationROCCurvesConfig(BaseConfig):
name: Literal["classification_roc_curves"] = "classification_roc_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_roced_class = y_pred[:, class_index]
fpr, tpr, _ = metrics.roc_curve(
y_true_class,
y_roced_class,
)
ax.plot(fpr, tpr, label=class_name)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_roc_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationROCCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
PlotConfig = Annotated[
Union[ExampleGalleryConfig,], Field(discriminator="name")
Union[
ExampleGalleryConfig,
ClipEvaluationPlotConfig,
DetectionPRCurveConfig,
ClassificationPRCurvesConfig,
DetectionROCCurveConfig,
ClassificationROCCurvesConfig,
],
Field(discriminator="name"),
]
def build_plotter(config: PlotConfig) -> PlotterProtocol:
return plots_registry.build(config)
def build_plotter(
config: PlotConfig, class_names: List[str]
) -> PlotterProtocol:
return plots_registry.build(config, class_names)
@dataclass

View File

@ -6,7 +6,6 @@ from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
@ -29,7 +28,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[data.Match],
matches: List[MatchEvaluation],
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
@ -43,8 +42,7 @@ def plot_matches(
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
) -> Axes:
ax = plot_clip(
clip,
@ -60,52 +58,48 @@ def plot_matches(
color_mapper = TagColorMapper()
for match in matches:
if match.source is None and match.target is not None:
plot.plot_annotation(
annotation=match.target,
if match.is_cross_trigger():
plot_cross_trigger_match(
match,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
fill=fill,
add_points=add_points,
add_spectrogram=False,
use_score=True,
color=cross_trigger_color,
add_text=False,
)
elif match.is_true_positive():
plot_true_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=true_positive_color,
add_text=False,
)
elif match.is_false_negative():
plot_false_negative_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_negative_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
add_text=False,
)
elif match.target is None and match.source is not None:
plot_prediction(
prediction=match.source,
elif match.is_false_positive:
plot_false_positive_match(
match,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
elif match.target is not None and match.source is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
add_text=False,
)
else:
continue
@ -121,6 +115,9 @@ def plot_false_positive_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
add_text: bool = True,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
@ -141,34 +138,36 @@ def plot_false_positive_match(
recording=match.clip.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_geometry(
ax = plot.plot_geometry(
match.pred_geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.pred_score if use_score else 1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
@ -181,7 +180,9 @@ def plot_false_negative_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
@ -203,17 +204,18 @@ def plot_false_negative_match(
recording=sound_event.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
@ -224,15 +226,16 @@ def plot_false_negative_match(
color=color,
)
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
@ -245,7 +248,10 @@ def plot_true_positive_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
@ -269,17 +275,18 @@ def plot_true_positive_match(
recording=sound_event.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
@ -296,20 +303,21 @@ def plot_true_positive_match(
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.pred_score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
@ -322,7 +330,10 @@ def plot_cross_trigger_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
@ -346,17 +357,18 @@ def plot_cross_trigger_match(
recording=sound_event.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
@ -368,24 +380,25 @@ def plot_cross_trigger_match(
linestyle=annotation_linestyle,
)
plot.plot_geometry(
ax = plot.plot_geometry(
match.pred_geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.pred_score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax

View File

@ -4,6 +4,7 @@ from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
@ -82,6 +83,7 @@ class TrainingConfig(BaseConfig):
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_train_config(

View File

@ -1,4 +1,6 @@
import io
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import (
Annotated,
@ -13,8 +15,14 @@ from typing import (
)
import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from loguru import logger
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
@ -231,18 +239,17 @@ def build_logger(
)
def get_image_plotter(logger: Logger):
Plotter = Callable[[str, Figure, int], None]
def get_image_plotter(logger: Logger) -> Optional[Plotter]:
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
return logger.experiment.add_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
image = _convert_figure_to_array(figure)
return logger.experiment.log_image(
logger.run_id,
image,
@ -252,8 +259,20 @@ def get_image_plotter(logger: Logger):
return plot_figure
if isinstance(logger, CSVLogger):
return partial(save_figure, dir=Path(logger.log_dir))
def _convert_figure_to_image(figure):
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
path = dir / "plots" / f"{name}_step_{step}.png"
if not path.parent.exists():
path.parent.mkdir(parents=True)
fig.savefig(path, transparent=True, bbox_inches="tight")
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)

View File

@ -12,7 +12,6 @@ from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
@ -103,9 +102,9 @@ def train(
)
trainer = trainer or build_trainer(
config.train,
config,
targets=targets,
evaluator=build_evaluator(config.evaluation, targets=targets),
evaluator=build_evaluator(config.train.validation, targets=targets),
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
@ -151,7 +150,7 @@ def build_trainer_callbacks(
def build_trainer(
conf: TrainingConfig,
conf: "BatDetect2Config",
targets: "TargetProtocol",
evaluator: Optional[Evaluator] = None,
checkpoint_dir: Optional[Path] = None,
@ -159,13 +158,13 @@ def build_trainer(
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Trainer:
trainer_conf = conf.trainer
trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(
conf.logger,
conf.train.logger,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,

View File

@ -50,6 +50,26 @@ class MatchEvaluation:
return self.pred_class_scores[pred_class]
def is_true_positive(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class == self.pred_class
)
def is_false_positive(self, threshold: float = 0) -> bool:
return self.gt_det is None and self.pred_score > threshold
def is_false_negative(self, threshold: float = 0) -> bool:
return self.gt_det and self.pred_score <= threshold
def is_cross_trigger(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class != self.pred_class
)
@dataclass
class ClipEvaluation: