Update augmentations

This commit is contained in:
mbsantiago 2025-08-25 17:06:17 +01:00
parent 76dda0a0e9
commit 0bb0caddea
14 changed files with 620 additions and 1072 deletions

View File

@ -1,6 +1,11 @@
from batdetect2.plotting.clip_annotations import plot_clip_annotation from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.plotting.heatmaps import (
plot_classification_heatmap,
plot_detection_heatmap,
)
from batdetect2.plotting.matches import ( from batdetect2.plotting.matches import (
plot_cross_trigger_match, plot_cross_trigger_match,
plot_false_negative_match, plot_false_negative_match,
@ -13,9 +18,12 @@ __all__ = [
"plot_clip", "plot_clip",
"plot_clip_annotation", "plot_clip_annotation",
"plot_clip_prediction", "plot_clip_prediction",
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match", "plot_cross_trigger_match",
"plot_false_negative_match",
"plot_false_positive_match",
"plot_matches",
"plot_spectrogram",
"plot_true_positive_match",
"plot_detection_heatmap",
"plot_classification_heatmap",
] ]

View File

@ -3,6 +3,7 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch
from matplotlib import axes from matplotlib import axes
__all__ = [ __all__ = [
@ -12,7 +13,7 @@ __all__ = [
def create_ax( def create_ax(
ax: Optional[axes.Axes] = None, ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10), figsize: Optional[Tuple[int, int]] = None,
**kwargs, **kwargs,
) -> axes.Axes: ) -> axes.Axes:
"""Create a new axis if none is provided""" """Create a new axis if none is provided"""
@ -20,3 +21,14 @@ def create_ax(
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore _, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
return ax # type: ignore return ax # type: ignore
def plot_spectrogram(
spec: torch.Tensor,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
cmap="gray",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax.pcolormesh(spec.numpy(), cmap=cmap)
return ax

View File

@ -1,26 +1,115 @@
"""Plot heatmaps""" """Plot heatmaps"""
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import xarray as xr import numpy as np
from matplotlib import axes import torch
from matplotlib import axes, patches
from matplotlib.cm import get_cmap
from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
from batdetect2.plotting.common import create_ax from batdetect2.plotting.common import create_ax
def plot_heatmap( def plot_detection_heatmap(
heatmap: xr.DataArray, heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None, ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10), figsize: Tuple[int, int] = (10, 10),
threshold: Optional[float] = None,
alpha: float = 1,
cmap: Union[str, Colormap] = "jet",
color: Optional[str] = None,
) -> axes.Axes: ) -> axes.Axes:
ax = create_ax(ax, figsize=figsize) ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
if threshold is not None:
heatmap = np.ma.masked_where(
heatmap < threshold,
heatmap,
)
if color is not None:
cmap = create_colormap(color)
ax.pcolormesh( ax.pcolormesh(
heatmap.time,
heatmap.frequency,
heatmap, heatmap,
vmax=1, vmax=1,
vmin=0, vmin=0,
cmap=cmap,
alpha=alpha,
) )
return ax return ax
def plot_classification_heatmap(
heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
class_names: Optional[List[str]] = None,
threshold: Optional[float] = 0.1,
alpha: float = 1,
cmap: Union[str, Colormap] = "tab20",
):
ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
if heatmap.ndim == 4:
heatmap = heatmap[0]
if heatmap.ndim != 3:
raise ValueError("Expecting a 3-dimensional array")
num_classes = heatmap.shape[0]
if class_names is None:
class_names = [f"class_{i}" for i in range(num_classes)]
if len(class_names) != num_classes:
raise ValueError("Inconsistent number of class names")
if not isinstance(cmap, Colormap):
cmap = get_cmap(cmap)
handles = []
for index, class_heatmap in enumerate(heatmap):
class_name = class_names[index]
color = cmap(index / num_classes)
max = class_heatmap.max()
if max == 0:
continue
if threshold is not None:
class_heatmap = np.ma.masked_where(
class_heatmap < threshold,
class_heatmap,
)
ax.pcolormesh(
class_heatmap,
vmax=1,
vmin=0,
cmap=create_colormap(color), # type: ignore
alpha=alpha,
)
handles.append(patches.Patch(color=color, label=class_name))
ax.legend(handles=handles)
return ax
def create_colormap(color: str) -> Colormap:
(r, g, b, a) = to_rgba(color)
return LinearSegmentedColormap.from_list(
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
)

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,12 @@
from typing import Optional, Tuple, Union from typing import Optional, Tuple
import numpy as np import numpy as np
import xarray as xr
from loguru import logger from loguru import logger
from soundevent import arrays
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width
DEFAULT_TRAIN_CLIP_DURATION = 0.513 DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -32,40 +32,23 @@ class Clipper(ClipperProtocol):
self.max_empty = max_empty self.max_empty = max_empty
def extract_clip( def extract_clip(
self, example: xr.Dataset self, example: PreprocessedExample
) -> Tuple[xr.Dataset, float, float]: ) -> Tuple[PreprocessedExample, float, float]:
step = arrays.get_dim_step(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
duration = (
arrays.get_dim_width(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
+ step
)
start_time = 0 start_time = 0
duration = example.audio.shape[-1] / self.samplerate
if self.random: if self.random:
start_time = np.random.uniform( start_time = np.random.uniform(
-self.max_empty, -self.max_empty,
duration - self.duration + self.max_empty, duration - self.duration + self.max_empty,
) )
subclip = select_subclip(
example,
start=start_time,
span=self.duration,
dim="time",
)
return ( return (
select_subclip( select_subclip(
subclip, example,
start=start_time, start=start_time,
span=self.duration, duration=self.duration,
dim="audio_time", samplerate=self.samplerate,
), ),
start_time, start_time,
start_time + self.duration, start_time + self.duration,
@ -73,6 +56,7 @@ class Clipper(ClipperProtocol):
def build_clipper( def build_clipper(
samplerate: int,
config: Optional[ClipingConfig] = None, config: Optional[ClipingConfig] = None,
random: Optional[bool] = None, random: Optional[bool] = None,
) -> ClipperProtocol: ) -> ClipperProtocol:
@ -82,6 +66,7 @@ def build_clipper(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return Clipper( return Clipper(
samplerate=samplerate,
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,
random=config.random if random else False, random=config.random if random else False,
@ -89,106 +74,43 @@ def build_clipper(
def select_subclip( def select_subclip(
dataset: xr.Dataset, example: PreprocessedExample,
span: float,
start: float, start: float,
fill_value: float = 0,
dim: str = "time",
) -> xr.Dataset:
width = _compute_expected_width(
dataset, # type: ignore
span,
dim=dim,
)
coord = dataset.coords[dim]
if len(coord) == width:
return dataset
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
coord, start, span
)
data_vars = {}
for name, data_array in dataset.data_vars.items():
if dim not in data_array.dims:
data_vars[name] = data_array
continue
if width == data_array.sizes[dim]:
data_vars[name] = data_array
continue
sliced = data_array.isel({dim: dim_slice}).data
if start_pad > 0 or end_pad > 0:
padding = [
[0, 0] if other_dim != dim else [start_pad, end_pad]
for other_dim in data_array.dims
]
sliced = np.pad(sliced, padding, constant_values=fill_value)
data_vars[name] = xr.DataArray(
data=sliced,
dims=data_array.dims,
coords={**data_array.coords, dim: new_coords},
attrs=data_array.attrs,
)
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
def _extract_coordinate(
coord: xr.DataArray,
start: float,
span: float,
) -> Tuple[xr.Variable, int, int, slice]:
step = arrays.get_dim_step(coord, str(coord.name))
current_width = len(coord)
expected_width = int(np.floor(span / step))
coord_start = float(coord[0])
offset = start - coord_start
start_index = int(np.floor(offset / step))
end_index = start_index + expected_width
if start_index > current_width:
raise ValueError("Requested span does not overlap with current range")
if end_index < 0:
raise ValueError("Requested span does not overlap with current range")
corrected_start = float(start_index * step)
corrected_end = float(end_index * step)
start_index_offset = max(0, -start_index)
end_index_offset = max(0, end_index - current_width)
sl = slice(
start_index if start_index >= 0 else None,
end_index if end_index < current_width else None,
)
return (
arrays.create_range_dim(
str(coord.name),
start=corrected_start,
stop=corrected_end,
step=step,
),
start_index_offset,
end_index_offset,
sl,
)
def _compute_expected_width(
array: Union[xr.DataArray, xr.Dataset],
duration: float, duration: float,
dim: str, samplerate: float,
) -> int: fill_value: float = 0,
step = arrays.get_dim_step(array, dim) # type: ignore ) -> PreprocessedExample:
return int(np.floor(duration / step)) audio_width = int(np.floor(duration * samplerate))
audio_start = int(np.floor(start * samplerate))
audio = adjust_width(
example.audio[audio_start : audio_start + audio_width],
audio_width,
value=fill_value,
)
audio_duration = example.audio.shape[-1] / samplerate
spec_sr = example.spectrogram.shape[-1] / audio_duration
spec_start = int(np.floor(start * spec_sr))
spec_width = int(np.floor(duration * spec_sr))
return PreprocessedExample(
audio=audio,
spectrogram=adjust_width(
example.spectrogram[:, spec_start : spec_start + spec_width],
spec_width,
),
class_heatmap=adjust_width(
example.class_heatmap[:, :, spec_start : spec_start + spec_width],
spec_width,
),
detection_heatmap=adjust_width(
example.detection_heatmap[:, spec_start : spec_start + spec_width],
spec_width,
),
size_heatmap=adjust_width(
example.size_heatmap[:, :, spec_start : spec_start + spec_width],
spec_width,
),
)

View File

@ -22,7 +22,7 @@ includes utilities for parallel processing using `multiprocessing`.
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Optional, Sequence from typing import Callable, Optional, Sequence, TypedDict
import numpy as np import numpy as np
import torch import torch
@ -98,17 +98,25 @@ def preprocess_dataset(
) )
class Example(TypedDict):
audio: torch.Tensor
spectrogram: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def generate_train_example( def generate_train_example(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
labeller: ClipLabeller, labeller: ClipLabeller,
) -> Dict[str, torch.Tensor]: ) -> PreprocessedExample:
"""Generate a complete training example for one annotation.""" """Generate a complete training example for one annotation."""
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip)) wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
spectrogram = preprocessor(wave) spectrogram = preprocessor(wave)
heatmaps = labeller(clip_annotation, spectrogram) heatmaps = labeller(clip_annotation, spectrogram)
return dict( return PreprocessedExample(
audio=wave, audio=wave,
spectrogram=spectrogram, spectrogram=spectrogram,
detection_heatmap=heatmaps.detection, detection_heatmap=heatmaps.detection,
@ -138,8 +146,14 @@ class PreprocessingDataset(torch.utils.data.Dataset):
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
labeller=self.labeller, labeller=self.labeller,
) )
example["idx"] = idx return {
return example "idx": idx,
"spectrogram": example.spectrogram,
"audio": example.audio,
"class_heatmap": example.class_heatmap,
"size_heatmap": example.size_heatmap,
"detection_heatmap": example.detection_heatmap,
}
def __len__(self) -> int: def __len__(self) -> int:
return len(self.clips) return len(self.clips)
@ -147,16 +161,17 @@ class PreprocessingDataset(torch.utils.data.Dataset):
def _save_example_to_file( def _save_example_to_file(
example: PreprocessedExample, example: PreprocessedExample,
clip_annotation: data.ClipAnnotation,
path: data.PathLike, path: data.PathLike,
) -> None: ) -> None:
np.savez_compressed( np.savez_compressed(
path, path,
audio=example.audio, audio=example.audio.numpy(),
spectrogram=example.spectrogram, spectrogram=example.spectrogram.numpy(),
detection_heatmap=example.detection_heatmap, detection_heatmap=example.detection_heatmap.numpy(),
class_heatmap=example.class_heatmap, class_heatmap=example.class_heatmap.numpy(),
size_heatmap=example.size_heatmap, size_heatmap=example.size_heatmap.numpy(),
clip_annotation=example.clip_annotation, clip_annotation=clip_annotation,
) )
@ -211,11 +226,10 @@ def preprocess_annotations(
filename = filename_fn(clip_annotation) filename = filename_fn(clip_annotation)
path = output_dir / filename path = output_dir / filename
example = PreprocessedExample( example = PreprocessedExample(
clip_annotation=clip_annotation, spectrogram=batch["spectrogram"],
spectrogram=batch["spectrogram"].numpy(), audio=batch["audio"],
audio=batch["audio"].numpy(), class_heatmap=batch["class_heatmap"],
class_heatmap=batch["class_heatmap"].numpy(), size_heatmap=batch["size_heatmap"],
size_heatmap=batch["size_heatmap"].numpy(), detection_heatmap=batch["detection_heatmap"],
detection_heatmap=batch["detection_heatmap"].numpy(),
) )
_save_example_to_file(example, path) _save_example_to_file(example, clip_annotation, path)

View File

@ -148,6 +148,8 @@ class PreprocessorProtocol(Protocol):
min_freq: float min_freq: float
samplerate: int
audio_pipeline: AudioPipeline audio_pipeline: AudioPipeline
spectrogram_pipeline: SpectrogramPipeline spectrogram_pipeline: SpectrogramPipeline
@ -155,4 +157,4 @@ class PreprocessorProtocol(Protocol):
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_numpy(self, wav: np.ndarray) -> np.ndarray: def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy()[0, 0] return self(torch.tensor(wav)).numpy()

View File

@ -1,8 +1,6 @@
from typing import Callable, NamedTuple, Protocol, Tuple from typing import Callable, NamedTuple, Protocol, Tuple
import numpy as np
import torch import torch
import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.typing.models import ModelOutput from batdetect2.typing.models import ModelOutput
@ -19,24 +17,7 @@ __all__ = [
class Heatmaps(NamedTuple): class Heatmaps(NamedTuple):
"""Structure holding the generated heatmap targets. """Structure holding the generated heatmap targets."""
Attributes
----------
detection : xr.DataArray
Heatmap indicating the probability of sound event presence. Typically
smoothed with a Gaussian kernel centered on event reference points.
Shape matches the input spectrogram. Values normalized [0, 1].
classes : xr.DataArray
Heatmap indicating the probability of specific class presence. Has an
additional 'category' dimension corresponding to the target class
names. Each category slice is typically smoothed with a Gaussian
kernel. Values normalized [0, 1] per category.
size : xr.DataArray
Heatmap encoding the size (width, height) of detected events. Has an
additional 'dimension' coordinate ('width', 'height'). Values represent
scaled dimensions placed at the event reference points.
"""
detection: torch.Tensor detection: torch.Tensor
classes: torch.Tensor classes: torch.Tensor
@ -44,12 +25,20 @@ class Heatmaps(NamedTuple):
class PreprocessedExample(NamedTuple): class PreprocessedExample(NamedTuple):
audio: np.ndarray audio: torch.Tensor
spectrogram: np.ndarray spectrogram: torch.Tensor
detection_heatmap: np.ndarray detection_heatmap: torch.Tensor
class_heatmap: np.ndarray class_heatmap: torch.Tensor
size_heatmap: np.ndarray size_heatmap: torch.Tensor
clip_annotation: data.ClipAnnotation
def copy(self):
return PreprocessedExample(
audio=self.audio.clone(),
spectrogram=self.spectrogram.clone(),
detection_heatmap=self.detection_heatmap.clone(),
size_heatmap=self.size_heatmap.clone(),
class_heatmap=self.class_heatmap.clone(),
)
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps] ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
@ -60,7 +49,7 @@ spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training. steps, and returns the final `Heatmaps` used for model training.
""" """
Augmentation = Callable[[xr.Dataset], xr.Dataset] Augmentation = Callable[[PreprocessedExample], PreprocessedExample]
class TrainExample(NamedTuple): class TrainExample(NamedTuple):
@ -108,5 +97,5 @@ class LossProtocol(Protocol):
class ClipperProtocol(Protocol): class ClipperProtocol(Protocol):
def extract_clip( def extract_clip(
self, example: xr.Dataset self, example: PreprocessedExample
) -> Tuple[xr.Dataset, float, float]: ... ) -> Tuple[PreprocessedExample, float, float]: ...

View File

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import torch
import xarray as xr import xarray as xr
@ -35,77 +36,40 @@ def spec_to_xarray(
) )
def audio_to_xarray(
wav: np.ndarray,
start_time: float,
end_time: float,
time_axis: str = "time",
) -> xr.DataArray:
if wav.ndim != 1:
raise ValueError("Input numpy audio array should be 1-dimensional")
return xr.DataArray(
data=wav,
dims=[time_axis],
coords={
time_axis: np.linspace(
start_time,
end_time,
len(wav),
endpoint=False,
),
},
)
def extend_width( def extend_width(
array: np.ndarray, tensor: torch.Tensor,
extra: int, extra: int,
axis: int = -1, axis: int = -1,
value: float = 0, value: float = 0,
) -> np.ndarray: ) -> torch.Tensor:
dims = len(array.shape) dims = len(tensor.shape)
axis = axis % dims axis = dims - axis % dims - 1
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)] pad = [0 for _ in range(2 * dims)]
return np.pad( pad[2 * axis + 1] = extra
array, return torch.nn.functional.pad(
tensor,
pad, pad,
mode="constant", mode="constant",
constant_values=value, value=value,
) )
def make_width_divisible(
array: np.ndarray,
factor: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
width = array.shape[axis]
if width % factor == 0:
return array
extra = (-width) % factor
return extend_width(array, extra, axis=axis, value=value)
def adjust_width( def adjust_width(
array: np.ndarray, tensor: torch.Tensor,
width: int, width: int,
axis: int = -1, axis: int = -1,
value: float = 0, value: float = 0,
) -> np.ndarray: ) -> torch.Tensor:
dims = len(array.shape) dims = len(tensor.shape)
axis = axis % dims axis = axis % dims
current_width = array.shape[axis] current_width = tensor.shape[axis]
if current_width == width: if current_width == width:
return array return tensor
if current_width < width: if current_width < width:
return extend_width( return extend_width(
array, tensor,
extra=width - current_width, extra=width - current_width,
axis=axis, axis=axis,
value=value, value=value,
@ -115,11 +79,4 @@ def adjust_width(
slice(None, None) if index != axis else slice(None, width) slice(None, None) if index != axis else slice(None, width)
for index in range(dims) for index in range(dims)
] ]
return array[tuple(slices)] return tensor[tuple(slices)]
def iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)

View File

@ -431,8 +431,13 @@ def sample_targets(
@pytest.fixture @pytest.fixture
def sample_labeller( def sample_labeller(
sample_targets: TargetProtocol, sample_targets: TargetProtocol,
sample_preprocessor: PreprocessorProtocol,
) -> ClipLabeller: ) -> ClipLabeller:
return build_clip_labeler(sample_targets) return build_clip_labeler(
sample_targets,
min_freq=sample_preprocessor.min_freq,
max_freq=sample_preprocessor.max_freq,
)
@pytest.fixture @pytest.fixture

View File

@ -2,6 +2,7 @@ from collections.abc import Callable
import numpy as np import numpy as np
import pytest import pytest
import torch
import xarray as xr import xarray as xr
from soundevent import arrays, data from soundevent import arrays, data
@ -42,12 +43,17 @@ def test_mix_examples(
labeller=sample_labeller, labeller=sample_labeller,
) )
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor) mixed = mix_examples(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
assert mixed["spectrogram"].shape == example1["spectrogram"].shape assert mixed.spectrogram.shape == example1.spectrogram.shape
assert mixed["detection"].shape == example1["detection"].shape assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
assert mixed["size"].shape == example1["size"].shape assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert mixed["class"].shape == example1["class"].shape assert mixed.class_heatmap.shape == example1.class_heatmap.shape
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) @pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@ -82,13 +88,17 @@ def test_mix_examples_of_different_durations(
labeller=sample_labeller, labeller=sample_labeller,
) )
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor) mixed = mix_examples(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
# Check the spectrogram has the expected duration assert mixed.spectrogram.shape == example1.spectrogram.shape
step = arrays.get_dim_step(mixed["spectrogram"], "time") assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time") assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert start == 0 assert mixed.class_heatmap.shape == example1.class_heatmap.shape
assert np.isclose(stop + step, duration1, atol=2 * step)
def test_add_echo( def test_add_echo(
@ -107,12 +117,32 @@ def test_add_echo(
preprocessor=sample_preprocessor, preprocessor=sample_preprocessor,
labeller=sample_labeller, labeller=sample_labeller,
) )
with_echo = add_echo(original, preprocessor=sample_preprocessor) with_echo = add_echo(
original,
preprocessor=sample_preprocessor,
delay=0.1,
weight=0.3,
)
assert with_echo["spectrogram"].shape == original["spectrogram"].shape assert with_echo.spectrogram.shape == original.spectrogram.shape
xr.testing.assert_identical(with_echo["size"], original["size"]) torch.testing.assert_close(
xr.testing.assert_identical(with_echo["class"], original["class"]) with_echo.size_heatmap,
xr.testing.assert_identical(with_echo["detection"], original["detection"]) original.size_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.class_heatmap,
original.class_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.detection_heatmap,
original.detection_heatmap,
atol=0,
rtol=0,
)
def test_selected_random_subclip_has_the_correct_width( def test_selected_random_subclip_has_the_correct_width(

View File

@ -3,7 +3,6 @@ import pytest
import xarray as xr import xarray as xr
from batdetect2.train.clips import ( from batdetect2.train.clips import (
Clipper,
_compute_expected_width, _compute_expected_width,
select_subclip, select_subclip,
) )
@ -322,145 +321,3 @@ def test_select_subclip_no_overlap_raises_error(long_dataset):
start=-1.0 * CLIP_DURATION - 1.0, start=-1.0 * CLIP_DURATION - 1.0,
dim="time", dim="time",
) )
def test_clipper_non_random(long_dataset, exact_dataset, short_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=False)
for ds in [long_dataset, exact_dataset, short_dataset]:
clip, _, _ = clipper.extract_clip(ds)
expected_spec_width = _compute_expected_width(
ds, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
ds, CLIP_DURATION, "audio_time"
)
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
assert clip.time.min() >= -1 / SPEC_SAMPLERATE
assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
def test_clipper_random(long_dataset):
seed = 42
np.random.seed(seed)
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip1, _, _ = clipper.extract_clip(long_dataset)
np.random.seed(seed + 1)
clip2, _, _ = clipper.extract_clip(long_dataset)
expected_spec_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "audio_time"
)
for clip in [clip1, clip2]:
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
assert not np.isclose(clip1.time.min(), clip2.time.min())
assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min())
for clip in [clip1, clip2]:
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
max_start_time = (
(long_dataset.time.max() - long_dataset.time.min())
- CLIP_DURATION
+ MAX_EMPTY
)
assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
def test_clipper_random_max_empty_effect(long_dataset):
"""Check that max_empty influences the possible start times."""
seed = 123
data_duration = long_dataset.time.max() - long_dataset.time.min()
np.random.seed(seed)
clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0)
max_start_time0 = data_duration - CLIP_DURATION
start_times0 = []
for _ in range(20):
clip, _, _ = clipper0.extract_clip(long_dataset)
start_times0.append(clip.time.min().item())
assert all(
st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0
)
assert any(st > 0.1 for st in start_times0)
np.random.seed(seed)
clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2)
max_start_time_pos = data_duration - CLIP_DURATION + 0.2
start_times_pos = []
for _ in range(20):
clip, _, _ = clipper_pos.extract_clip(long_dataset)
start_times_pos.append(clip.time.min().item())
assert all(
st <= max_start_time_pos + 1 / SPEC_SAMPLERATE
for st in start_times_pos
)
assert any(st > max_start_time0 + 1e-6 for st in start_times_pos)
def test_clipper_short_dataset_random(short_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip, _, _ = clipper.extract_clip(short_dataset)
expected_spec_width = _compute_expected_width(
short_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
short_dataset, CLIP_DURATION, "audio_time"
)
assert clip.sizes["time"] == expected_spec_width
assert clip.sizes["audio_time"] == expected_audio_width
assert clip["spectrogram"].shape[1] == expected_spec_width
assert clip["audio"].shape[0] == expected_audio_width
assert np.any(clip.spectrogram == 0)
assert np.any(clip.audio == 0)
def test_clipper_exact_dataset_random(exact_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip, _, _ = clipper.extract_clip(exact_dataset)
expected_spec_width = _compute_expected_width(
exact_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
exact_dataset, CLIP_DURATION, "audio_time"
)
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)

View File

@ -1,6 +1,4 @@
import pytest import pytest
import torch
import xarray as xr
from soundevent import data from soundevent import data
from soundevent.terms import get_term from soundevent.terms import get_term
@ -10,6 +8,7 @@ from batdetect2.targets import build_targets, load_target_config
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ModelOutput from batdetect2.typing import ModelOutput
from batdetect2.typing.preprocess import AudioLoader
@pytest.fixture @pytest.fixture
@ -35,6 +34,8 @@ def build_from_config(
labeller = build_clip_labeler( labeller = build_clip_labeler(
targets=targets, targets=targets,
config=labels_config, config=labels_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
targets, targets,
@ -48,62 +49,8 @@ def build_from_config(
return build return build
# TODO: better name
def test_generated_train_example_has_expected_outputs(
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
generic_class:
- key: order
value: Chiroptera
preprocessing:
postprocessing:
"""
_, preprocessor, labeller, _ = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
)
clip_annotation = data.ClipAnnotation(
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
sound_events=[se1],
)
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
assert isinstance(encoded, xr.Dataset)
assert "audio" in encoded
assert "spectrogram" in encoded
assert "detection" in encoded
assert "class" in encoded
assert "size" in encoded
spec_shape = encoded["spectrogram"].shape
assert len(spec_shape) == 2
height, width = spec_shape
assert encoded["detection"].shape == (height, width)
assert encoded["class"].shape == (1, height, width)
assert encoded["size"].shape == (2, height, width)
def test_encoding_decoding_roundtrip_recovers_object( def test_encoding_decoding_roundtrip_recovers_object(
sample_audio_loader: AudioLoader,
build_from_config, build_from_config,
recording, recording,
): ):
@ -136,13 +83,17 @@ def test_encoding_decoding_roundtrip_recovers_object(
clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(clip_annotation, preprocessor, labeller) encoded = generate_train_example(
clip_annotation, sample_audio_loader, preprocessor, labeller
)
predictions = postprocessor.get_predictions( predictions = postprocessor.get_predictions(
ModelOutput( ModelOutput(
detection_probs=torch.tensor([[encoded["detection"].data]]), detection_probs=encoded["detection_heatmap"]
size_preds=torch.tensor([encoded["size"].data]), .unsqueeze(0)
class_probs=torch.tensor([encoded["class"].data]), .unsqueeze(0),
features=torch.tensor([[encoded["spectrogram"].data]]), size_preds=encoded["size_heatmap"].unsqueeze(0),
class_probs=encoded["class_heatmap"].unsqueeze(0),
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
), ),
[clip], [clip],
)[0] )[0]
@ -185,6 +136,7 @@ def test_encoding_decoding_roundtrip_recovers_object(
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
sample_audio_loader: AudioLoader,
build_from_config, build_from_config,
recording, recording,
): ):
@ -222,13 +174,20 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(clip_annotation, preprocessor, labeller) encoded = generate_train_example(
clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
)
predictions = postprocessor.get_predictions( predictions = postprocessor.get_predictions(
ModelOutput( ModelOutput(
detection_probs=torch.tensor([[encoded["detection"].data]]), detection_probs=encoded["detection_heatmap"]
size_preds=torch.tensor([encoded["size"].data]), .unsqueeze(0)
class_probs=torch.tensor([encoded["class"].data]), .unsqueeze(0),
features=torch.tensor([[encoded["spectrogram"].data]]), size_preds=encoded["size_heatmap"].unsqueeze(0),
class_probs=encoded["class_heatmap"].unsqueeze(0),
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
), ),
[clip], [clip],
)[0] )[0]

View File

@ -1,23 +1,59 @@
import numpy as np import torch
from batdetect2.utils.arrays import adjust_width, extend_width from batdetect2.utils.arrays import adjust_width, extend_width
def test_extend_width(): def test_extend_width():
array = np.random.random([1, 1, 128, 100]) array = torch.rand([1, 1, 128, 100])
extended = extend_width(array, 100) extended = extend_width(array, 100)
assert extended.shape == (1, 1, 128, 200) assert extended.shape == (1, 1, 128, 200)
extended = extend_width(array, 100, axis=0)
assert extended.shape == (101, 1, 128, 100)
extended = extend_width(array, 100, axis=1)
assert extended.shape == (1, 101, 128, 100)
extended = extend_width(array, 100, axis=2)
assert extended.shape == (1, 1, 228, 100)
extended = extend_width(array, 100, axis=3)
assert extended.shape == (1, 1, 128, 200)
extended = extend_width(array, 100, axis=-2)
assert extended.shape == (1, 1, 228, 100)
def test_extends_with_value():
array = torch.rand([1, 1, 128, 100])
extended = extend_width(array, 100, value=-1)
torch.testing.assert_close(
extended[:, :, :, 100:],
torch.ones_like(array) * -1,
rtol=0,
atol=0,
)
def test_can_adjust_short_width(): def test_can_adjust_short_width():
array = np.random.random([1, 1, 128, 100]) array = torch.rand([1, 1, 128, 100])
extended = adjust_width(array, 512) extended = adjust_width(array, 512)
assert extended.shape == (1, 1, 128, 512) assert extended.shape == (1, 1, 128, 512)
extended = adjust_width(array, 512, axis=0)
assert extended.shape == (512, 1, 128, 100)
extended = adjust_width(array, 512, axis=1)
assert extended.shape == (1, 512, 128, 100)
extended = adjust_width(array, 512, axis=2)
assert extended.shape == (1, 1, 512, 100)
extended = adjust_width(array, 512, axis=3)
assert extended.shape == (1, 1, 128, 512)
def test_can_adjust_long_width(): def test_can_adjust_long_width():
array = np.random.random([1, 1, 128, 512]) array = torch.rand([1, 1, 128, 512])
extended = adjust_width(array, 256) extended = adjust_width(array, 256)
assert extended.shape == (1, 1, 128, 256) assert extended.shape == (1, 1, 128, 256)