mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Update augmentations
This commit is contained in:
parent
76dda0a0e9
commit
0bb0caddea
@ -1,6 +1,11 @@
|
|||||||
from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
||||||
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
|
from batdetect2.plotting.common import plot_spectrogram
|
||||||
|
from batdetect2.plotting.heatmaps import (
|
||||||
|
plot_classification_heatmap,
|
||||||
|
plot_detection_heatmap,
|
||||||
|
)
|
||||||
from batdetect2.plotting.matches import (
|
from batdetect2.plotting.matches import (
|
||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
plot_false_negative_match,
|
plot_false_negative_match,
|
||||||
@ -13,9 +18,12 @@ __all__ = [
|
|||||||
"plot_clip",
|
"plot_clip",
|
||||||
"plot_clip_annotation",
|
"plot_clip_annotation",
|
||||||
"plot_clip_prediction",
|
"plot_clip_prediction",
|
||||||
"plot_matches",
|
|
||||||
"plot_false_positive_match",
|
|
||||||
"plot_true_positive_match",
|
|
||||||
"plot_false_negative_match",
|
|
||||||
"plot_cross_trigger_match",
|
"plot_cross_trigger_match",
|
||||||
|
"plot_false_negative_match",
|
||||||
|
"plot_false_positive_match",
|
||||||
|
"plot_matches",
|
||||||
|
"plot_spectrogram",
|
||||||
|
"plot_true_positive_match",
|
||||||
|
"plot_detection_heatmap",
|
||||||
|
"plot_classification_heatmap",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
from matplotlib import axes
|
from matplotlib import axes
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -12,7 +13,7 @@ __all__ = [
|
|||||||
|
|
||||||
def create_ax(
|
def create_ax(
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
@ -20,3 +21,14 @@ def create_ax(
|
|||||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||||
|
|
||||||
return ax # type: ignore
|
return ax # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
cmap="gray",
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax.pcolormesh(spec.numpy(), cmap=cmap)
|
||||||
|
return ax
|
||||||
|
|||||||
@ -1,26 +1,115 @@
|
|||||||
"""Plot heatmaps"""
|
"""Plot heatmaps"""
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import xarray as xr
|
import numpy as np
|
||||||
from matplotlib import axes
|
import torch
|
||||||
|
from matplotlib import axes, patches
|
||||||
|
from matplotlib.cm import get_cmap
|
||||||
|
from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
||||||
|
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
|
||||||
def plot_heatmap(
|
def plot_detection_heatmap(
|
||||||
heatmap: xr.DataArray,
|
heatmap: Union[torch.Tensor, np.ndarray],
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
|
threshold: Optional[float] = None,
|
||||||
|
alpha: float = 1,
|
||||||
|
cmap: Union[str, Colormap] = "jet",
|
||||||
|
color: Optional[str] = None,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax, figsize=figsize)
|
ax = create_ax(ax, figsize=figsize)
|
||||||
|
|
||||||
|
if isinstance(heatmap, torch.Tensor):
|
||||||
|
heatmap = heatmap.numpy()
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
heatmap = np.ma.masked_where(
|
||||||
|
heatmap < threshold,
|
||||||
|
heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
if color is not None:
|
||||||
|
cmap = create_colormap(color)
|
||||||
|
|
||||||
ax.pcolormesh(
|
ax.pcolormesh(
|
||||||
heatmap.time,
|
|
||||||
heatmap.frequency,
|
|
||||||
heatmap,
|
heatmap,
|
||||||
vmax=1,
|
vmax=1,
|
||||||
vmin=0,
|
vmin=0,
|
||||||
|
cmap=cmap,
|
||||||
|
alpha=alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_classification_heatmap(
|
||||||
|
heatmap: Union[torch.Tensor, np.ndarray],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
|
class_names: Optional[List[str]] = None,
|
||||||
|
threshold: Optional[float] = 0.1,
|
||||||
|
alpha: float = 1,
|
||||||
|
cmap: Union[str, Colormap] = "tab20",
|
||||||
|
):
|
||||||
|
ax = create_ax(ax, figsize=figsize)
|
||||||
|
|
||||||
|
if isinstance(heatmap, torch.Tensor):
|
||||||
|
heatmap = heatmap.numpy()
|
||||||
|
|
||||||
|
if heatmap.ndim == 4:
|
||||||
|
heatmap = heatmap[0]
|
||||||
|
|
||||||
|
if heatmap.ndim != 3:
|
||||||
|
raise ValueError("Expecting a 3-dimensional array")
|
||||||
|
|
||||||
|
num_classes = heatmap.shape[0]
|
||||||
|
|
||||||
|
if class_names is None:
|
||||||
|
class_names = [f"class_{i}" for i in range(num_classes)]
|
||||||
|
|
||||||
|
if len(class_names) != num_classes:
|
||||||
|
raise ValueError("Inconsistent number of class names")
|
||||||
|
|
||||||
|
if not isinstance(cmap, Colormap):
|
||||||
|
cmap = get_cmap(cmap)
|
||||||
|
|
||||||
|
handles = []
|
||||||
|
|
||||||
|
for index, class_heatmap in enumerate(heatmap):
|
||||||
|
class_name = class_names[index]
|
||||||
|
|
||||||
|
color = cmap(index / num_classes)
|
||||||
|
|
||||||
|
max = class_heatmap.max()
|
||||||
|
|
||||||
|
if max == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
class_heatmap = np.ma.masked_where(
|
||||||
|
class_heatmap < threshold,
|
||||||
|
class_heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.pcolormesh(
|
||||||
|
class_heatmap,
|
||||||
|
vmax=1,
|
||||||
|
vmin=0,
|
||||||
|
cmap=create_colormap(color), # type: ignore
|
||||||
|
alpha=alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
handles.append(patches.Patch(color=color, label=class_name))
|
||||||
|
|
||||||
|
ax.legend(handles=handles)
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def create_colormap(color: str) -> Colormap:
|
||||||
|
(r, g, b, a) = to_rgba(color)
|
||||||
|
return LinearSegmentedColormap.from_list(
|
||||||
|
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
||||||
|
)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,12 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import arrays
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.typing import ClipperProtocol
|
from batdetect2.typing import ClipperProtocol
|
||||||
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
@ -32,40 +32,23 @@ class Clipper(ClipperProtocol):
|
|||||||
self.max_empty = max_empty
|
self.max_empty = max_empty
|
||||||
|
|
||||||
def extract_clip(
|
def extract_clip(
|
||||||
self, example: xr.Dataset
|
self, example: PreprocessedExample
|
||||||
) -> Tuple[xr.Dataset, float, float]:
|
) -> Tuple[PreprocessedExample, float, float]:
|
||||||
step = arrays.get_dim_step(
|
|
||||||
example.spectrogram,
|
|
||||||
dim=arrays.Dimensions.time.value,
|
|
||||||
)
|
|
||||||
duration = (
|
|
||||||
arrays.get_dim_width(
|
|
||||||
example.spectrogram,
|
|
||||||
dim=arrays.Dimensions.time.value,
|
|
||||||
)
|
|
||||||
+ step
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time = 0
|
start_time = 0
|
||||||
|
duration = example.audio.shape[-1] / self.samplerate
|
||||||
|
|
||||||
if self.random:
|
if self.random:
|
||||||
start_time = np.random.uniform(
|
start_time = np.random.uniform(
|
||||||
-self.max_empty,
|
-self.max_empty,
|
||||||
duration - self.duration + self.max_empty,
|
duration - self.duration + self.max_empty,
|
||||||
)
|
)
|
||||||
|
|
||||||
subclip = select_subclip(
|
|
||||||
example,
|
|
||||||
start=start_time,
|
|
||||||
span=self.duration,
|
|
||||||
dim="time",
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
select_subclip(
|
select_subclip(
|
||||||
subclip,
|
example,
|
||||||
start=start_time,
|
start=start_time,
|
||||||
span=self.duration,
|
duration=self.duration,
|
||||||
dim="audio_time",
|
samplerate=self.samplerate,
|
||||||
),
|
),
|
||||||
start_time,
|
start_time,
|
||||||
start_time + self.duration,
|
start_time + self.duration,
|
||||||
@ -73,6 +56,7 @@ class Clipper(ClipperProtocol):
|
|||||||
|
|
||||||
|
|
||||||
def build_clipper(
|
def build_clipper(
|
||||||
|
samplerate: int,
|
||||||
config: Optional[ClipingConfig] = None,
|
config: Optional[ClipingConfig] = None,
|
||||||
random: Optional[bool] = None,
|
random: Optional[bool] = None,
|
||||||
) -> ClipperProtocol:
|
) -> ClipperProtocol:
|
||||||
@ -82,6 +66,7 @@ def build_clipper(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
return Clipper(
|
return Clipper(
|
||||||
|
samplerate=samplerate,
|
||||||
duration=config.duration,
|
duration=config.duration,
|
||||||
max_empty=config.max_empty,
|
max_empty=config.max_empty,
|
||||||
random=config.random if random else False,
|
random=config.random if random else False,
|
||||||
@ -89,106 +74,43 @@ def build_clipper(
|
|||||||
|
|
||||||
|
|
||||||
def select_subclip(
|
def select_subclip(
|
||||||
dataset: xr.Dataset,
|
example: PreprocessedExample,
|
||||||
span: float,
|
|
||||||
start: float,
|
start: float,
|
||||||
fill_value: float = 0,
|
|
||||||
dim: str = "time",
|
|
||||||
) -> xr.Dataset:
|
|
||||||
width = _compute_expected_width(
|
|
||||||
dataset, # type: ignore
|
|
||||||
span,
|
|
||||||
dim=dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
coord = dataset.coords[dim]
|
|
||||||
|
|
||||||
if len(coord) == width:
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
|
|
||||||
coord, start, span
|
|
||||||
)
|
|
||||||
|
|
||||||
data_vars = {}
|
|
||||||
for name, data_array in dataset.data_vars.items():
|
|
||||||
if dim not in data_array.dims:
|
|
||||||
data_vars[name] = data_array
|
|
||||||
continue
|
|
||||||
|
|
||||||
if width == data_array.sizes[dim]:
|
|
||||||
data_vars[name] = data_array
|
|
||||||
continue
|
|
||||||
|
|
||||||
sliced = data_array.isel({dim: dim_slice}).data
|
|
||||||
|
|
||||||
if start_pad > 0 or end_pad > 0:
|
|
||||||
padding = [
|
|
||||||
[0, 0] if other_dim != dim else [start_pad, end_pad]
|
|
||||||
for other_dim in data_array.dims
|
|
||||||
]
|
|
||||||
sliced = np.pad(sliced, padding, constant_values=fill_value)
|
|
||||||
|
|
||||||
data_vars[name] = xr.DataArray(
|
|
||||||
data=sliced,
|
|
||||||
dims=data_array.dims,
|
|
||||||
coords={**data_array.coords, dim: new_coords},
|
|
||||||
attrs=data_array.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_coordinate(
|
|
||||||
coord: xr.DataArray,
|
|
||||||
start: float,
|
|
||||||
span: float,
|
|
||||||
) -> Tuple[xr.Variable, int, int, slice]:
|
|
||||||
step = arrays.get_dim_step(coord, str(coord.name))
|
|
||||||
|
|
||||||
current_width = len(coord)
|
|
||||||
expected_width = int(np.floor(span / step))
|
|
||||||
|
|
||||||
coord_start = float(coord[0])
|
|
||||||
offset = start - coord_start
|
|
||||||
|
|
||||||
start_index = int(np.floor(offset / step))
|
|
||||||
end_index = start_index + expected_width
|
|
||||||
|
|
||||||
if start_index > current_width:
|
|
||||||
raise ValueError("Requested span does not overlap with current range")
|
|
||||||
|
|
||||||
if end_index < 0:
|
|
||||||
raise ValueError("Requested span does not overlap with current range")
|
|
||||||
|
|
||||||
corrected_start = float(start_index * step)
|
|
||||||
corrected_end = float(end_index * step)
|
|
||||||
|
|
||||||
start_index_offset = max(0, -start_index)
|
|
||||||
end_index_offset = max(0, end_index - current_width)
|
|
||||||
|
|
||||||
sl = slice(
|
|
||||||
start_index if start_index >= 0 else None,
|
|
||||||
end_index if end_index < current_width else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
arrays.create_range_dim(
|
|
||||||
str(coord.name),
|
|
||||||
start=corrected_start,
|
|
||||||
stop=corrected_end,
|
|
||||||
step=step,
|
|
||||||
),
|
|
||||||
start_index_offset,
|
|
||||||
end_index_offset,
|
|
||||||
sl,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_expected_width(
|
|
||||||
array: Union[xr.DataArray, xr.Dataset],
|
|
||||||
duration: float,
|
duration: float,
|
||||||
dim: str,
|
samplerate: float,
|
||||||
) -> int:
|
fill_value: float = 0,
|
||||||
step = arrays.get_dim_step(array, dim) # type: ignore
|
) -> PreprocessedExample:
|
||||||
return int(np.floor(duration / step))
|
audio_width = int(np.floor(duration * samplerate))
|
||||||
|
audio_start = int(np.floor(start * samplerate))
|
||||||
|
|
||||||
|
audio = adjust_width(
|
||||||
|
example.audio[audio_start : audio_start + audio_width],
|
||||||
|
audio_width,
|
||||||
|
value=fill_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_duration = example.audio.shape[-1] / samplerate
|
||||||
|
spec_sr = example.spectrogram.shape[-1] / audio_duration
|
||||||
|
|
||||||
|
spec_start = int(np.floor(start * spec_sr))
|
||||||
|
spec_width = int(np.floor(duration * spec_sr))
|
||||||
|
|
||||||
|
return PreprocessedExample(
|
||||||
|
audio=audio,
|
||||||
|
spectrogram=adjust_width(
|
||||||
|
example.spectrogram[:, spec_start : spec_start + spec_width],
|
||||||
|
spec_width,
|
||||||
|
),
|
||||||
|
class_heatmap=adjust_width(
|
||||||
|
example.class_heatmap[:, :, spec_start : spec_start + spec_width],
|
||||||
|
spec_width,
|
||||||
|
),
|
||||||
|
detection_heatmap=adjust_width(
|
||||||
|
example.detection_heatmap[:, spec_start : spec_start + spec_width],
|
||||||
|
spec_width,
|
||||||
|
),
|
||||||
|
size_heatmap=adjust_width(
|
||||||
|
example.size_heatmap[:, :, spec_start : spec_start + spec_width],
|
||||||
|
spec_width,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ includes utilities for parallel processing using `multiprocessing`.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Optional, Sequence
|
from typing import Callable, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -98,17 +98,25 @@ def preprocess_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Example(TypedDict):
|
||||||
|
audio: torch.Tensor
|
||||||
|
spectrogram: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> PreprocessedExample:
|
||||||
"""Generate a complete training example for one annotation."""
|
"""Generate a complete training example for one annotation."""
|
||||||
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
|
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
|
||||||
spectrogram = preprocessor(wave)
|
spectrogram = preprocessor(wave)
|
||||||
heatmaps = labeller(clip_annotation, spectrogram)
|
heatmaps = labeller(clip_annotation, spectrogram)
|
||||||
return dict(
|
return PreprocessedExample(
|
||||||
audio=wave,
|
audio=wave,
|
||||||
spectrogram=spectrogram,
|
spectrogram=spectrogram,
|
||||||
detection_heatmap=heatmaps.detection,
|
detection_heatmap=heatmaps.detection,
|
||||||
@ -138,8 +146,14 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
|||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
labeller=self.labeller,
|
labeller=self.labeller,
|
||||||
)
|
)
|
||||||
example["idx"] = idx
|
return {
|
||||||
return example
|
"idx": idx,
|
||||||
|
"spectrogram": example.spectrogram,
|
||||||
|
"audio": example.audio,
|
||||||
|
"class_heatmap": example.class_heatmap,
|
||||||
|
"size_heatmap": example.size_heatmap,
|
||||||
|
"detection_heatmap": example.detection_heatmap,
|
||||||
|
}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.clips)
|
return len(self.clips)
|
||||||
@ -147,16 +161,17 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def _save_example_to_file(
|
def _save_example_to_file(
|
||||||
example: PreprocessedExample,
|
example: PreprocessedExample,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
) -> None:
|
) -> None:
|
||||||
np.savez_compressed(
|
np.savez_compressed(
|
||||||
path,
|
path,
|
||||||
audio=example.audio,
|
audio=example.audio.numpy(),
|
||||||
spectrogram=example.spectrogram,
|
spectrogram=example.spectrogram.numpy(),
|
||||||
detection_heatmap=example.detection_heatmap,
|
detection_heatmap=example.detection_heatmap.numpy(),
|
||||||
class_heatmap=example.class_heatmap,
|
class_heatmap=example.class_heatmap.numpy(),
|
||||||
size_heatmap=example.size_heatmap,
|
size_heatmap=example.size_heatmap.numpy(),
|
||||||
clip_annotation=example.clip_annotation,
|
clip_annotation=clip_annotation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -211,11 +226,10 @@ def preprocess_annotations(
|
|||||||
filename = filename_fn(clip_annotation)
|
filename = filename_fn(clip_annotation)
|
||||||
path = output_dir / filename
|
path = output_dir / filename
|
||||||
example = PreprocessedExample(
|
example = PreprocessedExample(
|
||||||
clip_annotation=clip_annotation,
|
spectrogram=batch["spectrogram"],
|
||||||
spectrogram=batch["spectrogram"].numpy(),
|
audio=batch["audio"],
|
||||||
audio=batch["audio"].numpy(),
|
class_heatmap=batch["class_heatmap"],
|
||||||
class_heatmap=batch["class_heatmap"].numpy(),
|
size_heatmap=batch["size_heatmap"],
|
||||||
size_heatmap=batch["size_heatmap"].numpy(),
|
detection_heatmap=batch["detection_heatmap"],
|
||||||
detection_heatmap=batch["detection_heatmap"].numpy(),
|
|
||||||
)
|
)
|
||||||
_save_example_to_file(example, path)
|
_save_example_to_file(example, clip_annotation, path)
|
||||||
|
|||||||
@ -148,6 +148,8 @@ class PreprocessorProtocol(Protocol):
|
|||||||
|
|
||||||
min_freq: float
|
min_freq: float
|
||||||
|
|
||||||
|
samplerate: int
|
||||||
|
|
||||||
audio_pipeline: AudioPipeline
|
audio_pipeline: AudioPipeline
|
||||||
|
|
||||||
spectrogram_pipeline: SpectrogramPipeline
|
spectrogram_pipeline: SpectrogramPipeline
|
||||||
@ -155,4 +157,4 @@ class PreprocessorProtocol(Protocol):
|
|||||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||||
return self(torch.tensor(wav)).numpy()[0, 0]
|
return self(torch.tensor(wav)).numpy()
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.models import ModelOutput
|
from batdetect2.typing.models import ModelOutput
|
||||||
@ -19,24 +17,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Heatmaps(NamedTuple):
|
class Heatmaps(NamedTuple):
|
||||||
"""Structure holding the generated heatmap targets.
|
"""Structure holding the generated heatmap targets."""
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
detection : xr.DataArray
|
|
||||||
Heatmap indicating the probability of sound event presence. Typically
|
|
||||||
smoothed with a Gaussian kernel centered on event reference points.
|
|
||||||
Shape matches the input spectrogram. Values normalized [0, 1].
|
|
||||||
classes : xr.DataArray
|
|
||||||
Heatmap indicating the probability of specific class presence. Has an
|
|
||||||
additional 'category' dimension corresponding to the target class
|
|
||||||
names. Each category slice is typically smoothed with a Gaussian
|
|
||||||
kernel. Values normalized [0, 1] per category.
|
|
||||||
size : xr.DataArray
|
|
||||||
Heatmap encoding the size (width, height) of detected events. Has an
|
|
||||||
additional 'dimension' coordinate ('width', 'height'). Values represent
|
|
||||||
scaled dimensions placed at the event reference points.
|
|
||||||
"""
|
|
||||||
|
|
||||||
detection: torch.Tensor
|
detection: torch.Tensor
|
||||||
classes: torch.Tensor
|
classes: torch.Tensor
|
||||||
@ -44,12 +25,20 @@ class Heatmaps(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class PreprocessedExample(NamedTuple):
|
class PreprocessedExample(NamedTuple):
|
||||||
audio: np.ndarray
|
audio: torch.Tensor
|
||||||
spectrogram: np.ndarray
|
spectrogram: torch.Tensor
|
||||||
detection_heatmap: np.ndarray
|
detection_heatmap: torch.Tensor
|
||||||
class_heatmap: np.ndarray
|
class_heatmap: torch.Tensor
|
||||||
size_heatmap: np.ndarray
|
size_heatmap: torch.Tensor
|
||||||
clip_annotation: data.ClipAnnotation
|
|
||||||
|
def copy(self):
|
||||||
|
return PreprocessedExample(
|
||||||
|
audio=self.audio.clone(),
|
||||||
|
spectrogram=self.spectrogram.clone(),
|
||||||
|
detection_heatmap=self.detection_heatmap.clone(),
|
||||||
|
size_heatmap=self.size_heatmap.clone(),
|
||||||
|
class_heatmap=self.class_heatmap.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||||
@ -60,7 +49,7 @@ spectrogram, applies all configured filtering, transformation, and encoding
|
|||||||
steps, and returns the final `Heatmaps` used for model training.
|
steps, and returns the final `Heatmaps` used for model training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
Augmentation = Callable[[PreprocessedExample], PreprocessedExample]
|
||||||
|
|
||||||
|
|
||||||
class TrainExample(NamedTuple):
|
class TrainExample(NamedTuple):
|
||||||
@ -108,5 +97,5 @@ class LossProtocol(Protocol):
|
|||||||
|
|
||||||
class ClipperProtocol(Protocol):
|
class ClipperProtocol(Protocol):
|
||||||
def extract_clip(
|
def extract_clip(
|
||||||
self, example: xr.Dataset
|
self, example: PreprocessedExample
|
||||||
) -> Tuple[xr.Dataset, float, float]: ...
|
) -> Tuple[PreprocessedExample, float, float]: ...
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
|
||||||
|
|
||||||
@ -35,77 +36,40 @@ def spec_to_xarray(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def audio_to_xarray(
|
|
||||||
wav: np.ndarray,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
time_axis: str = "time",
|
|
||||||
) -> xr.DataArray:
|
|
||||||
if wav.ndim != 1:
|
|
||||||
raise ValueError("Input numpy audio array should be 1-dimensional")
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=wav,
|
|
||||||
dims=[time_axis],
|
|
||||||
coords={
|
|
||||||
time_axis: np.linspace(
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
len(wav),
|
|
||||||
endpoint=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extend_width(
|
def extend_width(
|
||||||
array: np.ndarray,
|
tensor: torch.Tensor,
|
||||||
extra: int,
|
extra: int,
|
||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
value: float = 0,
|
value: float = 0,
|
||||||
) -> np.ndarray:
|
) -> torch.Tensor:
|
||||||
dims = len(array.shape)
|
dims = len(tensor.shape)
|
||||||
axis = axis % dims
|
axis = dims - axis % dims - 1
|
||||||
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
|
pad = [0 for _ in range(2 * dims)]
|
||||||
return np.pad(
|
pad[2 * axis + 1] = extra
|
||||||
array,
|
return torch.nn.functional.pad(
|
||||||
|
tensor,
|
||||||
pad,
|
pad,
|
||||||
mode="constant",
|
mode="constant",
|
||||||
constant_values=value,
|
value=value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_width_divisible(
|
|
||||||
array: np.ndarray,
|
|
||||||
factor: int,
|
|
||||||
axis: int = -1,
|
|
||||||
value: float = 0,
|
|
||||||
) -> np.ndarray:
|
|
||||||
width = array.shape[axis]
|
|
||||||
|
|
||||||
if width % factor == 0:
|
|
||||||
return array
|
|
||||||
|
|
||||||
extra = (-width) % factor
|
|
||||||
return extend_width(array, extra, axis=axis, value=value)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_width(
|
def adjust_width(
|
||||||
array: np.ndarray,
|
tensor: torch.Tensor,
|
||||||
width: int,
|
width: int,
|
||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
value: float = 0,
|
value: float = 0,
|
||||||
) -> np.ndarray:
|
) -> torch.Tensor:
|
||||||
dims = len(array.shape)
|
dims = len(tensor.shape)
|
||||||
axis = axis % dims
|
axis = axis % dims
|
||||||
current_width = array.shape[axis]
|
current_width = tensor.shape[axis]
|
||||||
|
|
||||||
if current_width == width:
|
if current_width == width:
|
||||||
return array
|
return tensor
|
||||||
|
|
||||||
if current_width < width:
|
if current_width < width:
|
||||||
return extend_width(
|
return extend_width(
|
||||||
array,
|
tensor,
|
||||||
extra=width - current_width,
|
extra=width - current_width,
|
||||||
axis=axis,
|
axis=axis,
|
||||||
value=value,
|
value=value,
|
||||||
@ -115,11 +79,4 @@ def adjust_width(
|
|||||||
slice(None, None) if index != axis else slice(None, width)
|
slice(None, None) if index != axis else slice(None, width)
|
||||||
for index in range(dims)
|
for index in range(dims)
|
||||||
]
|
]
|
||||||
return array[tuple(slices)]
|
return tensor[tuple(slices)]
|
||||||
|
|
||||||
|
|
||||||
def iterate_over_array(array: xr.DataArray):
|
|
||||||
dim_name = array.dims[0]
|
|
||||||
coords = array.coords[dim_name]
|
|
||||||
for value, coord in zip(array.values, coords.values):
|
|
||||||
yield coord, float(value)
|
|
||||||
|
|||||||
@ -431,8 +431,13 @@ def sample_targets(
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_labeller(
|
def sample_labeller(
|
||||||
sample_targets: TargetProtocol,
|
sample_targets: TargetProtocol,
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
return build_clip_labeler(sample_targets)
|
return build_clip_labeler(
|
||||||
|
sample_targets,
|
||||||
|
min_freq=sample_preprocessor.min_freq,
|
||||||
|
max_freq=sample_preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from collections.abc import Callable
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
@ -42,12 +43,17 @@ def test_mix_examples(
|
|||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
|
mixed = mix_examples(
|
||||||
|
example1,
|
||||||
|
example2,
|
||||||
|
weight=0.3,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
|
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||||
assert mixed["detection"].shape == example1["detection"].shape
|
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||||
assert mixed["size"].shape == example1["size"].shape
|
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||||
assert mixed["class"].shape == example1["class"].shape
|
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||||
@ -82,13 +88,17 @@ def test_mix_examples_of_different_durations(
|
|||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
|
mixed = mix_examples(
|
||||||
|
example1,
|
||||||
|
example2,
|
||||||
|
weight=0.3,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
# Check the spectrogram has the expected duration
|
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||||
step = arrays.get_dim_step(mixed["spectrogram"], "time")
|
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||||
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time")
|
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||||
assert start == 0
|
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||||
assert np.isclose(stop + step, duration1, atol=2 * step)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_echo(
|
def test_add_echo(
|
||||||
@ -107,12 +117,32 @@ def test_add_echo(
|
|||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
with_echo = add_echo(original, preprocessor=sample_preprocessor)
|
with_echo = add_echo(
|
||||||
|
original,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
delay=0.1,
|
||||||
|
weight=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
assert with_echo.spectrogram.shape == original.spectrogram.shape
|
||||||
xr.testing.assert_identical(with_echo["size"], original["size"])
|
torch.testing.assert_close(
|
||||||
xr.testing.assert_identical(with_echo["class"], original["class"])
|
with_echo.size_heatmap,
|
||||||
xr.testing.assert_identical(with_echo["detection"], original["detection"])
|
original.size_heatmap,
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
with_echo.class_heatmap,
|
||||||
|
original.class_heatmap,
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
with_echo.detection_heatmap,
|
||||||
|
original.detection_heatmap,
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_selected_random_subclip_has_the_correct_width(
|
def test_selected_random_subclip_has_the_correct_width(
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import pytest
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
|
||||||
from batdetect2.train.clips import (
|
from batdetect2.train.clips import (
|
||||||
Clipper,
|
|
||||||
_compute_expected_width,
|
_compute_expected_width,
|
||||||
select_subclip,
|
select_subclip,
|
||||||
)
|
)
|
||||||
@ -322,145 +321,3 @@ def test_select_subclip_no_overlap_raises_error(long_dataset):
|
|||||||
start=-1.0 * CLIP_DURATION - 1.0,
|
start=-1.0 * CLIP_DURATION - 1.0,
|
||||||
dim="time",
|
dim="time",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_clipper_non_random(long_dataset, exact_dataset, short_dataset):
|
|
||||||
clipper = Clipper(duration=CLIP_DURATION, random=False)
|
|
||||||
|
|
||||||
for ds in [long_dataset, exact_dataset, short_dataset]:
|
|
||||||
clip, _, _ = clipper.extract_clip(ds)
|
|
||||||
expected_spec_width = _compute_expected_width(
|
|
||||||
ds, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
expected_audio_width = _compute_expected_width(
|
|
||||||
ds, CLIP_DURATION, "audio_time"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert clip.dims["time"] == expected_spec_width
|
|
||||||
assert clip.dims["audio_time"] == expected_audio_width
|
|
||||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
|
||||||
assert clip.audio.shape[0] == expected_audio_width
|
|
||||||
|
|
||||||
assert clip.time.min() >= -1 / SPEC_SAMPLERATE
|
|
||||||
assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE
|
|
||||||
|
|
||||||
time_span = clip.time.max() - clip.time.min()
|
|
||||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
|
||||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
|
||||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clipper_random(long_dataset):
|
|
||||||
seed = 42
|
|
||||||
np.random.seed(seed)
|
|
||||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
|
||||||
clip1, _, _ = clipper.extract_clip(long_dataset)
|
|
||||||
|
|
||||||
np.random.seed(seed + 1)
|
|
||||||
clip2, _, _ = clipper.extract_clip(long_dataset)
|
|
||||||
|
|
||||||
expected_spec_width = _compute_expected_width(
|
|
||||||
long_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
expected_audio_width = _compute_expected_width(
|
|
||||||
long_dataset, CLIP_DURATION, "audio_time"
|
|
||||||
)
|
|
||||||
|
|
||||||
for clip in [clip1, clip2]:
|
|
||||||
assert clip.dims["time"] == expected_spec_width
|
|
||||||
assert clip.dims["audio_time"] == expected_audio_width
|
|
||||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
|
||||||
assert clip.audio.shape[0] == expected_audio_width
|
|
||||||
|
|
||||||
assert not np.isclose(clip1.time.min(), clip2.time.min())
|
|
||||||
assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min())
|
|
||||||
|
|
||||||
for clip in [clip1, clip2]:
|
|
||||||
time_span = clip.time.max() - clip.time.min()
|
|
||||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
|
||||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
|
||||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
|
||||||
|
|
||||||
max_start_time = (
|
|
||||||
(long_dataset.time.max() - long_dataset.time.min())
|
|
||||||
- CLIP_DURATION
|
|
||||||
+ MAX_EMPTY
|
|
||||||
)
|
|
||||||
assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
|
||||||
assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
|
||||||
|
|
||||||
|
|
||||||
def test_clipper_random_max_empty_effect(long_dataset):
|
|
||||||
"""Check that max_empty influences the possible start times."""
|
|
||||||
seed = 123
|
|
||||||
data_duration = long_dataset.time.max() - long_dataset.time.min()
|
|
||||||
|
|
||||||
np.random.seed(seed)
|
|
||||||
clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0)
|
|
||||||
max_start_time0 = data_duration - CLIP_DURATION
|
|
||||||
start_times0 = []
|
|
||||||
|
|
||||||
for _ in range(20):
|
|
||||||
clip, _, _ = clipper0.extract_clip(long_dataset)
|
|
||||||
start_times0.append(clip.time.min().item())
|
|
||||||
|
|
||||||
assert all(
|
|
||||||
st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0
|
|
||||||
)
|
|
||||||
assert any(st > 0.1 for st in start_times0)
|
|
||||||
|
|
||||||
np.random.seed(seed)
|
|
||||||
clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2)
|
|
||||||
max_start_time_pos = data_duration - CLIP_DURATION + 0.2
|
|
||||||
start_times_pos = []
|
|
||||||
for _ in range(20):
|
|
||||||
clip, _, _ = clipper_pos.extract_clip(long_dataset)
|
|
||||||
start_times_pos.append(clip.time.min().item())
|
|
||||||
assert all(
|
|
||||||
st <= max_start_time_pos + 1 / SPEC_SAMPLERATE
|
|
||||||
for st in start_times_pos
|
|
||||||
)
|
|
||||||
|
|
||||||
assert any(st > max_start_time0 + 1e-6 for st in start_times_pos)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clipper_short_dataset_random(short_dataset):
|
|
||||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
|
||||||
clip, _, _ = clipper.extract_clip(short_dataset)
|
|
||||||
|
|
||||||
expected_spec_width = _compute_expected_width(
|
|
||||||
short_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
expected_audio_width = _compute_expected_width(
|
|
||||||
short_dataset, CLIP_DURATION, "audio_time"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert clip.sizes["time"] == expected_spec_width
|
|
||||||
assert clip.sizes["audio_time"] == expected_audio_width
|
|
||||||
assert clip["spectrogram"].shape[1] == expected_spec_width
|
|
||||||
assert clip["audio"].shape[0] == expected_audio_width
|
|
||||||
|
|
||||||
assert np.any(clip.spectrogram == 0)
|
|
||||||
assert np.any(clip.audio == 0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clipper_exact_dataset_random(exact_dataset):
|
|
||||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
|
||||||
clip, _, _ = clipper.extract_clip(exact_dataset)
|
|
||||||
|
|
||||||
expected_spec_width = _compute_expected_width(
|
|
||||||
exact_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
expected_audio_width = _compute_expected_width(
|
|
||||||
exact_dataset, CLIP_DURATION, "audio_time"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert clip.dims["time"] == expected_spec_width
|
|
||||||
assert clip.dims["audio_time"] == expected_audio_width
|
|
||||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
|
||||||
assert clip.audio.shape[0] == expected_audio_width
|
|
||||||
|
|
||||||
time_span = clip.time.max() - clip.time.min()
|
|
||||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
|
||||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
|
||||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.terms import get_term
|
from soundevent.terms import get_term
|
||||||
|
|
||||||
@ -10,6 +8,7 @@ from batdetect2.targets import build_targets, load_target_config
|
|||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
from batdetect2.typing import ModelOutput
|
from batdetect2.typing import ModelOutput
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -35,6 +34,8 @@ def build_from_config(
|
|||||||
labeller = build_clip_labeler(
|
labeller = build_clip_labeler(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=labels_config,
|
config=labels_config,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
)
|
)
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
targets,
|
targets,
|
||||||
@ -48,62 +49,8 @@ def build_from_config(
|
|||||||
return build
|
return build
|
||||||
|
|
||||||
|
|
||||||
# TODO: better name
|
|
||||||
def test_generated_train_example_has_expected_outputs(
|
|
||||||
build_from_config,
|
|
||||||
recording,
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
labels:
|
|
||||||
targets:
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: bottom-left
|
|
||||||
classes:
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
preprocessing:
|
|
||||||
postprocessing:
|
|
||||||
"""
|
|
||||||
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
|
||||||
se1 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
clip_annotation = data.ClipAnnotation(
|
|
||||||
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
|
||||||
sound_events=[se1],
|
|
||||||
)
|
|
||||||
|
|
||||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
|
||||||
|
|
||||||
assert isinstance(encoded, xr.Dataset)
|
|
||||||
assert "audio" in encoded
|
|
||||||
assert "spectrogram" in encoded
|
|
||||||
assert "detection" in encoded
|
|
||||||
assert "class" in encoded
|
|
||||||
assert "size" in encoded
|
|
||||||
|
|
||||||
spec_shape = encoded["spectrogram"].shape
|
|
||||||
assert len(spec_shape) == 2
|
|
||||||
|
|
||||||
height, width = spec_shape
|
|
||||||
assert encoded["detection"].shape == (height, width)
|
|
||||||
assert encoded["class"].shape == (1, height, width)
|
|
||||||
assert encoded["size"].shape == (2, height, width)
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object(
|
def test_encoding_decoding_roundtrip_recovers_object(
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
build_from_config,
|
build_from_config,
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
@ -136,13 +83,17 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
|||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
|
|
||||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
encoded = generate_train_example(
|
||||||
|
clip_annotation, sample_audio_loader, preprocessor, labeller
|
||||||
|
)
|
||||||
predictions = postprocessor.get_predictions(
|
predictions = postprocessor.get_predictions(
|
||||||
ModelOutput(
|
ModelOutput(
|
||||||
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
detection_probs=encoded["detection_heatmap"]
|
||||||
size_preds=torch.tensor([encoded["size"].data]),
|
.unsqueeze(0)
|
||||||
class_probs=torch.tensor([encoded["class"].data]),
|
.unsqueeze(0),
|
||||||
features=torch.tensor([[encoded["spectrogram"].data]]),
|
size_preds=encoded["size_heatmap"].unsqueeze(0),
|
||||||
|
class_probs=encoded["class_heatmap"].unsqueeze(0),
|
||||||
|
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
|
||||||
),
|
),
|
||||||
[clip],
|
[clip],
|
||||||
)[0]
|
)[0]
|
||||||
@ -185,6 +136,7 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
|||||||
|
|
||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
build_from_config,
|
build_from_config,
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
@ -222,13 +174,20 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
|||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
|
|
||||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
encoded = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
sample_audio_loader,
|
||||||
|
preprocessor,
|
||||||
|
labeller,
|
||||||
|
)
|
||||||
predictions = postprocessor.get_predictions(
|
predictions = postprocessor.get_predictions(
|
||||||
ModelOutput(
|
ModelOutput(
|
||||||
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
detection_probs=encoded["detection_heatmap"]
|
||||||
size_preds=torch.tensor([encoded["size"].data]),
|
.unsqueeze(0)
|
||||||
class_probs=torch.tensor([encoded["class"].data]),
|
.unsqueeze(0),
|
||||||
features=torch.tensor([[encoded["spectrogram"].data]]),
|
size_preds=encoded["size_heatmap"].unsqueeze(0),
|
||||||
|
class_probs=encoded["class_heatmap"].unsqueeze(0),
|
||||||
|
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
|
||||||
),
|
),
|
||||||
[clip],
|
[clip],
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
@ -1,23 +1,59 @@
|
|||||||
import numpy as np
|
import torch
|
||||||
|
|
||||||
from batdetect2.utils.arrays import adjust_width, extend_width
|
from batdetect2.utils.arrays import adjust_width, extend_width
|
||||||
|
|
||||||
|
|
||||||
def test_extend_width():
|
def test_extend_width():
|
||||||
array = np.random.random([1, 1, 128, 100])
|
array = torch.rand([1, 1, 128, 100])
|
||||||
|
|
||||||
extended = extend_width(array, 100)
|
extended = extend_width(array, 100)
|
||||||
|
|
||||||
assert extended.shape == (1, 1, 128, 200)
|
assert extended.shape == (1, 1, 128, 200)
|
||||||
|
|
||||||
|
extended = extend_width(array, 100, axis=0)
|
||||||
|
assert extended.shape == (101, 1, 128, 100)
|
||||||
|
|
||||||
|
extended = extend_width(array, 100, axis=1)
|
||||||
|
assert extended.shape == (1, 101, 128, 100)
|
||||||
|
|
||||||
|
extended = extend_width(array, 100, axis=2)
|
||||||
|
assert extended.shape == (1, 1, 228, 100)
|
||||||
|
|
||||||
|
extended = extend_width(array, 100, axis=3)
|
||||||
|
assert extended.shape == (1, 1, 128, 200)
|
||||||
|
|
||||||
|
extended = extend_width(array, 100, axis=-2)
|
||||||
|
assert extended.shape == (1, 1, 228, 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extends_with_value():
|
||||||
|
array = torch.rand([1, 1, 128, 100])
|
||||||
|
extended = extend_width(array, 100, value=-1)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
extended[:, :, :, 100:],
|
||||||
|
torch.ones_like(array) * -1,
|
||||||
|
rtol=0,
|
||||||
|
atol=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_can_adjust_short_width():
|
def test_can_adjust_short_width():
|
||||||
array = np.random.random([1, 1, 128, 100])
|
array = torch.rand([1, 1, 128, 100])
|
||||||
extended = adjust_width(array, 512)
|
extended = adjust_width(array, 512)
|
||||||
assert extended.shape == (1, 1, 128, 512)
|
assert extended.shape == (1, 1, 128, 512)
|
||||||
|
|
||||||
|
extended = adjust_width(array, 512, axis=0)
|
||||||
|
assert extended.shape == (512, 1, 128, 100)
|
||||||
|
|
||||||
|
extended = adjust_width(array, 512, axis=1)
|
||||||
|
assert extended.shape == (1, 512, 128, 100)
|
||||||
|
|
||||||
|
extended = adjust_width(array, 512, axis=2)
|
||||||
|
assert extended.shape == (1, 1, 512, 100)
|
||||||
|
|
||||||
|
extended = adjust_width(array, 512, axis=3)
|
||||||
|
assert extended.shape == (1, 1, 128, 512)
|
||||||
|
|
||||||
|
|
||||||
def test_can_adjust_long_width():
|
def test_can_adjust_long_width():
|
||||||
array = np.random.random([1, 1, 128, 512])
|
array = torch.rand([1, 1, 128, 512])
|
||||||
extended = adjust_width(array, 256)
|
extended = adjust_width(array, 256)
|
||||||
assert extended.shape == (1, 1, 128, 256)
|
assert extended.shape == (1, 1, 128, 256)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user