Compare commits

..

No commits in common. "65d13a32b76fc5868be273d5dabaed7eaabab1e8" and "2cc0bd59d4615f4916c02dc6dea50348aa79b48a" have entirely different histories.

View File

@ -100,7 +100,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except ValueError:
continue continue
for index, match in enumerate(false_positives[:n_examples]): for index, match in enumerate(false_positives[:n_examples]):
@ -112,7 +112,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except ValueError:
continue continue
for index, match in enumerate(false_negatives[:n_examples]): for index, match in enumerate(false_negatives[:n_examples]):
@ -124,11 +124,11 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except ValueError:
continue continue
for index, match in enumerate(cross_triggers[:n_examples]): for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1) ax = plt.subplot(4, n_examples, 4 * n_examples + index + 1)
try: try:
plotting.plot_cross_trigger_match( plotting.plot_cross_trigger_match(
match, match,
@ -136,7 +136,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except ValueError:
continue continue
return fig return fig
@ -154,7 +154,7 @@ def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
] ]
) )
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") bins = pd.qcut(pred_scores, q=n_examples, labels=False)
df = pd.DataFrame({"indices": indices, "bins": bins}) df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").apply(lambda x: x.sample(1)) sample = df.groupby("bins").apply(lambda x: x.sample(1))
return [matches[ind] for ind in sample["indices"]] return [matches[ind] for ind in sample["indices"]]