mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add metrics and plots
This commit is contained in:
parent
6e217380f2
commit
b81a882b58
@ -56,7 +56,10 @@ class BatDetect2API:
|
||||
train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
targets=self.targets,
|
||||
config=self.config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
|
||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
||||
and serialization capabilities.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def to_yaml_string(
|
||||
self,
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.evaluate.metrics import (
|
||||
DetectionAPConfig,
|
||||
MetricConfig,
|
||||
)
|
||||
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
|
||||
from batdetect2.evaluate.plots import PlotConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
@ -20,18 +20,14 @@ __all__ = [
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
ignore_start_end: float = 0.01
|
||||
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||
metrics: List[MetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionAPConfig(),
|
||||
ClassificationAPConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[PlotConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ExampleGalleryConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[PlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
|
||||
@ -58,7 +58,7 @@ def evaluate(
|
||||
clip_annotations = []
|
||||
predictions = []
|
||||
|
||||
evaluator = build_evaluator(config=config)
|
||||
evaluator = build_evaluator(config=config, targets=targets)
|
||||
|
||||
for batch in loader:
|
||||
outputs = model.detector(batch.spec)
|
||||
|
||||
@ -138,7 +138,7 @@ def build_evaluator(
|
||||
) -> Evaluator:
|
||||
config = config or EvaluationConfig()
|
||||
targets = targets or build_targets()
|
||||
matcher = matcher or build_matcher(config.match)
|
||||
matcher = matcher or build_matcher(config.match_strategy)
|
||||
|
||||
if metrics is None:
|
||||
metrics = [
|
||||
@ -147,7 +147,10 @@ def build_evaluator(
|
||||
]
|
||||
|
||||
if plots is None:
|
||||
plots = [build_plotter(config) for config in config.plots]
|
||||
plots = [
|
||||
build_plotter(config, targets.class_names)
|
||||
for config in config.plots
|
||||
]
|
||||
|
||||
return Evaluator(
|
||||
config=config,
|
||||
|
||||
@ -111,7 +111,7 @@ def match(
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
name: Literal["start_time"] = "start_time"
|
||||
name: Literal["start_time_match"] = "start_time_match"
|
||||
distance_threshold: float = 0.01
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import (
|
||||
Annotated,
|
||||
@ -12,8 +13,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
from sklearn import metrics, preprocessing
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
@ -26,57 +26,18 @@ __all__ = ["DetectionAP", "ClassificationAP"]
|
||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
||||
|
||||
|
||||
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
|
||||
APImplementation = Literal["sklearn", "pascal_voc"]
|
||||
|
||||
|
||||
class DetectionAPConfig(BaseConfig):
|
||||
name: Literal["detection_ap"] = "detection_ap"
|
||||
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||
|
||||
|
||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
|
||||
num_positives = y_true.sum()
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c,
|
||||
np.finfo(np.float64).eps,
|
||||
)
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
|
||||
|
||||
_ap_impl_mapping: Mapping[
|
||||
AveragePrecisionImplementation, Callable[[Any, Any], float]
|
||||
] = {
|
||||
"sklearn": metrics.average_precision_score,
|
||||
"pascal_voc": pascal_voc_average_precision,
|
||||
}
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class DetectionAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
@ -102,9 +63,37 @@ class DetectionAP(MetricsProtocol):
|
||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||
|
||||
|
||||
class DetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
||||
|
||||
|
||||
class DetectionROCAUC(MetricsProtocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {"detection_ROC_AUC": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: DetectionROCAUCConfig, class_names: List[str]
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
|
||||
|
||||
|
||||
class ClassificationAPConfig(BaseConfig):
|
||||
name: Literal["classification_ap"] = "classification_ap"
|
||||
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
@ -113,7 +102,7 @@ class ClassificationAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
@ -164,7 +153,7 @@ class ClassificationAP(MetricsProtocol):
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
@ -203,11 +192,435 @@ class ClassificationAP(MetricsProtocol):
|
||||
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
||||
|
||||
|
||||
class ClassificationROCAUCConfig(BaseConfig):
|
||||
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationROCAUC(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_roc_auc)
|
||||
|
||||
mean_roc_auc = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
|
||||
return {
|
||||
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
|
||||
**{
|
||||
f"classification_ROC_AUC/{class_name}": class_scores[
|
||||
class_name
|
||||
]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
|
||||
|
||||
|
||||
class TopClassAPConfig(BaseConfig):
|
||||
name: Literal["top_class_ap"] = "top_class_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class TopClassAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
top_class = match.pred_class
|
||||
|
||||
y_true.append(top_class == match.gt_class)
|
||||
y_score.append(match.pred_class_score)
|
||||
|
||||
score = float(self.metric(y_true, y_score))
|
||||
return {"top_class_AP": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(TopClassAPConfig, TopClassAP)
|
||||
|
||||
|
||||
class ClassificationBalancedAccuracyConfig(BaseConfig):
|
||||
name: Literal["classification_balanced_accuracy"] = (
|
||||
"classification_balanced_accuracy"
|
||||
)
|
||||
|
||||
|
||||
class ClassificationBalancedAccuracy(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str]):
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
top_class = match.pred_class
|
||||
|
||||
# Focus on matches
|
||||
if match.gt_class is None or top_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(self.class_names.index(match.gt_class))
|
||||
y_pred.append(self.class_names.index(top_class))
|
||||
|
||||
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
|
||||
return {"classification_balanced_accuracy": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationBalancedAccuracyConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(class_names)
|
||||
|
||||
|
||||
metrics_registry.register(
|
||||
ClassificationBalancedAccuracyConfig,
|
||||
ClassificationBalancedAccuracy,
|
||||
)
|
||||
|
||||
|
||||
class ClipAPConfig(BaseConfig):
|
||||
name: Literal["clip_ap"] = "clip_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
implementation: APImplementation,
|
||||
include: Optional[Sequence[str]] = None,
|
||||
exclude: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
self.class_names = class_names
|
||||
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_classes = set()
|
||||
clip_scores = defaultdict(list)
|
||||
|
||||
for match in clip_eval.matches:
|
||||
if match.gt_class is not None:
|
||||
clip_classes.add(match.gt_class)
|
||||
|
||||
for class_name, score in match.pred_class_scores.items():
|
||||
clip_scores[class_name].append(score)
|
||||
|
||||
y_true.append(clip_classes)
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
# Get max score for each class
|
||||
max(clip_scores.get(class_name, [0]))
|
||||
for class_name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.MultiLabelBinarizer(
|
||||
classes=self.class_names
|
||||
).fit_transform(y_true)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_ap = self.metric(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_ap)
|
||||
|
||||
mean_ap = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_mAP": float(mean_ap),
|
||||
**{
|
||||
f"clip_AP/{class_name}": class_scores[class_name]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ClipAPConfig, class_names: List[str]):
|
||||
return cls(
|
||||
implementation=config.ap_implementation,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipAPConfig, ClipAP)
|
||||
|
||||
|
||||
class ClipROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_roc_auc"] = "clip_roc_auc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipROCAUC(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[Sequence[str]] = None,
|
||||
exclude: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_classes = set()
|
||||
clip_scores = defaultdict(list)
|
||||
|
||||
for match in clip_eval.matches:
|
||||
if match.gt_class is not None:
|
||||
clip_classes.add(match.gt_class)
|
||||
|
||||
for class_name, score in match.pred_class_scores.items():
|
||||
clip_scores[class_name].append(score)
|
||||
|
||||
y_true.append(clip_classes)
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
# Get maximum score for each class
|
||||
max(clip_scores.get(class_name, [0]))
|
||||
for class_name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.MultiLabelBinarizer(
|
||||
classes=self.class_names
|
||||
).fit_transform(y_true)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_roc_auc)
|
||||
|
||||
mean_roc_auc = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_macro_ROC_AUC": float(mean_roc_auc),
|
||||
**{
|
||||
f"clip_ROC_AUC/{class_name}": class_scores[class_name]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipROCAUCConfig, ClipROCAUC)
|
||||
|
||||
MetricConfig = Annotated[
|
||||
Union[ClassificationAPConfig, DetectionAPConfig],
|
||||
Union[
|
||||
DetectionAPConfig,
|
||||
DetectionROCAUCConfig,
|
||||
ClassificationAPConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
TopClassAPConfig,
|
||||
ClassificationBalancedAccuracyConfig,
|
||||
ClipAPConfig,
|
||||
ClipROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_metric(config: MetricConfig, class_names: List[str]):
|
||||
return metrics_registry.build(config, class_names)
|
||||
|
||||
|
||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
|
||||
num_positives = y_true.sum()
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c,
|
||||
np.finfo(np.float64).eps,
|
||||
)
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
|
||||
|
||||
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
|
||||
"sklearn": metrics.average_precision_score,
|
||||
"pascal_voc": pascal_voc_average_precision,
|
||||
}
|
||||
|
||||
@ -4,13 +4,18 @@ from dataclasses import dataclass, field
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.matches import plot_matches
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
@ -26,12 +31,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
|
||||
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
|
||||
|
||||
|
||||
class ExampleGalleryConfig(BaseConfig):
|
||||
name: Literal["example_gallery"] = "example_gallery"
|
||||
examples_per_class: int = 5
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
@ -87,9 +93,12 @@ class ExampleGallery(PlotterProtocol):
|
||||
plt.close(fig)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ExampleGalleryConfig):
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio_transforms)
|
||||
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
|
||||
audio_loader = build_audio_loader(config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return cls(
|
||||
examples_per_class=config.examples_per_class,
|
||||
preprocessor=preprocessor,
|
||||
@ -100,13 +109,345 @@ class ExampleGallery(PlotterProtocol):
|
||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
||||
|
||||
|
||||
class ClipEvaluationPlotConfig(BaseConfig):
|
||||
name: Literal["example_clip"] = "example_clip"
|
||||
num_plots: int = 5
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class PlotClipEvaluation(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
num_plots: int = 3,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
):
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.num_plots = num_plots
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
examples = random.sample(
|
||||
clip_evaluations,
|
||||
k=min(self.num_plots, len(clip_evaluations)),
|
||||
)
|
||||
|
||||
for index, clip_evaluation in enumerate(examples):
|
||||
fig, ax = plt.subplots()
|
||||
plot_matches(
|
||||
clip_evaluation.matches,
|
||||
clip=clip_evaluation.clip,
|
||||
audio_loader=self.audio_loader,
|
||||
ax=ax,
|
||||
)
|
||||
yield f"clip_evaluation/example_{index}", fig
|
||||
plt.close(fig)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipEvaluationPlotConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
audio_loader = build_audio_loader(config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return cls(
|
||||
num_plots=config.num_plots,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
|
||||
|
||||
|
||||
class DetectionPRCurveConfig(BaseConfig):
|
||||
name: Literal["detection_pr_curve"] = "detection_pr_curve"
|
||||
|
||||
|
||||
class DetectionPRCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(recall, precision, label="Detector")
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.legend()
|
||||
|
||||
yield "detection_pr_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: DetectionPRCurveConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
|
||||
|
||||
|
||||
class ClassificationPRCurvesConfig(BaseConfig):
|
||||
name: Literal["classification_pr_curves"] = "classification_pr_curves"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationPRCurves(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
precision, recall, _ = metrics.precision_recall_curve(
|
||||
y_true_class,
|
||||
y_pred_class,
|
||||
)
|
||||
ax.plot(recall, precision, label=class_name)
|
||||
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
yield "classification_pr_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationPRCurvesConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names=class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
|
||||
|
||||
|
||||
class DetectionROCCurveConfig(BaseConfig):
|
||||
name: Literal["detection_roc_curve"] = "detection_roc_curve"
|
||||
|
||||
|
||||
class DetectionROCCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(fpr, tpr, label="Detection")
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
ax.legend()
|
||||
|
||||
yield "detection_roc_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: DetectionROCCurveConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
|
||||
|
||||
|
||||
class ClassificationROCCurvesConfig(BaseConfig):
|
||||
name: Literal["classification_roc_curves"] = "classification_roc_curves"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationROCCurves(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_roced_class = y_pred[:, class_index]
|
||||
fpr, tpr, _ = metrics.roc_curve(
|
||||
y_true_class,
|
||||
y_roced_class,
|
||||
)
|
||||
ax.plot(fpr, tpr, label=class_name)
|
||||
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
yield "classification_roc_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationROCCurvesConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names=class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
||||
|
||||
|
||||
PlotConfig = Annotated[
|
||||
Union[ExampleGalleryConfig,], Field(discriminator="name")
|
||||
Union[
|
||||
ExampleGalleryConfig,
|
||||
ClipEvaluationPlotConfig,
|
||||
DetectionPRCurveConfig,
|
||||
ClassificationPRCurvesConfig,
|
||||
DetectionROCCurveConfig,
|
||||
ClassificationROCCurvesConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_plotter(config: PlotConfig) -> PlotterProtocol:
|
||||
return plots_registry.build(config)
|
||||
def build_plotter(
|
||||
config: PlotConfig, class_names: List[str]
|
||||
) -> PlotterProtocol:
|
||||
return plots_registry.build(config, class_names)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -6,7 +6,6 @@ from soundevent import data, plot
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
||||
|
||||
@ -29,7 +28,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||
|
||||
|
||||
def plot_matches(
|
||||
matches: List[data.Match],
|
||||
matches: List[MatchEvaluation],
|
||||
clip: data.Clip,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
@ -43,8 +42,7 @@ def plot_matches(
|
||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
) -> Axes:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
@ -60,52 +58,48 @@ def plot_matches(
|
||||
color_mapper = TagColorMapper()
|
||||
|
||||
for match in matches:
|
||||
if match.source is None and match.target is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
if match.is_cross_trigger():
|
||||
plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
fill=fill,
|
||||
add_points=add_points,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
color=cross_trigger_color,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.is_true_positive():
|
||||
plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
add_points=add_points,
|
||||
color=true_positive_color,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.is_false_negative():
|
||||
plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_negative_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.target is None and match.source is not None:
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
elif match.is_false_positive:
|
||||
plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
elif match.target is not None and match.source is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
add_text=False,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
@ -121,6 +115,9 @@ def plot_false_positive_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_text: bool = True,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
@ -141,34 +138,36 @@ def plot_false_positive_match(
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_geometry(
|
||||
ax = plot.plot_geometry(
|
||||
match.pred_geometry,
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -181,7 +180,9 @@ def plot_false_negative_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
@ -203,17 +204,18 @@ def plot_false_negative_match(
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
@ -224,15 +226,16 @@ def plot_false_negative_match(
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Negative \nClass: {match.gt_class} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Negative \nClass: {match.gt_class} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -245,7 +248,10 @@ def plot_true_positive_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
@ -269,17 +275,18 @@ def plot_true_positive_match(
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
@ -296,20 +303,21 @@ def plot_true_positive_match(
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -322,7 +330,10 @@ def plot_cross_trigger_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
@ -346,17 +357,18 @@ def plot_cross_trigger_match(
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
@ -368,24 +380,25 @@ def plot_cross_trigger_match(
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
|
||||
plot.plot_geometry(
|
||||
ax = plot.plot_geometry(
|
||||
match.pred_geometry,
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
@ -82,6 +83,7 @@ class TrainingConfig(BaseConfig):
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Annotated,
|
||||
@ -13,8 +15,14 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
Logger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
from loguru import logger
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -231,18 +239,17 @@ def build_logger(
|
||||
)
|
||||
|
||||
|
||||
def get_image_plotter(logger: Logger):
|
||||
Plotter = Callable[[str, Figure, int], None]
|
||||
|
||||
|
||||
def get_image_plotter(logger: Logger) -> Optional[Plotter]:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
return logger.experiment.add_figure(name, figure, step)
|
||||
|
||||
return plot_figure
|
||||
return logger.experiment.add_figure
|
||||
|
||||
if isinstance(logger, MLFlowLogger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
image = _convert_figure_to_image(figure)
|
||||
image = _convert_figure_to_array(figure)
|
||||
return logger.experiment.log_image(
|
||||
logger.run_id,
|
||||
image,
|
||||
@ -252,8 +259,20 @@ def get_image_plotter(logger: Logger):
|
||||
|
||||
return plot_figure
|
||||
|
||||
if isinstance(logger, CSVLogger):
|
||||
return partial(save_figure, dir=Path(logger.log_dir))
|
||||
|
||||
def _convert_figure_to_image(figure):
|
||||
|
||||
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
||||
path = dir / "plots" / f"{name}_step_{step}.png"
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
fig.savefig(path, transparent=True, bbox_inches="tight")
|
||||
|
||||
|
||||
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||
with io.BytesIO() as buff:
|
||||
figure.savefig(buff, format="raw")
|
||||
buff.seek(0)
|
||||
|
||||
@ -12,7 +12,6 @@ from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
@ -103,9 +102,9 @@ def train(
|
||||
)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
config.train,
|
||||
config,
|
||||
targets=targets,
|
||||
evaluator=build_evaluator(config.evaluation, targets=targets),
|
||||
evaluator=build_evaluator(config.train.validation, targets=targets),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
@ -151,7 +150,7 @@ def build_trainer_callbacks(
|
||||
|
||||
|
||||
def build_trainer(
|
||||
conf: TrainingConfig,
|
||||
conf: "BatDetect2Config",
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional[Evaluator] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
@ -159,13 +158,13 @@ def build_trainer(
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = conf.trainer
|
||||
trainer_conf = conf.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
train_logger = build_logger(
|
||||
conf.logger,
|
||||
conf.train.logger,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
|
||||
@ -50,6 +50,26 @@ class MatchEvaluation:
|
||||
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
def is_true_positive(self, threshold: float = 0) -> bool:
|
||||
return (
|
||||
self.gt_det
|
||||
and self.pred_score > threshold
|
||||
and self.gt_class == self.pred_class
|
||||
)
|
||||
|
||||
def is_false_positive(self, threshold: float = 0) -> bool:
|
||||
return self.gt_det is None and self.pred_score > threshold
|
||||
|
||||
def is_false_negative(self, threshold: float = 0) -> bool:
|
||||
return self.gt_det and self.pred_score <= threshold
|
||||
|
||||
def is_cross_trigger(self, threshold: float = 0) -> bool:
|
||||
return (
|
||||
self.gt_det
|
||||
and self.pred_score > threshold
|
||||
and self.gt_class != self.pred_class
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEvaluation:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user