diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index ca4665b..b0b33a5 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -4,7 +4,9 @@ from matplotlib.axes import Axes from soundevent import data, plot from batdetect2.plotting.clips import plot_clip +from batdetect2.plotting.common import create_ax from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.typing.targets import TargetProtocol __all__ = [ "plot_clip_annotation", @@ -43,3 +45,31 @@ def plot_clip_annotation( facecolor="none" if not fill else None, ) return ax + + +def plot_anchor_points( + clip_annotation: data.ClipAnnotation, + targets: TargetProtocol, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + size: int = 1, + color: str = "red", + marker: str = "x", + alpha: float = 1, +) -> Axes: + ax = create_ax(ax=ax, figsize=figsize) + + positions = [] + + for sound_event in clip_annotation.sound_events: + if not targets.filter(sound_event): + continue + + sound_event = targets.transform(sound_event) + + position, _ = targets.encode_roi(sound_event) + positions.append(position) + + X, Y = zip(*positions) + ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha) + return ax diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index 9a8b930..f9459a9 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -1,6 +1,6 @@ """General plotting utilities.""" -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -25,7 +25,7 @@ def create_ax( def plot_spectrogram( - spec: torch.Tensor, + spec: Union[torch.Tensor, np.ndarray], start_time: float, end_time: float, min_freq: float, @@ -34,12 +34,15 @@ def plot_spectrogram( figsize: Optional[Tuple[int, int]] = None, cmap="gray", ) -> axes.Axes: + if isinstance(spec, torch.Tensor): + spec = spec.numpy() + ax = create_ax(ax=ax, figsize=figsize) ax.pcolormesh( - np.linspace(start_time, end_time, spec.shape[-1], endpoint=False), - np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False), - spec.numpy(), + np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True), + np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True), + spec, cmap=cmap, ) return ax diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 3dd9da6..176ee81 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -62,7 +62,7 @@ class LabelConfig(BaseConfig): diffuse targets. """ - sigma: float = 3.0 + sigma: float = 2.0 def build_clip_labeler( @@ -174,7 +174,7 @@ def generate_clip_label( def map_to_pixels(x, size, min_val, max_val) -> int: - return int(np.floor(np.interp(x, [min_val, max_val], [0, size]))) + return int(np.interp(x, [min_val, max_val], [0, size])) def generate_heatmaps(