Fix plotting after update

This commit is contained in:
mbsantiago 2025-08-26 11:48:06 +01:00
parent 3043230d4f
commit d25efdad10
4 changed files with 54 additions and 5 deletions

View File

@ -1,4 +1,5 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple from typing import List, Literal, Optional, Tuple
import numpy as np import numpy as np
@ -340,3 +341,36 @@ def match_predictions_and_annotations(
) )
return matches return matches
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(matches: List[MatchEvaluation]) -> ClassExamples:
class_examples = ClassExamples()
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples.false_negatives.append(match)
continue
if gt_class is None:
class_examples.false_positives.append(match)
continue
if gt_class != pred_class:
class_examples.cross_triggers.append(match)
class_examples.cross_triggers.append(match)
continue
class_examples.true_positives.append(match)
return class_examples

View File

@ -37,6 +37,10 @@ def plot_clip(
plot_spectrogram( plot_spectrogram(
spec, spec,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
ax=ax, ax=ax,
cmap=spec_cmap, cmap=spec_cmap,
) )

View File

@ -3,6 +3,7 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
import torch import torch
from matplotlib import axes from matplotlib import axes
@ -25,10 +26,20 @@ def create_ax(
def plot_spectrogram( def plot_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
ax: Optional[axes.Axes] = None, ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
cmap="gray", cmap="gray",
) -> axes.Axes: ) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
ax.pcolormesh(spec.numpy(), cmap=cmap)
ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1], endpoint=False),
np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False),
spec.numpy(),
cmap=cmap,
)
return ax return ax

View File

@ -100,7 +100,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except (ValueError, AssertionError, RuntimeError):
continue continue
for index, match in enumerate(false_positives[:n_examples]): for index, match in enumerate(false_positives[:n_examples]):
@ -112,7 +112,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except (ValueError, AssertionError, RuntimeError):
continue continue
for index, match in enumerate(false_negatives[:n_examples]): for index, match in enumerate(false_negatives[:n_examples]):
@ -124,7 +124,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except (ValueError, AssertionError, RuntimeError):
continue continue
for index, match in enumerate(cross_triggers[:n_examples]): for index, match in enumerate(cross_triggers[:n_examples]):
@ -136,7 +136,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError): except (ValueError, AssertionError, RuntimeError):
continue continue
return fig return fig