mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Added plotting example gallery function
This commit is contained in:
parent
d877d383a4
commit
3cfceb76b4
148
src/batdetect2/plotting/evaluation.py
Normal file
148
src/batdetect2/plotting/evaluation.py
Normal file
@ -0,0 +1,148 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from batdetect2 import plotting
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassExamples:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def plot_examples(
|
||||
matches: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
):
|
||||
class_examples = defaultdict(ClassExamples)
|
||||
|
||||
for match in matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[gt_class].cross_triggers.append(match)
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
for class_name, examples in class_examples.items():
|
||||
true_positives = get_binned_sample(
|
||||
examples.true_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_positives = get_binned_sample(
|
||||
examples.false_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_negatives = random.sample(
|
||||
examples.false_negatives,
|
||||
k=min(n_examples, len(examples.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers = get_binned_sample(
|
||||
examples.cross_triggers,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
fig = plot_class_examples(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=preprocessor,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
yield class_name, fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_class_examples(
|
||||
true_positives: List[MatchEvaluation],
|
||||
false_positives: List[MatchEvaluation],
|
||||
false_negatives: List[MatchEvaluation],
|
||||
cross_triggers: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
):
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
for index, match in enumerate(true_positives):
|
||||
ax = plt.subplot(4, n_examples, index + 1)
|
||||
plotting.plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
for index, match in enumerate(false_positives):
|
||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
||||
plotting.plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
for index, match in enumerate(false_negatives):
|
||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
||||
plotting.plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
for index, match in enumerate(cross_triggers):
|
||||
ax = plt.subplot(4, n_examples, 4 * n_examples + index + 1)
|
||||
plotting.plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[
|
||||
(index, match.pred_class_scores[pred_class])
|
||||
for index, match in enumerate(matches)
|
||||
if (pred_class := match.pred_class) is not None
|
||||
]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False)
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").apply(lambda x: x.sample(1))
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
Loading…
Reference in New Issue
Block a user