Update augmentations

This commit is contained in:
mbsantiago 2025-08-25 17:06:17 +01:00
parent 76dda0a0e9
commit 0bb0caddea
14 changed files with 620 additions and 1072 deletions

View File

@ -1,6 +1,11 @@
from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.plotting.heatmaps import (
plot_classification_heatmap,
plot_detection_heatmap,
)
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
@ -13,9 +18,12 @@ __all__ = [
"plot_clip",
"plot_clip_annotation",
"plot_clip_prediction",
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
"plot_false_negative_match",
"plot_false_positive_match",
"plot_matches",
"plot_spectrogram",
"plot_true_positive_match",
"plot_detection_heatmap",
"plot_classification_heatmap",
]

View File

@ -3,6 +3,7 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import torch
from matplotlib import axes
__all__ = [
@ -12,7 +13,7 @@ __all__ = [
def create_ax(
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
figsize: Optional[Tuple[int, int]] = None,
**kwargs,
) -> axes.Axes:
"""Create a new axis if none is provided"""
@ -20,3 +21,14 @@ def create_ax(
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
return ax # type: ignore
def plot_spectrogram(
spec: torch.Tensor,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
cmap="gray",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax.pcolormesh(spec.numpy(), cmap=cmap)
return ax

View File

@ -1,26 +1,115 @@
"""Plot heatmaps"""
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import xarray as xr
from matplotlib import axes
import numpy as np
import torch
from matplotlib import axes, patches
from matplotlib.cm import get_cmap
from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
from batdetect2.plotting.common import create_ax
def plot_heatmap(
heatmap: xr.DataArray,
def plot_detection_heatmap(
heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
threshold: Optional[float] = None,
alpha: float = 1,
cmap: Union[str, Colormap] = "jet",
color: Optional[str] = None,
) -> axes.Axes:
ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
if threshold is not None:
heatmap = np.ma.masked_where(
heatmap < threshold,
heatmap,
)
if color is not None:
cmap = create_colormap(color)
ax.pcolormesh(
heatmap.time,
heatmap.frequency,
heatmap,
vmax=1,
vmin=0,
cmap=cmap,
alpha=alpha,
)
return ax
def plot_classification_heatmap(
heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
class_names: Optional[List[str]] = None,
threshold: Optional[float] = 0.1,
alpha: float = 1,
cmap: Union[str, Colormap] = "tab20",
):
ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
if heatmap.ndim == 4:
heatmap = heatmap[0]
if heatmap.ndim != 3:
raise ValueError("Expecting a 3-dimensional array")
num_classes = heatmap.shape[0]
if class_names is None:
class_names = [f"class_{i}" for i in range(num_classes)]
if len(class_names) != num_classes:
raise ValueError("Inconsistent number of class names")
if not isinstance(cmap, Colormap):
cmap = get_cmap(cmap)
handles = []
for index, class_heatmap in enumerate(heatmap):
class_name = class_names[index]
color = cmap(index / num_classes)
max = class_heatmap.max()
if max == 0:
continue
if threshold is not None:
class_heatmap = np.ma.masked_where(
class_heatmap < threshold,
class_heatmap,
)
ax.pcolormesh(
class_heatmap,
vmax=1,
vmin=0,
cmap=create_colormap(color), # type: ignore
alpha=alpha,
)
handles.append(patches.Patch(color=color, label=class_name))
ax.legend(handles=handles)
return ax
def create_colormap(color: str) -> Colormap:
(r, g, b, a) = to_rgba(color)
return LinearSegmentedColormap.from_list(
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
)

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,12 @@
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import numpy as np
import xarray as xr
from loguru import logger
from soundevent import arrays
from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width
DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -32,40 +32,23 @@ class Clipper(ClipperProtocol):
self.max_empty = max_empty
def extract_clip(
self, example: xr.Dataset
) -> Tuple[xr.Dataset, float, float]:
step = arrays.get_dim_step(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
duration = (
arrays.get_dim_width(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
+ step
)
self, example: PreprocessedExample
) -> Tuple[PreprocessedExample, float, float]:
start_time = 0
duration = example.audio.shape[-1] / self.samplerate
if self.random:
start_time = np.random.uniform(
-self.max_empty,
duration - self.duration + self.max_empty,
)
subclip = select_subclip(
example,
start=start_time,
span=self.duration,
dim="time",
)
return (
select_subclip(
subclip,
example,
start=start_time,
span=self.duration,
dim="audio_time",
duration=self.duration,
samplerate=self.samplerate,
),
start_time,
start_time + self.duration,
@ -73,6 +56,7 @@ class Clipper(ClipperProtocol):
def build_clipper(
samplerate: int,
config: Optional[ClipingConfig] = None,
random: Optional[bool] = None,
) -> ClipperProtocol:
@ -82,6 +66,7 @@ def build_clipper(
lambda: config.to_yaml_string(),
)
return Clipper(
samplerate=samplerate,
duration=config.duration,
max_empty=config.max_empty,
random=config.random if random else False,
@ -89,106 +74,43 @@ def build_clipper(
def select_subclip(
dataset: xr.Dataset,
span: float,
example: PreprocessedExample,
start: float,
fill_value: float = 0,
dim: str = "time",
) -> xr.Dataset:
width = _compute_expected_width(
dataset, # type: ignore
span,
dim=dim,
)
coord = dataset.coords[dim]
if len(coord) == width:
return dataset
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
coord, start, span
)
data_vars = {}
for name, data_array in dataset.data_vars.items():
if dim not in data_array.dims:
data_vars[name] = data_array
continue
if width == data_array.sizes[dim]:
data_vars[name] = data_array
continue
sliced = data_array.isel({dim: dim_slice}).data
if start_pad > 0 or end_pad > 0:
padding = [
[0, 0] if other_dim != dim else [start_pad, end_pad]
for other_dim in data_array.dims
]
sliced = np.pad(sliced, padding, constant_values=fill_value)
data_vars[name] = xr.DataArray(
data=sliced,
dims=data_array.dims,
coords={**data_array.coords, dim: new_coords},
attrs=data_array.attrs,
)
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
def _extract_coordinate(
coord: xr.DataArray,
start: float,
span: float,
) -> Tuple[xr.Variable, int, int, slice]:
step = arrays.get_dim_step(coord, str(coord.name))
current_width = len(coord)
expected_width = int(np.floor(span / step))
coord_start = float(coord[0])
offset = start - coord_start
start_index = int(np.floor(offset / step))
end_index = start_index + expected_width
if start_index > current_width:
raise ValueError("Requested span does not overlap with current range")
if end_index < 0:
raise ValueError("Requested span does not overlap with current range")
corrected_start = float(start_index * step)
corrected_end = float(end_index * step)
start_index_offset = max(0, -start_index)
end_index_offset = max(0, end_index - current_width)
sl = slice(
start_index if start_index >= 0 else None,
end_index if end_index < current_width else None,
)
return (
arrays.create_range_dim(
str(coord.name),
start=corrected_start,
stop=corrected_end,
step=step,
),
start_index_offset,
end_index_offset,
sl,
)
def _compute_expected_width(
array: Union[xr.DataArray, xr.Dataset],
duration: float,
dim: str,
) -> int:
step = arrays.get_dim_step(array, dim) # type: ignore
return int(np.floor(duration / step))
samplerate: float,
fill_value: float = 0,
) -> PreprocessedExample:
audio_width = int(np.floor(duration * samplerate))
audio_start = int(np.floor(start * samplerate))
audio = adjust_width(
example.audio[audio_start : audio_start + audio_width],
audio_width,
value=fill_value,
)
audio_duration = example.audio.shape[-1] / samplerate
spec_sr = example.spectrogram.shape[-1] / audio_duration
spec_start = int(np.floor(start * spec_sr))
spec_width = int(np.floor(duration * spec_sr))
return PreprocessedExample(
audio=audio,
spectrogram=adjust_width(
example.spectrogram[:, spec_start : spec_start + spec_width],
spec_width,
),
class_heatmap=adjust_width(
example.class_heatmap[:, :, spec_start : spec_start + spec_width],
spec_width,
),
detection_heatmap=adjust_width(
example.detection_heatmap[:, spec_start : spec_start + spec_width],
spec_width,
),
size_heatmap=adjust_width(
example.size_heatmap[:, :, spec_start : spec_start + spec_width],
spec_width,
),
)

View File

@ -22,7 +22,7 @@ includes utilities for parallel processing using `multiprocessing`.
import os
from pathlib import Path
from typing import Callable, Dict, Optional, Sequence
from typing import Callable, Optional, Sequence, TypedDict
import numpy as np
import torch
@ -98,17 +98,25 @@ def preprocess_dataset(
)
class Example(TypedDict):
audio: torch.Tensor
spectrogram: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def generate_train_example(
clip_annotation: data.ClipAnnotation,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
) -> Dict[str, torch.Tensor]:
) -> PreprocessedExample:
"""Generate a complete training example for one annotation."""
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
spectrogram = preprocessor(wave)
heatmaps = labeller(clip_annotation, spectrogram)
return dict(
return PreprocessedExample(
audio=wave,
spectrogram=spectrogram,
detection_heatmap=heatmaps.detection,
@ -138,8 +146,14 @@ class PreprocessingDataset(torch.utils.data.Dataset):
preprocessor=self.preprocessor,
labeller=self.labeller,
)
example["idx"] = idx
return example
return {
"idx": idx,
"spectrogram": example.spectrogram,
"audio": example.audio,
"class_heatmap": example.class_heatmap,
"size_heatmap": example.size_heatmap,
"detection_heatmap": example.detection_heatmap,
}
def __len__(self) -> int:
return len(self.clips)
@ -147,16 +161,17 @@ class PreprocessingDataset(torch.utils.data.Dataset):
def _save_example_to_file(
example: PreprocessedExample,
clip_annotation: data.ClipAnnotation,
path: data.PathLike,
) -> None:
np.savez_compressed(
path,
audio=example.audio,
spectrogram=example.spectrogram,
detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
clip_annotation=example.clip_annotation,
audio=example.audio.numpy(),
spectrogram=example.spectrogram.numpy(),
detection_heatmap=example.detection_heatmap.numpy(),
class_heatmap=example.class_heatmap.numpy(),
size_heatmap=example.size_heatmap.numpy(),
clip_annotation=clip_annotation,
)
@ -211,11 +226,10 @@ def preprocess_annotations(
filename = filename_fn(clip_annotation)
path = output_dir / filename
example = PreprocessedExample(
clip_annotation=clip_annotation,
spectrogram=batch["spectrogram"].numpy(),
audio=batch["audio"].numpy(),
class_heatmap=batch["class_heatmap"].numpy(),
size_heatmap=batch["size_heatmap"].numpy(),
detection_heatmap=batch["detection_heatmap"].numpy(),
spectrogram=batch["spectrogram"],
audio=batch["audio"],
class_heatmap=batch["class_heatmap"],
size_heatmap=batch["size_heatmap"],
detection_heatmap=batch["detection_heatmap"],
)
_save_example_to_file(example, path)
_save_example_to_file(example, clip_annotation, path)

View File

@ -148,6 +148,8 @@ class PreprocessorProtocol(Protocol):
min_freq: float
samplerate: int
audio_pipeline: AudioPipeline
spectrogram_pipeline: SpectrogramPipeline
@ -155,4 +157,4 @@ class PreprocessorProtocol(Protocol):
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy()[0, 0]
return self(torch.tensor(wav)).numpy()

View File

@ -1,8 +1,6 @@
from typing import Callable, NamedTuple, Protocol, Tuple
import numpy as np
import torch
import xarray as xr
from soundevent import data
from batdetect2.typing.models import ModelOutput
@ -19,24 +17,7 @@ __all__ = [
class Heatmaps(NamedTuple):
"""Structure holding the generated heatmap targets.
Attributes
----------
detection : xr.DataArray
Heatmap indicating the probability of sound event presence. Typically
smoothed with a Gaussian kernel centered on event reference points.
Shape matches the input spectrogram. Values normalized [0, 1].
classes : xr.DataArray
Heatmap indicating the probability of specific class presence. Has an
additional 'category' dimension corresponding to the target class
names. Each category slice is typically smoothed with a Gaussian
kernel. Values normalized [0, 1] per category.
size : xr.DataArray
Heatmap encoding the size (width, height) of detected events. Has an
additional 'dimension' coordinate ('width', 'height'). Values represent
scaled dimensions placed at the event reference points.
"""
"""Structure holding the generated heatmap targets."""
detection: torch.Tensor
classes: torch.Tensor
@ -44,12 +25,20 @@ class Heatmaps(NamedTuple):
class PreprocessedExample(NamedTuple):
audio: np.ndarray
spectrogram: np.ndarray
detection_heatmap: np.ndarray
class_heatmap: np.ndarray
size_heatmap: np.ndarray
clip_annotation: data.ClipAnnotation
audio: torch.Tensor
spectrogram: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def copy(self):
return PreprocessedExample(
audio=self.audio.clone(),
spectrogram=self.spectrogram.clone(),
detection_heatmap=self.detection_heatmap.clone(),
size_heatmap=self.size_heatmap.clone(),
class_heatmap=self.class_heatmap.clone(),
)
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
@ -60,7 +49,7 @@ spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training.
"""
Augmentation = Callable[[xr.Dataset], xr.Dataset]
Augmentation = Callable[[PreprocessedExample], PreprocessedExample]
class TrainExample(NamedTuple):
@ -108,5 +97,5 @@ class LossProtocol(Protocol):
class ClipperProtocol(Protocol):
def extract_clip(
self, example: xr.Dataset
) -> Tuple[xr.Dataset, float, float]: ...
self, example: PreprocessedExample
) -> Tuple[PreprocessedExample, float, float]: ...

View File

@ -1,4 +1,5 @@
import numpy as np
import torch
import xarray as xr
@ -35,77 +36,40 @@ def spec_to_xarray(
)
def audio_to_xarray(
wav: np.ndarray,
start_time: float,
end_time: float,
time_axis: str = "time",
) -> xr.DataArray:
if wav.ndim != 1:
raise ValueError("Input numpy audio array should be 1-dimensional")
return xr.DataArray(
data=wav,
dims=[time_axis],
coords={
time_axis: np.linspace(
start_time,
end_time,
len(wav),
endpoint=False,
),
},
)
def extend_width(
array: np.ndarray,
tensor: torch.Tensor,
extra: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
dims = len(array.shape)
axis = axis % dims
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
return np.pad(
array,
) -> torch.Tensor:
dims = len(tensor.shape)
axis = dims - axis % dims - 1
pad = [0 for _ in range(2 * dims)]
pad[2 * axis + 1] = extra
return torch.nn.functional.pad(
tensor,
pad,
mode="constant",
constant_values=value,
value=value,
)
def make_width_divisible(
array: np.ndarray,
factor: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
width = array.shape[axis]
if width % factor == 0:
return array
extra = (-width) % factor
return extend_width(array, extra, axis=axis, value=value)
def adjust_width(
array: np.ndarray,
tensor: torch.Tensor,
width: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
dims = len(array.shape)
) -> torch.Tensor:
dims = len(tensor.shape)
axis = axis % dims
current_width = array.shape[axis]
current_width = tensor.shape[axis]
if current_width == width:
return array
return tensor
if current_width < width:
return extend_width(
array,
tensor,
extra=width - current_width,
axis=axis,
value=value,
@ -115,11 +79,4 @@ def adjust_width(
slice(None, None) if index != axis else slice(None, width)
for index in range(dims)
]
return array[tuple(slices)]
def iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)
return tensor[tuple(slices)]

View File

@ -431,8 +431,13 @@ def sample_targets(
@pytest.fixture
def sample_labeller(
sample_targets: TargetProtocol,
sample_preprocessor: PreprocessorProtocol,
) -> ClipLabeller:
return build_clip_labeler(sample_targets)
return build_clip_labeler(
sample_targets,
min_freq=sample_preprocessor.min_freq,
max_freq=sample_preprocessor.max_freq,
)
@pytest.fixture

View File

@ -2,6 +2,7 @@ from collections.abc import Callable
import numpy as np
import pytest
import torch
import xarray as xr
from soundevent import arrays, data
@ -42,12 +43,17 @@ def test_mix_examples(
labeller=sample_labeller,
)
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
mixed = mix_examples(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
assert mixed["detection"].shape == example1["detection"].shape
assert mixed["size"].shape == example1["size"].shape
assert mixed["class"].shape == example1["class"].shape
assert mixed.spectrogram.shape == example1.spectrogram.shape
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@ -82,13 +88,17 @@ def test_mix_examples_of_different_durations(
labeller=sample_labeller,
)
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
mixed = mix_examples(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
# Check the spectrogram has the expected duration
step = arrays.get_dim_step(mixed["spectrogram"], "time")
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time")
assert start == 0
assert np.isclose(stop + step, duration1, atol=2 * step)
assert mixed.spectrogram.shape == example1.spectrogram.shape
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
def test_add_echo(
@ -107,12 +117,32 @@ def test_add_echo(
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
with_echo = add_echo(original, preprocessor=sample_preprocessor)
with_echo = add_echo(
original,
preprocessor=sample_preprocessor,
delay=0.1,
weight=0.3,
)
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
xr.testing.assert_identical(with_echo["size"], original["size"])
xr.testing.assert_identical(with_echo["class"], original["class"])
xr.testing.assert_identical(with_echo["detection"], original["detection"])
assert with_echo.spectrogram.shape == original.spectrogram.shape
torch.testing.assert_close(
with_echo.size_heatmap,
original.size_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.class_heatmap,
original.class_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.detection_heatmap,
original.detection_heatmap,
atol=0,
rtol=0,
)
def test_selected_random_subclip_has_the_correct_width(

View File

@ -3,7 +3,6 @@ import pytest
import xarray as xr
from batdetect2.train.clips import (
Clipper,
_compute_expected_width,
select_subclip,
)
@ -322,145 +321,3 @@ def test_select_subclip_no_overlap_raises_error(long_dataset):
start=-1.0 * CLIP_DURATION - 1.0,
dim="time",
)
def test_clipper_non_random(long_dataset, exact_dataset, short_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=False)
for ds in [long_dataset, exact_dataset, short_dataset]:
clip, _, _ = clipper.extract_clip(ds)
expected_spec_width = _compute_expected_width(
ds, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
ds, CLIP_DURATION, "audio_time"
)
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
assert clip.time.min() >= -1 / SPEC_SAMPLERATE
assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
def test_clipper_random(long_dataset):
seed = 42
np.random.seed(seed)
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip1, _, _ = clipper.extract_clip(long_dataset)
np.random.seed(seed + 1)
clip2, _, _ = clipper.extract_clip(long_dataset)
expected_spec_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "audio_time"
)
for clip in [clip1, clip2]:
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
assert not np.isclose(clip1.time.min(), clip2.time.min())
assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min())
for clip in [clip1, clip2]:
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
max_start_time = (
(long_dataset.time.max() - long_dataset.time.min())
- CLIP_DURATION
+ MAX_EMPTY
)
assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
def test_clipper_random_max_empty_effect(long_dataset):
"""Check that max_empty influences the possible start times."""
seed = 123
data_duration = long_dataset.time.max() - long_dataset.time.min()
np.random.seed(seed)
clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0)
max_start_time0 = data_duration - CLIP_DURATION
start_times0 = []
for _ in range(20):
clip, _, _ = clipper0.extract_clip(long_dataset)
start_times0.append(clip.time.min().item())
assert all(
st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0
)
assert any(st > 0.1 for st in start_times0)
np.random.seed(seed)
clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2)
max_start_time_pos = data_duration - CLIP_DURATION + 0.2
start_times_pos = []
for _ in range(20):
clip, _, _ = clipper_pos.extract_clip(long_dataset)
start_times_pos.append(clip.time.min().item())
assert all(
st <= max_start_time_pos + 1 / SPEC_SAMPLERATE
for st in start_times_pos
)
assert any(st > max_start_time0 + 1e-6 for st in start_times_pos)
def test_clipper_short_dataset_random(short_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip, _, _ = clipper.extract_clip(short_dataset)
expected_spec_width = _compute_expected_width(
short_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
short_dataset, CLIP_DURATION, "audio_time"
)
assert clip.sizes["time"] == expected_spec_width
assert clip.sizes["audio_time"] == expected_audio_width
assert clip["spectrogram"].shape[1] == expected_spec_width
assert clip["audio"].shape[0] == expected_audio_width
assert np.any(clip.spectrogram == 0)
assert np.any(clip.audio == 0)
def test_clipper_exact_dataset_random(exact_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip, _, _ = clipper.extract_clip(exact_dataset)
expected_spec_width = _compute_expected_width(
exact_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
exact_dataset, CLIP_DURATION, "audio_time"
)
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)

View File

@ -1,6 +1,4 @@
import pytest
import torch
import xarray as xr
from soundevent import data
from soundevent.terms import get_term
@ -10,6 +8,7 @@ from batdetect2.targets import build_targets, load_target_config
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ModelOutput
from batdetect2.typing.preprocess import AudioLoader
@pytest.fixture
@ -35,6 +34,8 @@ def build_from_config(
labeller = build_clip_labeler(
targets=targets,
config=labels_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
postprocessor = build_postprocessor(
targets,
@ -48,62 +49,8 @@ def build_from_config(
return build
# TODO: better name
def test_generated_train_example_has_expected_outputs(
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
generic_class:
- key: order
value: Chiroptera
preprocessing:
postprocessing:
"""
_, preprocessor, labeller, _ = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
)
clip_annotation = data.ClipAnnotation(
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
sound_events=[se1],
)
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
assert isinstance(encoded, xr.Dataset)
assert "audio" in encoded
assert "spectrogram" in encoded
assert "detection" in encoded
assert "class" in encoded
assert "size" in encoded
spec_shape = encoded["spectrogram"].shape
assert len(spec_shape) == 2
height, width = spec_shape
assert encoded["detection"].shape == (height, width)
assert encoded["class"].shape == (1, height, width)
assert encoded["size"].shape == (2, height, width)
def test_encoding_decoding_roundtrip_recovers_object(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
@ -136,13 +83,17 @@ def test_encoding_decoding_roundtrip_recovers_object(
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
encoded = generate_train_example(
clip_annotation, sample_audio_loader, preprocessor, labeller
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=torch.tensor([[encoded["detection"].data]]),
size_preds=torch.tensor([encoded["size"].data]),
class_probs=torch.tensor([encoded["class"].data]),
features=torch.tensor([[encoded["spectrogram"].data]]),
detection_probs=encoded["detection_heatmap"]
.unsqueeze(0)
.unsqueeze(0),
size_preds=encoded["size_heatmap"].unsqueeze(0),
class_probs=encoded["class_heatmap"].unsqueeze(0),
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]
@ -185,6 +136,7 @@ def test_encoding_decoding_roundtrip_recovers_object(
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
@ -222,13 +174,20 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
encoded = generate_train_example(
clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=torch.tensor([[encoded["detection"].data]]),
size_preds=torch.tensor([encoded["size"].data]),
class_probs=torch.tensor([encoded["class"].data]),
features=torch.tensor([[encoded["spectrogram"].data]]),
detection_probs=encoded["detection_heatmap"]
.unsqueeze(0)
.unsqueeze(0),
size_preds=encoded["size_heatmap"].unsqueeze(0),
class_probs=encoded["class_heatmap"].unsqueeze(0),
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]

View File

@ -1,23 +1,59 @@
import numpy as np
import torch
from batdetect2.utils.arrays import adjust_width, extend_width
def test_extend_width():
array = np.random.random([1, 1, 128, 100])
array = torch.rand([1, 1, 128, 100])
extended = extend_width(array, 100)
assert extended.shape == (1, 1, 128, 200)
extended = extend_width(array, 100, axis=0)
assert extended.shape == (101, 1, 128, 100)
extended = extend_width(array, 100, axis=1)
assert extended.shape == (1, 101, 128, 100)
extended = extend_width(array, 100, axis=2)
assert extended.shape == (1, 1, 228, 100)
extended = extend_width(array, 100, axis=3)
assert extended.shape == (1, 1, 128, 200)
extended = extend_width(array, 100, axis=-2)
assert extended.shape == (1, 1, 228, 100)
def test_extends_with_value():
array = torch.rand([1, 1, 128, 100])
extended = extend_width(array, 100, value=-1)
torch.testing.assert_close(
extended[:, :, :, 100:],
torch.ones_like(array) * -1,
rtol=0,
atol=0,
)
def test_can_adjust_short_width():
array = np.random.random([1, 1, 128, 100])
array = torch.rand([1, 1, 128, 100])
extended = adjust_width(array, 512)
assert extended.shape == (1, 1, 128, 512)
extended = adjust_width(array, 512, axis=0)
assert extended.shape == (512, 1, 128, 100)
extended = adjust_width(array, 512, axis=1)
assert extended.shape == (1, 512, 128, 100)
extended = adjust_width(array, 512, axis=2)
assert extended.shape == (1, 1, 512, 100)
extended = adjust_width(array, 512, axis=3)
assert extended.shape == (1, 1, 128, 512)
def test_can_adjust_long_width():
array = np.random.random([1, 1, 128, 512])
array = torch.rand([1, 1, 128, 512])
extended = adjust_width(array, 256)
assert extended.shape == (1, 1, 128, 256)