mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Update augmentations
This commit is contained in:
parent
76dda0a0e9
commit
0bb0caddea
@ -1,6 +1,11 @@
|
||||
from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
||||
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.plotting.heatmaps import (
|
||||
plot_classification_heatmap,
|
||||
plot_detection_heatmap,
|
||||
)
|
||||
from batdetect2.plotting.matches import (
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
@ -13,9 +18,12 @@ __all__ = [
|
||||
"plot_clip",
|
||||
"plot_clip_annotation",
|
||||
"plot_clip_prediction",
|
||||
"plot_matches",
|
||||
"plot_false_positive_match",
|
||||
"plot_true_positive_match",
|
||||
"plot_false_negative_match",
|
||||
"plot_cross_trigger_match",
|
||||
"plot_false_negative_match",
|
||||
"plot_false_positive_match",
|
||||
"plot_matches",
|
||||
"plot_spectrogram",
|
||||
"plot_true_positive_match",
|
||||
"plot_detection_heatmap",
|
||||
"plot_classification_heatmap",
|
||||
]
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib import axes
|
||||
|
||||
__all__ = [
|
||||
@ -12,7 +13,7 @@ __all__ = [
|
||||
|
||||
def create_ax(
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
@ -20,3 +21,14 @@ def create_ax(
|
||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||
|
||||
return ax # type: ignore
|
||||
|
||||
|
||||
def plot_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
cmap="gray",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax.pcolormesh(spec.numpy(), cmap=cmap)
|
||||
return ax
|
||||
|
||||
@ -1,26 +1,115 @@
|
||||
"""Plot heatmaps"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import xarray as xr
|
||||
from matplotlib import axes
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
from matplotlib.cm import get_cmap
|
||||
from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
||||
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
|
||||
def plot_heatmap(
|
||||
heatmap: xr.DataArray,
|
||||
def plot_detection_heatmap(
|
||||
heatmap: Union[torch.Tensor, np.ndarray],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
threshold: Optional[float] = None,
|
||||
alpha: float = 1,
|
||||
cmap: Union[str, Colormap] = "jet",
|
||||
color: Optional[str] = None,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
if isinstance(heatmap, torch.Tensor):
|
||||
heatmap = heatmap.numpy()
|
||||
|
||||
if threshold is not None:
|
||||
heatmap = np.ma.masked_where(
|
||||
heatmap < threshold,
|
||||
heatmap,
|
||||
)
|
||||
|
||||
if color is not None:
|
||||
cmap = create_colormap(color)
|
||||
|
||||
ax.pcolormesh(
|
||||
heatmap.time,
|
||||
heatmap.frequency,
|
||||
heatmap,
|
||||
vmax=1,
|
||||
vmin=0,
|
||||
cmap=cmap,
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_classification_heatmap(
|
||||
heatmap: Union[torch.Tensor, np.ndarray],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
class_names: Optional[List[str]] = None,
|
||||
threshold: Optional[float] = 0.1,
|
||||
alpha: float = 1,
|
||||
cmap: Union[str, Colormap] = "tab20",
|
||||
):
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
if isinstance(heatmap, torch.Tensor):
|
||||
heatmap = heatmap.numpy()
|
||||
|
||||
if heatmap.ndim == 4:
|
||||
heatmap = heatmap[0]
|
||||
|
||||
if heatmap.ndim != 3:
|
||||
raise ValueError("Expecting a 3-dimensional array")
|
||||
|
||||
num_classes = heatmap.shape[0]
|
||||
|
||||
if class_names is None:
|
||||
class_names = [f"class_{i}" for i in range(num_classes)]
|
||||
|
||||
if len(class_names) != num_classes:
|
||||
raise ValueError("Inconsistent number of class names")
|
||||
|
||||
if not isinstance(cmap, Colormap):
|
||||
cmap = get_cmap(cmap)
|
||||
|
||||
handles = []
|
||||
|
||||
for index, class_heatmap in enumerate(heatmap):
|
||||
class_name = class_names[index]
|
||||
|
||||
color = cmap(index / num_classes)
|
||||
|
||||
max = class_heatmap.max()
|
||||
|
||||
if max == 0:
|
||||
continue
|
||||
|
||||
if threshold is not None:
|
||||
class_heatmap = np.ma.masked_where(
|
||||
class_heatmap < threshold,
|
||||
class_heatmap,
|
||||
)
|
||||
|
||||
ax.pcolormesh(
|
||||
class_heatmap,
|
||||
vmax=1,
|
||||
vmin=0,
|
||||
cmap=create_colormap(color), # type: ignore
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
handles.append(patches.Patch(color=color, label=class_name))
|
||||
|
||||
ax.legend(handles=handles)
|
||||
return ax
|
||||
|
||||
|
||||
def create_colormap(color: str) -> Colormap:
|
||||
(r, g, b, a) = to_rgba(color)
|
||||
return LinearSegmentedColormap.from_list(
|
||||
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,12 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from soundevent import arrays
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
from batdetect2.typing.train import PreprocessedExample
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
@ -32,40 +32,23 @@ class Clipper(ClipperProtocol):
|
||||
self.max_empty = max_empty
|
||||
|
||||
def extract_clip(
|
||||
self, example: xr.Dataset
|
||||
) -> Tuple[xr.Dataset, float, float]:
|
||||
step = arrays.get_dim_step(
|
||||
example.spectrogram,
|
||||
dim=arrays.Dimensions.time.value,
|
||||
)
|
||||
duration = (
|
||||
arrays.get_dim_width(
|
||||
example.spectrogram,
|
||||
dim=arrays.Dimensions.time.value,
|
||||
)
|
||||
+ step
|
||||
)
|
||||
|
||||
self, example: PreprocessedExample
|
||||
) -> Tuple[PreprocessedExample, float, float]:
|
||||
start_time = 0
|
||||
duration = example.audio.shape[-1] / self.samplerate
|
||||
|
||||
if self.random:
|
||||
start_time = np.random.uniform(
|
||||
-self.max_empty,
|
||||
duration - self.duration + self.max_empty,
|
||||
)
|
||||
|
||||
subclip = select_subclip(
|
||||
example,
|
||||
start=start_time,
|
||||
span=self.duration,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
return (
|
||||
select_subclip(
|
||||
subclip,
|
||||
example,
|
||||
start=start_time,
|
||||
span=self.duration,
|
||||
dim="audio_time",
|
||||
duration=self.duration,
|
||||
samplerate=self.samplerate,
|
||||
),
|
||||
start_time,
|
||||
start_time + self.duration,
|
||||
@ -73,6 +56,7 @@ class Clipper(ClipperProtocol):
|
||||
|
||||
|
||||
def build_clipper(
|
||||
samplerate: int,
|
||||
config: Optional[ClipingConfig] = None,
|
||||
random: Optional[bool] = None,
|
||||
) -> ClipperProtocol:
|
||||
@ -82,6 +66,7 @@ def build_clipper(
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Clipper(
|
||||
samplerate=samplerate,
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
random=config.random if random else False,
|
||||
@ -89,106 +74,43 @@ def build_clipper(
|
||||
|
||||
|
||||
def select_subclip(
|
||||
dataset: xr.Dataset,
|
||||
span: float,
|
||||
example: PreprocessedExample,
|
||||
start: float,
|
||||
fill_value: float = 0,
|
||||
dim: str = "time",
|
||||
) -> xr.Dataset:
|
||||
width = _compute_expected_width(
|
||||
dataset, # type: ignore
|
||||
span,
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
coord = dataset.coords[dim]
|
||||
|
||||
if len(coord) == width:
|
||||
return dataset
|
||||
|
||||
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
|
||||
coord, start, span
|
||||
)
|
||||
|
||||
data_vars = {}
|
||||
for name, data_array in dataset.data_vars.items():
|
||||
if dim not in data_array.dims:
|
||||
data_vars[name] = data_array
|
||||
continue
|
||||
|
||||
if width == data_array.sizes[dim]:
|
||||
data_vars[name] = data_array
|
||||
continue
|
||||
|
||||
sliced = data_array.isel({dim: dim_slice}).data
|
||||
|
||||
if start_pad > 0 or end_pad > 0:
|
||||
padding = [
|
||||
[0, 0] if other_dim != dim else [start_pad, end_pad]
|
||||
for other_dim in data_array.dims
|
||||
]
|
||||
sliced = np.pad(sliced, padding, constant_values=fill_value)
|
||||
|
||||
data_vars[name] = xr.DataArray(
|
||||
data=sliced,
|
||||
dims=data_array.dims,
|
||||
coords={**data_array.coords, dim: new_coords},
|
||||
attrs=data_array.attrs,
|
||||
)
|
||||
|
||||
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
|
||||
|
||||
|
||||
def _extract_coordinate(
|
||||
coord: xr.DataArray,
|
||||
start: float,
|
||||
span: float,
|
||||
) -> Tuple[xr.Variable, int, int, slice]:
|
||||
step = arrays.get_dim_step(coord, str(coord.name))
|
||||
|
||||
current_width = len(coord)
|
||||
expected_width = int(np.floor(span / step))
|
||||
|
||||
coord_start = float(coord[0])
|
||||
offset = start - coord_start
|
||||
|
||||
start_index = int(np.floor(offset / step))
|
||||
end_index = start_index + expected_width
|
||||
|
||||
if start_index > current_width:
|
||||
raise ValueError("Requested span does not overlap with current range")
|
||||
|
||||
if end_index < 0:
|
||||
raise ValueError("Requested span does not overlap with current range")
|
||||
|
||||
corrected_start = float(start_index * step)
|
||||
corrected_end = float(end_index * step)
|
||||
|
||||
start_index_offset = max(0, -start_index)
|
||||
end_index_offset = max(0, end_index - current_width)
|
||||
|
||||
sl = slice(
|
||||
start_index if start_index >= 0 else None,
|
||||
end_index if end_index < current_width else None,
|
||||
)
|
||||
|
||||
return (
|
||||
arrays.create_range_dim(
|
||||
str(coord.name),
|
||||
start=corrected_start,
|
||||
stop=corrected_end,
|
||||
step=step,
|
||||
),
|
||||
start_index_offset,
|
||||
end_index_offset,
|
||||
sl,
|
||||
)
|
||||
|
||||
|
||||
def _compute_expected_width(
|
||||
array: Union[xr.DataArray, xr.Dataset],
|
||||
duration: float,
|
||||
dim: str,
|
||||
) -> int:
|
||||
step = arrays.get_dim_step(array, dim) # type: ignore
|
||||
return int(np.floor(duration / step))
|
||||
samplerate: float,
|
||||
fill_value: float = 0,
|
||||
) -> PreprocessedExample:
|
||||
audio_width = int(np.floor(duration * samplerate))
|
||||
audio_start = int(np.floor(start * samplerate))
|
||||
|
||||
audio = adjust_width(
|
||||
example.audio[audio_start : audio_start + audio_width],
|
||||
audio_width,
|
||||
value=fill_value,
|
||||
)
|
||||
|
||||
audio_duration = example.audio.shape[-1] / samplerate
|
||||
spec_sr = example.spectrogram.shape[-1] / audio_duration
|
||||
|
||||
spec_start = int(np.floor(start * spec_sr))
|
||||
spec_width = int(np.floor(duration * spec_sr))
|
||||
|
||||
return PreprocessedExample(
|
||||
audio=audio,
|
||||
spectrogram=adjust_width(
|
||||
example.spectrogram[:, spec_start : spec_start + spec_width],
|
||||
spec_width,
|
||||
),
|
||||
class_heatmap=adjust_width(
|
||||
example.class_heatmap[:, :, spec_start : spec_start + spec_width],
|
||||
spec_width,
|
||||
),
|
||||
detection_heatmap=adjust_width(
|
||||
example.detection_heatmap[:, spec_start : spec_start + spec_width],
|
||||
spec_width,
|
||||
),
|
||||
size_heatmap=adjust_width(
|
||||
example.size_heatmap[:, :, spec_start : spec_start + spec_width],
|
||||
spec_width,
|
||||
),
|
||||
)
|
||||
|
||||
@ -22,7 +22,7 @@ includes utilities for parallel processing using `multiprocessing`.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Sequence
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -98,17 +98,25 @@ def preprocess_dataset(
|
||||
)
|
||||
|
||||
|
||||
class Example(TypedDict):
|
||||
audio: torch.Tensor
|
||||
spectrogram: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> PreprocessedExample:
|
||||
"""Generate a complete training example for one annotation."""
|
||||
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
|
||||
spectrogram = preprocessor(wave)
|
||||
heatmaps = labeller(clip_annotation, spectrogram)
|
||||
return dict(
|
||||
return PreprocessedExample(
|
||||
audio=wave,
|
||||
spectrogram=spectrogram,
|
||||
detection_heatmap=heatmaps.detection,
|
||||
@ -138,8 +146,14 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
||||
preprocessor=self.preprocessor,
|
||||
labeller=self.labeller,
|
||||
)
|
||||
example["idx"] = idx
|
||||
return example
|
||||
return {
|
||||
"idx": idx,
|
||||
"spectrogram": example.spectrogram,
|
||||
"audio": example.audio,
|
||||
"class_heatmap": example.class_heatmap,
|
||||
"size_heatmap": example.size_heatmap,
|
||||
"detection_heatmap": example.detection_heatmap,
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
@ -147,16 +161,17 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _save_example_to_file(
|
||||
example: PreprocessedExample,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
path: data.PathLike,
|
||||
) -> None:
|
||||
np.savez_compressed(
|
||||
path,
|
||||
audio=example.audio,
|
||||
spectrogram=example.spectrogram,
|
||||
detection_heatmap=example.detection_heatmap,
|
||||
class_heatmap=example.class_heatmap,
|
||||
size_heatmap=example.size_heatmap,
|
||||
clip_annotation=example.clip_annotation,
|
||||
audio=example.audio.numpy(),
|
||||
spectrogram=example.spectrogram.numpy(),
|
||||
detection_heatmap=example.detection_heatmap.numpy(),
|
||||
class_heatmap=example.class_heatmap.numpy(),
|
||||
size_heatmap=example.size_heatmap.numpy(),
|
||||
clip_annotation=clip_annotation,
|
||||
)
|
||||
|
||||
|
||||
@ -211,11 +226,10 @@ def preprocess_annotations(
|
||||
filename = filename_fn(clip_annotation)
|
||||
path = output_dir / filename
|
||||
example = PreprocessedExample(
|
||||
clip_annotation=clip_annotation,
|
||||
spectrogram=batch["spectrogram"].numpy(),
|
||||
audio=batch["audio"].numpy(),
|
||||
class_heatmap=batch["class_heatmap"].numpy(),
|
||||
size_heatmap=batch["size_heatmap"].numpy(),
|
||||
detection_heatmap=batch["detection_heatmap"].numpy(),
|
||||
spectrogram=batch["spectrogram"],
|
||||
audio=batch["audio"],
|
||||
class_heatmap=batch["class_heatmap"],
|
||||
size_heatmap=batch["size_heatmap"],
|
||||
detection_heatmap=batch["detection_heatmap"],
|
||||
)
|
||||
_save_example_to_file(example, path)
|
||||
_save_example_to_file(example, clip_annotation, path)
|
||||
|
||||
@ -148,6 +148,8 @@ class PreprocessorProtocol(Protocol):
|
||||
|
||||
min_freq: float
|
||||
|
||||
samplerate: int
|
||||
|
||||
audio_pipeline: AudioPipeline
|
||||
|
||||
spectrogram_pipeline: SpectrogramPipeline
|
||||
@ -155,4 +157,4 @@ class PreprocessorProtocol(Protocol):
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||
return self(torch.tensor(wav)).numpy()[0, 0]
|
||||
return self(torch.tensor(wav)).numpy()
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
@ -19,24 +17,7 @@ __all__ = [
|
||||
|
||||
|
||||
class Heatmaps(NamedTuple):
|
||||
"""Structure holding the generated heatmap targets.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : xr.DataArray
|
||||
Heatmap indicating the probability of sound event presence. Typically
|
||||
smoothed with a Gaussian kernel centered on event reference points.
|
||||
Shape matches the input spectrogram. Values normalized [0, 1].
|
||||
classes : xr.DataArray
|
||||
Heatmap indicating the probability of specific class presence. Has an
|
||||
additional 'category' dimension corresponding to the target class
|
||||
names. Each category slice is typically smoothed with a Gaussian
|
||||
kernel. Values normalized [0, 1] per category.
|
||||
size : xr.DataArray
|
||||
Heatmap encoding the size (width, height) of detected events. Has an
|
||||
additional 'dimension' coordinate ('width', 'height'). Values represent
|
||||
scaled dimensions placed at the event reference points.
|
||||
"""
|
||||
"""Structure holding the generated heatmap targets."""
|
||||
|
||||
detection: torch.Tensor
|
||||
classes: torch.Tensor
|
||||
@ -44,12 +25,20 @@ class Heatmaps(NamedTuple):
|
||||
|
||||
|
||||
class PreprocessedExample(NamedTuple):
|
||||
audio: np.ndarray
|
||||
spectrogram: np.ndarray
|
||||
detection_heatmap: np.ndarray
|
||||
class_heatmap: np.ndarray
|
||||
size_heatmap: np.ndarray
|
||||
clip_annotation: data.ClipAnnotation
|
||||
audio: torch.Tensor
|
||||
spectrogram: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
def copy(self):
|
||||
return PreprocessedExample(
|
||||
audio=self.audio.clone(),
|
||||
spectrogram=self.spectrogram.clone(),
|
||||
detection_heatmap=self.detection_heatmap.clone(),
|
||||
size_heatmap=self.size_heatmap.clone(),
|
||||
class_heatmap=self.class_heatmap.clone(),
|
||||
)
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||
@ -60,7 +49,7 @@ spectrogram, applies all configured filtering, transformation, and encoding
|
||||
steps, and returns the final `Heatmaps` used for model training.
|
||||
"""
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
Augmentation = Callable[[PreprocessedExample], PreprocessedExample]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
@ -108,5 +97,5 @@ class LossProtocol(Protocol):
|
||||
|
||||
class ClipperProtocol(Protocol):
|
||||
def extract_clip(
|
||||
self, example: xr.Dataset
|
||||
) -> Tuple[xr.Dataset, float, float]: ...
|
||||
self, example: PreprocessedExample
|
||||
) -> Tuple[PreprocessedExample, float, float]: ...
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
|
||||
|
||||
@ -35,77 +36,40 @@ def spec_to_xarray(
|
||||
)
|
||||
|
||||
|
||||
def audio_to_xarray(
|
||||
wav: np.ndarray,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
time_axis: str = "time",
|
||||
) -> xr.DataArray:
|
||||
if wav.ndim != 1:
|
||||
raise ValueError("Input numpy audio array should be 1-dimensional")
|
||||
|
||||
return xr.DataArray(
|
||||
data=wav,
|
||||
dims=[time_axis],
|
||||
coords={
|
||||
time_axis: np.linspace(
|
||||
start_time,
|
||||
end_time,
|
||||
len(wav),
|
||||
endpoint=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: np.ndarray,
|
||||
tensor: torch.Tensor,
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
|
||||
return np.pad(
|
||||
array,
|
||||
) -> torch.Tensor:
|
||||
dims = len(tensor.shape)
|
||||
axis = dims - axis % dims - 1
|
||||
pad = [0 for _ in range(2 * dims)]
|
||||
pad[2 * axis + 1] = extra
|
||||
return torch.nn.functional.pad(
|
||||
tensor,
|
||||
pad,
|
||||
mode="constant",
|
||||
constant_values=value,
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def make_width_divisible(
|
||||
array: np.ndarray,
|
||||
factor: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
width = array.shape[axis]
|
||||
|
||||
if width % factor == 0:
|
||||
return array
|
||||
|
||||
extra = (-width) % factor
|
||||
return extend_width(array, extra, axis=axis, value=value)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
array: np.ndarray,
|
||||
tensor: torch.Tensor,
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
) -> torch.Tensor:
|
||||
dims = len(tensor.shape)
|
||||
axis = axis % dims
|
||||
current_width = array.shape[axis]
|
||||
current_width = tensor.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return array
|
||||
return tensor
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
array,
|
||||
tensor,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
@ -115,11 +79,4 @@ def adjust_width(
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return array[tuple(slices)]
|
||||
|
||||
|
||||
def iterate_over_array(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name]
|
||||
for value, coord in zip(array.values, coords.values):
|
||||
yield coord, float(value)
|
||||
return tensor[tuple(slices)]
|
||||
|
||||
@ -431,8 +431,13 @@ def sample_targets(
|
||||
@pytest.fixture
|
||||
def sample_labeller(
|
||||
sample_targets: TargetProtocol,
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
) -> ClipLabeller:
|
||||
return build_clip_labeler(sample_targets)
|
||||
return build_clip_labeler(
|
||||
sample_targets,
|
||||
min_freq=sample_preprocessor.min_freq,
|
||||
max_freq=sample_preprocessor.max_freq,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import arrays, data
|
||||
|
||||
@ -42,12 +43,17 @@ def test_mix_examples(
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
|
||||
mixed = mix_examples(
|
||||
example1,
|
||||
example2,
|
||||
weight=0.3,
|
||||
preprocessor=sample_preprocessor,
|
||||
)
|
||||
|
||||
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
|
||||
assert mixed["detection"].shape == example1["detection"].shape
|
||||
assert mixed["size"].shape == example1["size"].shape
|
||||
assert mixed["class"].shape == example1["class"].shape
|
||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||
@ -82,13 +88,17 @@ def test_mix_examples_of_different_durations(
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
|
||||
mixed = mix_examples(
|
||||
example1,
|
||||
example2,
|
||||
weight=0.3,
|
||||
preprocessor=sample_preprocessor,
|
||||
)
|
||||
|
||||
# Check the spectrogram has the expected duration
|
||||
step = arrays.get_dim_step(mixed["spectrogram"], "time")
|
||||
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time")
|
||||
assert start == 0
|
||||
assert np.isclose(stop + step, duration1, atol=2 * step)
|
||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||
|
||||
|
||||
def test_add_echo(
|
||||
@ -107,12 +117,32 @@ def test_add_echo(
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
with_echo = add_echo(original, preprocessor=sample_preprocessor)
|
||||
with_echo = add_echo(
|
||||
original,
|
||||
preprocessor=sample_preprocessor,
|
||||
delay=0.1,
|
||||
weight=0.3,
|
||||
)
|
||||
|
||||
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
||||
xr.testing.assert_identical(with_echo["size"], original["size"])
|
||||
xr.testing.assert_identical(with_echo["class"], original["class"])
|
||||
xr.testing.assert_identical(with_echo["detection"], original["detection"])
|
||||
assert with_echo.spectrogram.shape == original.spectrogram.shape
|
||||
torch.testing.assert_close(
|
||||
with_echo.size_heatmap,
|
||||
original.size_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
with_echo.class_heatmap,
|
||||
original.class_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
with_echo.detection_heatmap,
|
||||
original.detection_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
|
||||
|
||||
def test_selected_random_subclip_has_the_correct_width(
|
||||
|
||||
@ -3,7 +3,6 @@ import pytest
|
||||
import xarray as xr
|
||||
|
||||
from batdetect2.train.clips import (
|
||||
Clipper,
|
||||
_compute_expected_width,
|
||||
select_subclip,
|
||||
)
|
||||
@ -322,145 +321,3 @@ def test_select_subclip_no_overlap_raises_error(long_dataset):
|
||||
start=-1.0 * CLIP_DURATION - 1.0,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
|
||||
def test_clipper_non_random(long_dataset, exact_dataset, short_dataset):
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=False)
|
||||
|
||||
for ds in [long_dataset, exact_dataset, short_dataset]:
|
||||
clip, _, _ = clipper.extract_clip(ds)
|
||||
expected_spec_width = _compute_expected_width(
|
||||
ds, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
ds, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
assert clip.dims["time"] == expected_spec_width
|
||||
assert clip.dims["audio_time"] == expected_audio_width
|
||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||
assert clip.audio.shape[0] == expected_audio_width
|
||||
|
||||
assert clip.time.min() >= -1 / SPEC_SAMPLERATE
|
||||
assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE
|
||||
|
||||
time_span = clip.time.max() - clip.time.min()
|
||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
||||
|
||||
|
||||
def test_clipper_random(long_dataset):
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||
clip1, _, _ = clipper.extract_clip(long_dataset)
|
||||
|
||||
np.random.seed(seed + 1)
|
||||
clip2, _, _ = clipper.extract_clip(long_dataset)
|
||||
|
||||
expected_spec_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
for clip in [clip1, clip2]:
|
||||
assert clip.dims["time"] == expected_spec_width
|
||||
assert clip.dims["audio_time"] == expected_audio_width
|
||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||
assert clip.audio.shape[0] == expected_audio_width
|
||||
|
||||
assert not np.isclose(clip1.time.min(), clip2.time.min())
|
||||
assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min())
|
||||
|
||||
for clip in [clip1, clip2]:
|
||||
time_span = clip.time.max() - clip.time.min()
|
||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
||||
|
||||
max_start_time = (
|
||||
(long_dataset.time.max() - long_dataset.time.min())
|
||||
- CLIP_DURATION
|
||||
+ MAX_EMPTY
|
||||
)
|
||||
assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
||||
assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
||||
|
||||
|
||||
def test_clipper_random_max_empty_effect(long_dataset):
|
||||
"""Check that max_empty influences the possible start times."""
|
||||
seed = 123
|
||||
data_duration = long_dataset.time.max() - long_dataset.time.min()
|
||||
|
||||
np.random.seed(seed)
|
||||
clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0)
|
||||
max_start_time0 = data_duration - CLIP_DURATION
|
||||
start_times0 = []
|
||||
|
||||
for _ in range(20):
|
||||
clip, _, _ = clipper0.extract_clip(long_dataset)
|
||||
start_times0.append(clip.time.min().item())
|
||||
|
||||
assert all(
|
||||
st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0
|
||||
)
|
||||
assert any(st > 0.1 for st in start_times0)
|
||||
|
||||
np.random.seed(seed)
|
||||
clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2)
|
||||
max_start_time_pos = data_duration - CLIP_DURATION + 0.2
|
||||
start_times_pos = []
|
||||
for _ in range(20):
|
||||
clip, _, _ = clipper_pos.extract_clip(long_dataset)
|
||||
start_times_pos.append(clip.time.min().item())
|
||||
assert all(
|
||||
st <= max_start_time_pos + 1 / SPEC_SAMPLERATE
|
||||
for st in start_times_pos
|
||||
)
|
||||
|
||||
assert any(st > max_start_time0 + 1e-6 for st in start_times_pos)
|
||||
|
||||
|
||||
def test_clipper_short_dataset_random(short_dataset):
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||
clip, _, _ = clipper.extract_clip(short_dataset)
|
||||
|
||||
expected_spec_width = _compute_expected_width(
|
||||
short_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
short_dataset, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
assert clip.sizes["time"] == expected_spec_width
|
||||
assert clip.sizes["audio_time"] == expected_audio_width
|
||||
assert clip["spectrogram"].shape[1] == expected_spec_width
|
||||
assert clip["audio"].shape[0] == expected_audio_width
|
||||
|
||||
assert np.any(clip.spectrogram == 0)
|
||||
assert np.any(clip.audio == 0)
|
||||
|
||||
|
||||
def test_clipper_exact_dataset_random(exact_dataset):
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||
clip, _, _ = clipper.extract_clip(exact_dataset)
|
||||
|
||||
expected_spec_width = _compute_expected_width(
|
||||
exact_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
exact_dataset, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
assert clip.dims["time"] == expected_spec_width
|
||||
assert clip.dims["audio_time"] == expected_audio_width
|
||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||
assert clip.audio.shape[0] == expected_audio_width
|
||||
|
||||
time_span = clip.time.max() - clip.time.min()
|
||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
@ -10,6 +8,7 @@ from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -35,6 +34,8 @@ def build_from_config(
|
||||
labeller = build_clip_labeler(
|
||||
targets=targets,
|
||||
config=labels_config,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
@ -48,62 +49,8 @@ def build_from_config(
|
||||
return build
|
||||
|
||||
|
||||
# TODO: better name
|
||||
def test_generated_train_example_has_expected_outputs(
|
||||
build_from_config,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
postprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
],
|
||||
)
|
||||
clip_annotation = data.ClipAnnotation(
|
||||
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
||||
sound_events=[se1],
|
||||
)
|
||||
|
||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||
|
||||
assert isinstance(encoded, xr.Dataset)
|
||||
assert "audio" in encoded
|
||||
assert "spectrogram" in encoded
|
||||
assert "detection" in encoded
|
||||
assert "class" in encoded
|
||||
assert "size" in encoded
|
||||
|
||||
spec_shape = encoded["spectrogram"].shape
|
||||
assert len(spec_shape) == 2
|
||||
|
||||
height, width = spec_shape
|
||||
assert encoded["detection"].shape == (height, width)
|
||||
assert encoded["class"].shape == (1, height, width)
|
||||
assert encoded["size"].shape == (2, height, width)
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object(
|
||||
sample_audio_loader: AudioLoader,
|
||||
build_from_config,
|
||||
recording,
|
||||
):
|
||||
@ -136,13 +83,17 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||
encoded = generate_train_example(
|
||||
clip_annotation, sample_audio_loader, preprocessor, labeller
|
||||
)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
||||
size_preds=torch.tensor([encoded["size"].data]),
|
||||
class_probs=torch.tensor([encoded["class"].data]),
|
||||
features=torch.tensor([[encoded["spectrogram"].data]]),
|
||||
detection_probs=encoded["detection_heatmap"]
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0),
|
||||
size_preds=encoded["size_heatmap"].unsqueeze(0),
|
||||
class_probs=encoded["class_heatmap"].unsqueeze(0),
|
||||
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
@ -185,6 +136,7 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
sample_audio_loader: AudioLoader,
|
||||
build_from_config,
|
||||
recording,
|
||||
):
|
||||
@ -222,13 +174,20 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||
encoded = generate_train_example(
|
||||
clip_annotation,
|
||||
sample_audio_loader,
|
||||
preprocessor,
|
||||
labeller,
|
||||
)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
||||
size_preds=torch.tensor([encoded["size"].data]),
|
||||
class_probs=torch.tensor([encoded["class"].data]),
|
||||
features=torch.tensor([[encoded["spectrogram"].data]]),
|
||||
detection_probs=encoded["detection_heatmap"]
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0),
|
||||
size_preds=encoded["size_heatmap"].unsqueeze(0),
|
||||
class_probs=encoded["class_heatmap"].unsqueeze(0),
|
||||
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
|
||||
@ -1,23 +1,59 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from batdetect2.utils.arrays import adjust_width, extend_width
|
||||
|
||||
|
||||
def test_extend_width():
|
||||
array = np.random.random([1, 1, 128, 100])
|
||||
|
||||
array = torch.rand([1, 1, 128, 100])
|
||||
extended = extend_width(array, 100)
|
||||
|
||||
assert extended.shape == (1, 1, 128, 200)
|
||||
|
||||
extended = extend_width(array, 100, axis=0)
|
||||
assert extended.shape == (101, 1, 128, 100)
|
||||
|
||||
extended = extend_width(array, 100, axis=1)
|
||||
assert extended.shape == (1, 101, 128, 100)
|
||||
|
||||
extended = extend_width(array, 100, axis=2)
|
||||
assert extended.shape == (1, 1, 228, 100)
|
||||
|
||||
extended = extend_width(array, 100, axis=3)
|
||||
assert extended.shape == (1, 1, 128, 200)
|
||||
|
||||
extended = extend_width(array, 100, axis=-2)
|
||||
assert extended.shape == (1, 1, 228, 100)
|
||||
|
||||
|
||||
def test_extends_with_value():
|
||||
array = torch.rand([1, 1, 128, 100])
|
||||
extended = extend_width(array, 100, value=-1)
|
||||
torch.testing.assert_close(
|
||||
extended[:, :, :, 100:],
|
||||
torch.ones_like(array) * -1,
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
|
||||
def test_can_adjust_short_width():
|
||||
array = np.random.random([1, 1, 128, 100])
|
||||
array = torch.rand([1, 1, 128, 100])
|
||||
extended = adjust_width(array, 512)
|
||||
assert extended.shape == (1, 1, 128, 512)
|
||||
|
||||
extended = adjust_width(array, 512, axis=0)
|
||||
assert extended.shape == (512, 1, 128, 100)
|
||||
|
||||
extended = adjust_width(array, 512, axis=1)
|
||||
assert extended.shape == (1, 512, 128, 100)
|
||||
|
||||
extended = adjust_width(array, 512, axis=2)
|
||||
assert extended.shape == (1, 1, 512, 100)
|
||||
|
||||
extended = adjust_width(array, 512, axis=3)
|
||||
assert extended.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_can_adjust_long_width():
|
||||
array = np.random.random([1, 1, 128, 512])
|
||||
array = torch.rand([1, 1, 128, 512])
|
||||
extended = adjust_width(array, 256)
|
||||
assert extended.shape == (1, 1, 128, 256)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user