mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Add metrics and plots
This commit is contained in:
parent
6e217380f2
commit
b81a882b58
@ -56,7 +56,10 @@ class BatDetect2API:
|
|||||||
train(
|
train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
|
targets=self.targets,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
|||||||
and serialization capabilities.
|
and serialization capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
def to_yaml_string(
|
def to_yaml_string(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.evaluate.metrics import (
|
|||||||
DetectionAPConfig,
|
DetectionAPConfig,
|
||||||
MetricConfig,
|
MetricConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
|
from batdetect2.evaluate.plots import PlotConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
@ -20,18 +20,14 @@ __all__ = [
|
|||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
ignore_start_end: float = 0.01
|
ignore_start_end: float = 0.01
|
||||||
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||||
metrics: List[MetricConfig] = Field(
|
metrics: List[MetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
DetectionAPConfig(),
|
DetectionAPConfig(),
|
||||||
ClassificationAPConfig(),
|
ClassificationAPConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: List[PlotConfig] = Field(
|
plots: List[PlotConfig] = Field(default_factory=list)
|
||||||
default_factory=lambda: [
|
|
||||||
ExampleGalleryConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
|
|||||||
@ -58,7 +58,7 @@ def evaluate(
|
|||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config)
|
evaluator = build_evaluator(config=config, targets=targets)
|
||||||
|
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
outputs = model.detector(batch.spec)
|
outputs = model.detector(batch.spec)
|
||||||
|
|||||||
@ -138,7 +138,7 @@ def build_evaluator(
|
|||||||
) -> Evaluator:
|
) -> Evaluator:
|
||||||
config = config or EvaluationConfig()
|
config = config or EvaluationConfig()
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
matcher = matcher or build_matcher(config.match)
|
matcher = matcher or build_matcher(config.match_strategy)
|
||||||
|
|
||||||
if metrics is None:
|
if metrics is None:
|
||||||
metrics = [
|
metrics = [
|
||||||
@ -147,7 +147,10 @@ def build_evaluator(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if plots is None:
|
if plots is None:
|
||||||
plots = [build_plotter(config) for config in config.plots]
|
plots = [
|
||||||
|
build_plotter(config, targets.class_names)
|
||||||
|
for config in config.plots
|
||||||
|
]
|
||||||
|
|
||||||
return Evaluator(
|
return Evaluator(
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
@ -111,7 +111,7 @@ def match(
|
|||||||
|
|
||||||
|
|
||||||
class StartTimeMatchConfig(BaseConfig):
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
name: Literal["start_time"] = "start_time"
|
name: Literal["start_time_match"] = "start_time_match"
|
||||||
distance_threshold: float = 0.01
|
distance_threshold: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
@ -12,8 +13,7 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics, preprocessing
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.core.registries import Registry
|
||||||
@ -26,57 +26,18 @@ __all__ = ["DetectionAP", "ClassificationAP"]
|
|||||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
||||||
|
|
||||||
|
|
||||||
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
|
APImplementation = Literal["sklearn", "pascal_voc"]
|
||||||
|
|
||||||
|
|
||||||
class DetectionAPConfig(BaseConfig):
|
class DetectionAPConfig(BaseConfig):
|
||||||
name: Literal["detection_ap"] = "detection_ap"
|
name: Literal["detection_ap"] = "detection_ap"
|
||||||
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
|
||||||
y_true = np.array(y_true)
|
|
||||||
y_score = np.array(y_score)
|
|
||||||
|
|
||||||
sort_ind = np.argsort(y_score)[::-1]
|
|
||||||
y_true_sorted = y_true[sort_ind]
|
|
||||||
|
|
||||||
num_positives = y_true.sum()
|
|
||||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
|
||||||
true_pos_c = np.cumsum(y_true_sorted)
|
|
||||||
|
|
||||||
recall = true_pos_c / num_positives
|
|
||||||
precision = true_pos_c / np.maximum(
|
|
||||||
true_pos_c + false_pos_c,
|
|
||||||
np.finfo(np.float64).eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
|
||||||
recall[np.isnan(recall)] = 0
|
|
||||||
|
|
||||||
# pascal 12 way
|
|
||||||
mprec = np.hstack((0, precision, 0))
|
|
||||||
mrec = np.hstack((0, recall, 1))
|
|
||||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
|
||||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
|
||||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
|
||||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
|
||||||
|
|
||||||
return ave_prec
|
|
||||||
|
|
||||||
|
|
||||||
_ap_impl_mapping: Mapping[
|
|
||||||
AveragePrecisionImplementation, Callable[[Any, Any], float]
|
|
||||||
] = {
|
|
||||||
"sklearn": metrics.average_precision_score,
|
|
||||||
"pascal_voc": pascal_voc_average_precision,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAP(MetricsProtocol):
|
class DetectionAP(MetricsProtocol):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
implementation: APImplementation = "pascal_voc",
|
||||||
):
|
):
|
||||||
self.implementation = implementation
|
self.implementation = implementation
|
||||||
self.metric = _ap_impl_mapping[self.implementation]
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
@ -102,9 +63,37 @@ class DetectionAP(MetricsProtocol):
|
|||||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUC(MetricsProtocol):
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[
|
||||||
|
(match.gt_det, match.pred_score)
|
||||||
|
for clip_eval in clip_evaluations
|
||||||
|
for match in clip_eval.matches
|
||||||
|
]
|
||||||
|
)
|
||||||
|
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
return {"detection_ROC_AUC": score}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls, config: DetectionROCAUCConfig, class_names: List[str]
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAPConfig(BaseConfig):
|
class ClassificationAPConfig(BaseConfig):
|
||||||
name: Literal["classification_ap"] = "classification_ap"
|
name: Literal["classification_ap"] = "classification_ap"
|
||||||
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
include: Optional[List[str]] = None
|
include: Optional[List[str]] = None
|
||||||
exclude: Optional[List[str]] = None
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
@ -113,7 +102,7 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
implementation: APImplementation = "pascal_voc",
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
@ -164,7 +153,7 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
y_true = label_binarize(y_true, classes=self.class_names)
|
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||||
y_pred = np.stack(y_pred)
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
class_scores = {}
|
class_scores = {}
|
||||||
@ -203,11 +192,435 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUC(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else "__NONE__"
|
||||||
|
)
|
||||||
|
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
class_scores = {}
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||||
|
class_scores[class_name] = float(class_roc_auc)
|
||||||
|
|
||||||
|
mean_roc_auc = np.mean(
|
||||||
|
[value for value in class_scores.values() if value != 0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
|
||||||
|
**{
|
||||||
|
f"classification_ROC_AUC/{class_name}": class_scores[
|
||||||
|
class_name
|
||||||
|
]
|
||||||
|
for class_name in self.selected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationROCAUCConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
class_names,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAPConfig(BaseConfig):
|
||||||
|
name: Literal["top_class_ap"] = "top_class_ap"
|
||||||
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAP(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
implementation: APImplementation = "pascal_voc",
|
||||||
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
top_class = match.pred_class
|
||||||
|
|
||||||
|
y_true.append(top_class == match.gt_class)
|
||||||
|
y_score.append(match.pred_class_score)
|
||||||
|
|
||||||
|
score = float(self.metric(y_true, y_score))
|
||||||
|
return {"top_class_AP": score}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
|
||||||
|
return cls(implementation=config.ap_implementation)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(TopClassAPConfig, TopClassAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationBalancedAccuracyConfig(BaseConfig):
|
||||||
|
name: Literal["classification_balanced_accuracy"] = (
|
||||||
|
"classification_balanced_accuracy"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationBalancedAccuracy(MetricsProtocol):
|
||||||
|
def __init__(self, class_names: List[str]):
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
top_class = match.pred_class
|
||||||
|
|
||||||
|
# Focus on matches
|
||||||
|
if match.gt_class is None or top_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(self.class_names.index(match.gt_class))
|
||||||
|
y_pred.append(self.class_names.index(top_class))
|
||||||
|
|
||||||
|
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
|
||||||
|
return {"classification_balanced_accuracy": score}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationBalancedAccuracyConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(class_names)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(
|
||||||
|
ClassificationBalancedAccuracyConfig,
|
||||||
|
ClassificationBalancedAccuracy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipAPConfig(BaseConfig):
|
||||||
|
name: Literal["clip_ap"] = "clip_ap"
|
||||||
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClipAP(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
implementation: APImplementation,
|
||||||
|
include: Optional[Sequence[str]] = None,
|
||||||
|
exclude: Optional[Sequence[str]] = None,
|
||||||
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_classes = set()
|
||||||
|
clip_scores = defaultdict(list)
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
if match.gt_class is not None:
|
||||||
|
clip_classes.add(match.gt_class)
|
||||||
|
|
||||||
|
for class_name, score in match.pred_class_scores.items():
|
||||||
|
clip_scores[class_name].append(score)
|
||||||
|
|
||||||
|
y_true.append(clip_classes)
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
# Get max score for each class
|
||||||
|
max(clip_scores.get(class_name, [0]))
|
||||||
|
for class_name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = preprocessing.MultiLabelBinarizer(
|
||||||
|
classes=self.class_names
|
||||||
|
).fit_transform(y_true)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
class_scores = {}
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
class_ap = self.metric(y_true_class, y_pred_class)
|
||||||
|
class_scores[class_name] = float(class_ap)
|
||||||
|
|
||||||
|
mean_ap = np.mean(
|
||||||
|
[value for value in class_scores.values() if value != 0]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"clip_mAP": float(mean_ap),
|
||||||
|
**{
|
||||||
|
f"clip_AP/{class_name}": class_scores[class_name]
|
||||||
|
for class_name in self.selected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: ClipAPConfig, class_names: List[str]):
|
||||||
|
return cls(
|
||||||
|
implementation=config.ap_implementation,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
class_names=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipAPConfig, ClipAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["clip_roc_auc"] = "clip_roc_auc"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClipROCAUC(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[Sequence[str]] = None,
|
||||||
|
exclude: Optional[Sequence[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_classes = set()
|
||||||
|
clip_scores = defaultdict(list)
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
if match.gt_class is not None:
|
||||||
|
clip_classes.add(match.gt_class)
|
||||||
|
|
||||||
|
for class_name, score in match.pred_class_scores.items():
|
||||||
|
clip_scores[class_name].append(score)
|
||||||
|
|
||||||
|
y_true.append(clip_classes)
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
# Get maximum score for each class
|
||||||
|
max(clip_scores.get(class_name, [0]))
|
||||||
|
for class_name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = preprocessing.MultiLabelBinarizer(
|
||||||
|
classes=self.class_names
|
||||||
|
).fit_transform(y_true)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
class_scores = {}
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||||
|
class_scores[class_name] = float(class_roc_auc)
|
||||||
|
|
||||||
|
mean_roc_auc = np.mean(
|
||||||
|
[value for value in class_scores.values() if value != 0]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"clip_macro_ROC_AUC": float(mean_roc_auc),
|
||||||
|
**{
|
||||||
|
f"clip_ROC_AUC/{class_name}": class_scores[class_name]
|
||||||
|
for class_name in self.selected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipROCAUCConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
class_names=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipROCAUCConfig, ClipROCAUC)
|
||||||
|
|
||||||
MetricConfig = Annotated[
|
MetricConfig = Annotated[
|
||||||
Union[ClassificationAPConfig, DetectionAPConfig],
|
Union[
|
||||||
|
DetectionAPConfig,
|
||||||
|
DetectionROCAUCConfig,
|
||||||
|
ClassificationAPConfig,
|
||||||
|
ClassificationROCAUCConfig,
|
||||||
|
TopClassAPConfig,
|
||||||
|
ClassificationBalancedAccuracyConfig,
|
||||||
|
ClipAPConfig,
|
||||||
|
ClipROCAUCConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_metric(config: MetricConfig, class_names: List[str]):
|
def build_metric(config: MetricConfig, class_names: List[str]):
|
||||||
return metrics_registry.build(config, class_names)
|
return metrics_registry.build(config, class_names)
|
||||||
|
|
||||||
|
|
||||||
|
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||||
|
y_true = np.array(y_true)
|
||||||
|
y_score = np.array(y_score)
|
||||||
|
|
||||||
|
sort_ind = np.argsort(y_score)[::-1]
|
||||||
|
y_true_sorted = y_true[sort_ind]
|
||||||
|
|
||||||
|
num_positives = y_true.sum()
|
||||||
|
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||||
|
true_pos_c = np.cumsum(y_true_sorted)
|
||||||
|
|
||||||
|
recall = true_pos_c / num_positives
|
||||||
|
precision = true_pos_c / np.maximum(
|
||||||
|
true_pos_c + false_pos_c,
|
||||||
|
np.finfo(np.float64).eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
precision[np.isnan(precision)] = 0
|
||||||
|
recall[np.isnan(recall)] = 0
|
||||||
|
|
||||||
|
# pascal 12 way
|
||||||
|
mprec = np.hstack((0, precision, 0))
|
||||||
|
mrec = np.hstack((0, recall, 1))
|
||||||
|
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||||
|
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||||
|
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||||
|
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||||
|
|
||||||
|
return ave_prec
|
||||||
|
|
||||||
|
|
||||||
|
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
|
||||||
|
"sklearn": metrics.average_precision_score,
|
||||||
|
"pascal_voc": pascal_voc_average_precision,
|
||||||
|
}
|
||||||
|
|||||||
@ -4,13 +4,18 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
|
from batdetect2.plotting.matches import plot_matches
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing.evaluate import (
|
||||||
ClipEvaluation,
|
ClipEvaluation,
|
||||||
@ -26,12 +31,13 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
|
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
|
||||||
|
|
||||||
|
|
||||||
class ExampleGalleryConfig(BaseConfig):
|
class ExampleGalleryConfig(BaseConfig):
|
||||||
name: Literal["example_gallery"] = "example_gallery"
|
name: Literal["example_gallery"] = "example_gallery"
|
||||||
examples_per_class: int = 5
|
examples_per_class: int = 5
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
preprocessing: PreprocessingConfig = Field(
|
preprocessing: PreprocessingConfig = Field(
|
||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
@ -87,9 +93,12 @@ class ExampleGallery(PlotterProtocol):
|
|||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: ExampleGalleryConfig):
|
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
|
||||||
preprocessor = build_preprocessor(config.preprocessing)
|
audio_loader = build_audio_loader(config.audio)
|
||||||
audio_loader = build_audio_loader(config.preprocessing.audio_transforms)
|
preprocessor = build_preprocessor(
|
||||||
|
config.preprocessing,
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
return cls(
|
return cls(
|
||||||
examples_per_class=config.examples_per_class,
|
examples_per_class=config.examples_per_class,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@ -100,13 +109,345 @@ class ExampleGallery(PlotterProtocol):
|
|||||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipEvaluationPlotConfig(BaseConfig):
|
||||||
|
name: Literal["example_clip"] = "example_clip"
|
||||||
|
num_plots: int = 5
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlotClipEvaluation(PlotterProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_plots: int = 3,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
):
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.num_plots = num_plots
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
examples = random.sample(
|
||||||
|
clip_evaluations,
|
||||||
|
k=min(self.num_plots, len(clip_evaluations)),
|
||||||
|
)
|
||||||
|
|
||||||
|
for index, clip_evaluation in enumerate(examples):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
plot_matches(
|
||||||
|
clip_evaluation.matches,
|
||||||
|
clip=clip_evaluation.clip,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
yield f"clip_evaluation/example_{index}", fig
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipEvaluationPlotConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
audio_loader = build_audio_loader(config.audio)
|
||||||
|
preprocessor = build_preprocessor(
|
||||||
|
config.preprocessing,
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
num_plots=config.num_plots,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPRCurveConfig(BaseConfig):
|
||||||
|
name: Literal["detection_pr_curve"] = "detection_pr_curve"
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPRCurve(PlotterProtocol):
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[
|
||||||
|
(match.gt_det, match.pred_score)
|
||||||
|
for clip_eval in clip_evaluations
|
||||||
|
for match in clip_eval.matches
|
||||||
|
]
|
||||||
|
)
|
||||||
|
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
ax.plot(recall, precision, label="Detector")
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
yield "detection_pr_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: DetectionPRCurveConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationPRCurvesConfig(BaseConfig):
|
||||||
|
name: Literal["classification_pr_curves"] = "classification_pr_curves"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationPRCurves(PlotterProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else "__NONE__"
|
||||||
|
)
|
||||||
|
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = label_binarize(y_true, classes=self.class_names)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 10))
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
if class_name not in self.selected:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
precision, recall, _ = metrics.precision_recall_curve(
|
||||||
|
y_true_class,
|
||||||
|
y_pred_class,
|
||||||
|
)
|
||||||
|
ax.plot(recall, precision, label=class_name)
|
||||||
|
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "classification_pr_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationPRCurvesConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
class_names=class_names,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCCurveConfig(BaseConfig):
|
||||||
|
name: Literal["detection_roc_curve"] = "detection_roc_curve"
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCCurve(PlotterProtocol):
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[
|
||||||
|
(match.gt_det, match.pred_score)
|
||||||
|
for clip_eval in clip_evaluations
|
||||||
|
for match in clip_eval.matches
|
||||||
|
]
|
||||||
|
)
|
||||||
|
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
ax.plot(fpr, tpr, label="Detection")
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
yield "detection_roc_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: DetectionROCCurveConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCCurvesConfig(BaseConfig):
|
||||||
|
name: Literal["classification_roc_curves"] = "classification_roc_curves"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCCurves(PlotterProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else "__NONE__"
|
||||||
|
)
|
||||||
|
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = label_binarize(y_true, classes=self.class_names)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 10))
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
if class_name not in self.selected:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_roced_class = y_pred[:, class_index]
|
||||||
|
fpr, tpr, _ = metrics.roc_curve(
|
||||||
|
y_true_class,
|
||||||
|
y_roced_class,
|
||||||
|
)
|
||||||
|
ax.plot(fpr, tpr, label=class_name)
|
||||||
|
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "classification_roc_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationROCCurvesConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
class_names=class_names,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
||||||
|
|
||||||
|
|
||||||
PlotConfig = Annotated[
|
PlotConfig = Annotated[
|
||||||
Union[ExampleGalleryConfig,], Field(discriminator="name")
|
Union[
|
||||||
|
ExampleGalleryConfig,
|
||||||
|
ClipEvaluationPlotConfig,
|
||||||
|
DetectionPRCurveConfig,
|
||||||
|
ClassificationPRCurvesConfig,
|
||||||
|
DetectionROCCurveConfig,
|
||||||
|
ClassificationROCCurvesConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_plotter(config: PlotConfig) -> PlotterProtocol:
|
def build_plotter(
|
||||||
return plots_registry.build(config)
|
config: PlotConfig, class_names: List[str]
|
||||||
|
) -> PlotterProtocol:
|
||||||
|
return plots_registry.build(config, class_names)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from soundevent import data, plot
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.plot.tags import TagColorMapper
|
from soundevent.plot.tags import TagColorMapper
|
||||||
|
|
||||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
|
||||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||||
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
||||||
|
|
||||||
@ -29,7 +28,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
|
|||||||
|
|
||||||
|
|
||||||
def plot_matches(
|
def plot_matches(
|
||||||
matches: List[data.Match],
|
matches: List[MatchEvaluation],
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
@ -43,8 +42,7 @@ def plot_matches(
|
|||||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
ax = plot_clip(
|
ax = plot_clip(
|
||||||
clip,
|
clip,
|
||||||
@ -60,52 +58,48 @@ def plot_matches(
|
|||||||
color_mapper = TagColorMapper()
|
color_mapper = TagColorMapper()
|
||||||
|
|
||||||
for match in matches:
|
for match in matches:
|
||||||
if match.source is None and match.target is not None:
|
if match.is_cross_trigger():
|
||||||
plot.plot_annotation(
|
plot_cross_trigger_match(
|
||||||
annotation=match.target,
|
match,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.004,
|
fill=fill,
|
||||||
freq_offset=2_000,
|
add_points=add_points,
|
||||||
|
add_spectrogram=False,
|
||||||
|
use_score=True,
|
||||||
|
color=cross_trigger_color,
|
||||||
|
add_text=False,
|
||||||
|
)
|
||||||
|
elif match.is_true_positive():
|
||||||
|
plot_true_positive_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
fill=fill,
|
||||||
|
add_spectrogram=False,
|
||||||
|
use_score=True,
|
||||||
|
add_points=add_points,
|
||||||
|
color=true_positive_color,
|
||||||
|
add_text=False,
|
||||||
|
)
|
||||||
|
elif match.is_false_negative():
|
||||||
|
plot_false_negative_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
fill=fill,
|
||||||
|
add_spectrogram=False,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=false_negative_color,
|
color=false_negative_color,
|
||||||
color_mapper=color_mapper,
|
add_text=False,
|
||||||
linestyle=annotation_linestyle,
|
|
||||||
)
|
)
|
||||||
elif match.target is None and match.source is not None:
|
elif match.is_false_positive:
|
||||||
plot_prediction(
|
plot_false_positive_match(
|
||||||
prediction=match.source,
|
match,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.004,
|
fill=fill,
|
||||||
freq_offset=2_000,
|
add_spectrogram=False,
|
||||||
|
use_score=True,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=false_positive_color,
|
color=false_positive_color,
|
||||||
color_mapper=color_mapper,
|
add_text=False,
|
||||||
linestyle=prediction_linestyle,
|
|
||||||
)
|
|
||||||
elif match.target is not None and match.source is not None:
|
|
||||||
plot.plot_annotation(
|
|
||||||
annotation=match.target,
|
|
||||||
ax=ax,
|
|
||||||
time_offset=0.004,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=true_positive_color,
|
|
||||||
color_mapper=color_mapper,
|
|
||||||
linestyle=annotation_linestyle,
|
|
||||||
)
|
|
||||||
plot_prediction(
|
|
||||||
prediction=match.source,
|
|
||||||
ax=ax,
|
|
||||||
time_offset=0.004,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=true_positive_color,
|
|
||||||
color_mapper=color_mapper,
|
|
||||||
linestyle=prediction_linestyle,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
@ -121,6 +115,9 @@ def plot_false_positive_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
use_score: bool = True,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
|
add_text: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
@ -141,34 +138,36 @@ def plot_false_positive_match(
|
|||||||
recording=match.clip.recording,
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred_geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=match.pred_score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
high_freq,
|
||||||
va="top",
|
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
@ -181,7 +180,9 @@ def plot_false_negative_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_text: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
@ -203,17 +204,18 @@ def plot_false_negative_match(
|
|||||||
recording=sound_event.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
ax = plot.plot_annotation(
|
||||||
match.sound_event_annotation,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
@ -224,15 +226,16 @@ def plot_false_negative_match(
|
|||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"False Negative \nClass: {match.gt_class} ",
|
high_freq,
|
||||||
va="top",
|
f"False Negative \nClass: {match.gt_class} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
@ -245,7 +248,10 @@ def plot_true_positive_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
use_score: bool = True,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_text: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||||
@ -269,17 +275,18 @@ def plot_true_positive_match(
|
|||||||
recording=sound_event.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
ax = plot.plot_annotation(
|
||||||
match.sound_event_annotation,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
@ -296,20 +303,21 @@ def plot_true_positive_match(
|
|||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=match.pred_score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
high_freq,
|
||||||
va="top",
|
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
@ -322,7 +330,10 @@ def plot_cross_trigger_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
use_score: bool = True,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_text: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
@ -346,17 +357,18 @@ def plot_cross_trigger_match(
|
|||||||
recording=sound_event.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
ax = plot.plot_annotation(
|
||||||
match.sound_event_annotation,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
@ -368,24 +380,25 @@ def plot_cross_trigger_match(
|
|||||||
linestyle=annotation_linestyle,
|
linestyle=annotation_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred_geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=match.pred_score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
high_freq,
|
||||||
va="top",
|
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
@ -82,6 +83,7 @@ class TrainingConfig(BaseConfig):
|
|||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
@ -13,8 +15,14 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
from lightning.pytorch.loggers import (
|
||||||
|
CSVLogger,
|
||||||
|
Logger,
|
||||||
|
MLFlowLogger,
|
||||||
|
TensorBoardLogger,
|
||||||
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -231,18 +239,17 @@ def build_logger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_image_plotter(logger: Logger):
|
Plotter = Callable[[str, Figure, int], None]
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_plotter(logger: Logger) -> Optional[Plotter]:
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
|
return logger.experiment.add_figure
|
||||||
def plot_figure(name, figure, step):
|
|
||||||
return logger.experiment.add_figure(name, figure, step)
|
|
||||||
|
|
||||||
return plot_figure
|
|
||||||
|
|
||||||
if isinstance(logger, MLFlowLogger):
|
if isinstance(logger, MLFlowLogger):
|
||||||
|
|
||||||
def plot_figure(name, figure, step):
|
def plot_figure(name, figure, step):
|
||||||
image = _convert_figure_to_image(figure)
|
image = _convert_figure_to_array(figure)
|
||||||
return logger.experiment.log_image(
|
return logger.experiment.log_image(
|
||||||
logger.run_id,
|
logger.run_id,
|
||||||
image,
|
image,
|
||||||
@ -252,8 +259,20 @@ def get_image_plotter(logger: Logger):
|
|||||||
|
|
||||||
return plot_figure
|
return plot_figure
|
||||||
|
|
||||||
|
if isinstance(logger, CSVLogger):
|
||||||
|
return partial(save_figure, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
def _convert_figure_to_image(figure):
|
|
||||||
|
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
||||||
|
path = dir / "plots" / f"{name}_step_{step}.png"
|
||||||
|
|
||||||
|
if not path.parent.exists():
|
||||||
|
path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
fig.savefig(path, transparent=True, bbox_inches="tight")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||||
with io.BytesIO() as buff:
|
with io.BytesIO() as buff:
|
||||||
figure.savefig(buff, format="raw")
|
figure.savefig(buff, format="raw")
|
||||||
buff.seek(0)
|
buff.seek(0)
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
|||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.config import TrainingConfig
|
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
@ -103,9 +102,9 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
config.train,
|
config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
evaluator=build_evaluator(config.evaluation, targets=targets),
|
evaluator=build_evaluator(config.train.validation, targets=targets),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
@ -151,7 +150,7 @@ def build_trainer_callbacks(
|
|||||||
|
|
||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: TrainingConfig,
|
conf: "BatDetect2Config",
|
||||||
targets: "TargetProtocol",
|
targets: "TargetProtocol",
|
||||||
evaluator: Optional[Evaluator] = None,
|
evaluator: Optional[Evaluator] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
@ -159,13 +158,13 @@ def build_trainer(
|
|||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.trainer
|
trainer_conf = conf.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
train_logger = build_logger(
|
train_logger = build_logger(
|
||||||
conf.logger,
|
conf.train.logger,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
|||||||
@ -50,6 +50,26 @@ class MatchEvaluation:
|
|||||||
|
|
||||||
return self.pred_class_scores[pred_class]
|
return self.pred_class_scores[pred_class]
|
||||||
|
|
||||||
|
def is_true_positive(self, threshold: float = 0) -> bool:
|
||||||
|
return (
|
||||||
|
self.gt_det
|
||||||
|
and self.pred_score > threshold
|
||||||
|
and self.gt_class == self.pred_class
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_false_positive(self, threshold: float = 0) -> bool:
|
||||||
|
return self.gt_det is None and self.pred_score > threshold
|
||||||
|
|
||||||
|
def is_false_negative(self, threshold: float = 0) -> bool:
|
||||||
|
return self.gt_det and self.pred_score <= threshold
|
||||||
|
|
||||||
|
def is_cross_trigger(self, threshold: float = 0) -> bool:
|
||||||
|
return (
|
||||||
|
self.gt_det
|
||||||
|
and self.pred_score > threshold
|
||||||
|
and self.gt_class != self.pred_class
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClipEvaluation:
|
class ClipEvaluation:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user