mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Fix plotting after update
This commit is contained in:
parent
3043230d4f
commit
d25efdad10
@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -340,3 +341,36 @@ def match_predictions_and_annotations(
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassExamples:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(matches: List[MatchEvaluation]) -> ClassExamples:
|
||||
class_examples = ClassExamples()
|
||||
|
||||
for match in matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples.false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples.false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples.cross_triggers.append(match)
|
||||
class_examples.cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples.true_positives.append(match)
|
||||
|
||||
return class_examples
|
||||
|
||||
@ -37,6 +37,10 @@ def plot_clip(
|
||||
|
||||
plot_spectrogram(
|
||||
spec,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
ax=ax,
|
||||
cmap=spec_cmap,
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes
|
||||
|
||||
@ -25,10 +26,20 @@ def create_ax(
|
||||
|
||||
def plot_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
cmap="gray",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax.pcolormesh(spec.numpy(), cmap=cmap)
|
||||
|
||||
ax.pcolormesh(
|
||||
np.linspace(start_time, end_time, spec.shape[-1], endpoint=False),
|
||||
np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False),
|
||||
spec.numpy(),
|
||||
cmap=cmap,
|
||||
)
|
||||
return ax
|
||||
|
||||
@ -100,7 +100,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives[:n_examples]):
|
||||
@ -112,7 +112,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives[:n_examples]):
|
||||
@ -124,7 +124,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||
@ -136,7 +136,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
continue
|
||||
|
||||
return fig
|
||||
|
||||
Loading…
Reference in New Issue
Block a user