Fix plotting after update

This commit is contained in:
mbsantiago 2025-08-26 11:48:06 +01:00
parent 3043230d4f
commit d25efdad10
4 changed files with 54 additions and 5 deletions

View File

@ -1,4 +1,5 @@
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple
import numpy as np
@ -340,3 +341,36 @@ def match_predictions_and_annotations(
)
return matches
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(matches: List[MatchEvaluation]) -> ClassExamples:
class_examples = ClassExamples()
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples.false_negatives.append(match)
continue
if gt_class is None:
class_examples.false_positives.append(match)
continue
if gt_class != pred_class:
class_examples.cross_triggers.append(match)
class_examples.cross_triggers.append(match)
continue
class_examples.true_positives.append(match)
return class_examples

View File

@ -37,6 +37,10 @@ def plot_clip(
plot_spectrogram(
spec,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
ax=ax,
cmap=spec_cmap,
)

View File

@ -3,6 +3,7 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import axes
@ -25,10 +26,20 @@ def create_ax(
def plot_spectrogram(
spec: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
cmap="gray",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax.pcolormesh(spec.numpy(), cmap=cmap)
ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1], endpoint=False),
np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False),
spec.numpy(),
cmap=cmap,
)
return ax

View File

@ -100,7 +100,7 @@ def plot_class_examples(
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
except (ValueError, AssertionError, RuntimeError):
continue
for index, match in enumerate(false_positives[:n_examples]):
@ -112,7 +112,7 @@ def plot_class_examples(
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
except (ValueError, AssertionError, RuntimeError):
continue
for index, match in enumerate(false_negatives[:n_examples]):
@ -124,7 +124,7 @@ def plot_class_examples(
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
except (ValueError, AssertionError, RuntimeError):
continue
for index, match in enumerate(cross_triggers[:n_examples]):
@ -136,7 +136,7 @@ def plot_class_examples(
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
except (ValueError, AssertionError, RuntimeError):
continue
return fig