Working towards training code

This commit is contained in:
mbsantiago 2024-05-23 13:04:55 +01:00
parent c66d14b7c7
commit 17cf958cd3
18 changed files with 624 additions and 1127 deletions

1
.gitignore vendored
View File

@ -110,3 +110,4 @@ experiments/*
!batdetect2_notebook.ipynb
!batdetect2/models/*.pth.tar
!tests/data/*.wav
notebooks/lightning_logs

View File

@ -1,304 +0,0 @@
from functools import wraps
from typing import Callable, List, Optional, Tuple
import numpy as np
import xarray as xr
from soundevent import data
from soundevent.geometry import compute_bounds
ClipAugmentation = Callable[[data.ClipAnnotation], data.ClipAnnotation]
AudioAugmentation = Callable[
[xr.DataArray, data.ClipAnnotation],
Tuple[xr.DataArray, data.ClipAnnotation],
]
SpecAugmentation = Callable[
[xr.DataArray, data.ClipAnnotation],
Tuple[xr.DataArray, data.ClipAnnotation],
]
ClipProvider = Callable[
[data.ClipAnnotation], Tuple[xr.DataArray, data.ClipAnnotation]
]
"""A function that provides some clip and its annotation.
Usually this function loads a random clip from a dataset. Takes
as input a clip annotation that can be used to filter the clips
to load (in case you want to avoid loading the same clip multiple times).
"""
AUGMENTATION_PROBABILITY = 0.2
MAX_DELAY = 0.005
STRETCH_SQUEEZE_DELTA = 0.04
MASK_MAX_TIME_PERC: float = 0.05
MASK_MAX_FREQ_PERC: float = 0.10
def maybe_apply(
augmentation: Callable,
prob: float = AUGMENTATION_PROBABILITY,
) -> Callable:
"""Apply an augmentation with a given probability."""
@wraps(augmentation)
def _augmentation(x):
if np.random.rand() > prob:
return x
return augmentation(x)
return _augmentation
def select_random_subclip(
clip_annotation: data.ClipAnnotation,
duration: Optional[float] = None,
proportion: float = 0.9,
) -> data.ClipAnnotation:
"""Select a random subclip from a clip."""
clip = clip_annotation.clip
if duration is None:
clip_duration = clip.end_time - clip.start_time
duration = clip_duration * proportion
start_time = np.random.uniform(clip.start_time, clip.end_time - duration)
return clip_annotation.model_copy(
update=dict(
clip=clip.model_copy(
update=dict(
start_time=start_time,
end_time=start_time + duration,
)
)
)
)
def combine_audio(
audio1: xr.DataArray,
audio2: xr.DataArray,
alpha: Optional[float] = None,
min_alpha: float = 0.3,
max_alpha: float = 0.7,
) -> xr.DataArray:
"""Combine two audio clips."""
if alpha is None:
alpha = np.random.uniform(min_alpha, max_alpha)
return alpha * audio1 + (1 - alpha) * audio2.data
def random_mix(
audio: xr.DataArray,
clip: data.ClipAnnotation,
provider: Optional[ClipProvider] = None,
alpha: Optional[float] = None,
min_alpha: float = 0.3,
max_alpha: float = 0.7,
join_annotations: bool = True,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Mix two audio clips."""
if provider is None:
raise ValueError("No audio provider given.")
try:
other_audio, other_clip = provider(clip)
except (StopIteration, ValueError):
raise ValueError("No more audio sources available.")
new_audio = combine_audio(
audio,
other_audio,
alpha=alpha,
min_alpha=min_alpha,
max_alpha=max_alpha,
)
if join_annotations:
clip = clip.model_copy(
update=dict(
sound_events=clip.sound_events + other_clip.sound_events,
)
)
return new_audio, clip
def add_echo(
audio: xr.DataArray,
clip: data.ClipAnnotation,
delay: Optional[float] = None,
alpha: Optional[float] = None,
min_alpha: float = 0.0,
max_alpha: float = 1.0,
max_delay: float = MAX_DELAY,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Add a delay to the audio."""
if delay is None:
delay = np.random.uniform(0, max_delay)
if alpha is None:
alpha = np.random.uniform(min_alpha, max_alpha)
samplerate = audio.attrs["samplerate"]
offset = int(delay * samplerate)
# NOTE: We use the copy method to avoid modifying the original audio
# data.
new_audio = audio.copy()
new_audio[offset:] += alpha * audio.data[:-offset]
return new_audio, clip
def scale_volume(
spec: xr.DataArray,
clip: data.ClipAnnotation,
factor: Optional[float] = None,
max_scaling: float = 2,
min_scaling: float = 0,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Scale the volume of a spectrogram."""
if factor is None:
factor = np.random.uniform(min_scaling, max_scaling)
return spec * factor, clip
def scale_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation,
time_factor: float = 1,
frequency_factor: float = 1,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
new_geometry = data.BoundingBox(
coordinates=[
start_time * time_factor,
low_freq * frequency_factor,
end_time * time_factor,
high_freq * frequency_factor,
]
)
return sound_event_annotation.model_copy(
update=dict(
sound_event=sound_event.model_copy(
update=dict(
geometry=new_geometry,
)
)
)
)
def warp_spectrogram(
spec: xr.DataArray,
clip: data.ClipAnnotation,
factor: Optional[float] = None,
delta: float = STRETCH_SQUEEZE_DELTA,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Warp a spectrogram."""
if factor is None:
factor = np.random.uniform(1 - delta, 1 + delta)
start_time = clip.clip.start_time
end_time = clip.clip.end_time
duration = end_time - start_time
new_time = np.linspace(
start_time,
start_time + duration * factor,
spec.time.size,
)
scaled_spec = spec.interp(
time=new_time,
method="linear",
kwargs={"fill_value": 0},
)
scaled_spec.coords["time"] = spec.time
scaled_clip = clip.model_copy(
update=dict(
sound_events=[
scale_sound_event_annotation(
sound_event_annotation,
time_factor=1 / factor,
)
for sound_event_annotation in clip.sound_events
]
)
)
return scaled_spec, scaled_clip
def mask_axis(
array: xr.DataArray,
axis: str,
start: float,
end: float,
mask_value: float = 0,
) -> xr.DataArray:
if axis not in array.dims:
raise ValueError(f"Axis {axis} not found in array")
coord = array[axis]
return array.where((coord < start) | (coord > end), mask_value)
def mask_time(
spec: xr.DataArray,
clip: data.ClipAnnotation,
max_time_mask: float = MASK_MAX_TIME_PERC,
max_num_masks: int = 3,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Mask a random section of the time axis."""
num_masks = np.random.randint(1, max_num_masks + 1)
for _ in range(num_masks):
mask_size = np.random.uniform(0, max_time_mask)
start = np.random.uniform(0, spec.time[-1] - mask_size)
end = start + mask_size
spec = mask_axis(spec, "time", start, end)
return spec, clip
def mask_frequency(
spec: xr.DataArray,
clip: data.ClipAnnotation,
max_freq_mask: float = MASK_MAX_FREQ_PERC,
max_num_masks: int = 3,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Mask a random section of the frequency axis."""
num_masks = np.random.randint(1, max_num_masks + 1)
for _ in range(num_masks):
mask_size = np.random.uniform(0, max_freq_mask)
start = np.random.uniform(0, spec.frequency[-1] - mask_size)
end = start + mask_size
spec = mask_axis(spec, "frequency", start, end)
return spec, clip
CLIP_AUGMENTATIONS: List[ClipAugmentation] = [
select_random_subclip,
]
AUDIO_AUGMENTATIONS: List[AudioAugmentation] = [
add_echo,
random_mix,
]
SPEC_AUGMENTATIONS: List[SpecAugmentation] = [
scale_volume,
warp_spectrogram,
mask_time,
mask_frequency,
]

View File

@ -11,7 +11,7 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2 import types
from batdetect2.data.labels import LabelFn
from batdetect2.data.labels import ClassMapper
PathLike = Union[Path, str, os.PathLike]
@ -54,7 +54,7 @@ def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
def convert_to_annotation_group(
annotation: data.ClipAnnotation,
label_fn: LabelFn = lambda _: None,
class_mapper: ClassMapper,
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
class_fn: ClassFn = lambda _: 0,
individual_fn: IndividualFn = lambda _: 0,
@ -80,8 +80,8 @@ def convert_to_annotation_group(
continue
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
class_id = label_fn(sound_event) or -1
event = event_fn(sound_event)
class_id = class_mapper.transform(sound_event) or -1
event = event_fn(sound_event) or ""
individual_id = individual_fn(sound_event) or -1
start_times.append(start_time)

View File

@ -4,7 +4,6 @@ from soundevent import data
from torch.utils.data import Dataset
__all__ = [
"ClipAnnotationDataset",
"ClipDataset",
]
@ -12,31 +11,7 @@ __all__ = [
E = TypeVar("E")
class ClipAnnotationDataset(Dataset, Generic[E]):
clip_annotations: List[data.ClipAnnotation]
transform: Callable[[data.ClipAnnotation], E]
def __init__(
self,
clip_annotations: Iterable[data.ClipAnnotation],
transform: Callable[[data.ClipAnnotation], E],
name: str = "ClipAnnotationDataset",
):
self.clip_annotations = list(clip_annotations)
self.transform = transform
self.name = name
def __len__(self) -> int:
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> E:
return self.transform(self.clip_annotations[idx])
class ClipDataset(Dataset, Generic[E]):
clips: List[data.Clip]
transform: Callable[[data.Clip], E]

View File

@ -1,113 +1,29 @@
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Tuple
import numpy as np
import xarray as xr
from scipy.ndimage import gaussian_filter
from soundevent import data, geometry
from soundevent import data, geometry, arrays
from soundevent.geometry.operations import Positions
from soundevent.types import ClassMapper
__all__ = [
"ClassMapper",
"generate_heatmaps",
]
PositionFn = Callable[[data.SoundEvent], Tuple[float, float]]
"""Convert a sound event to a single position in time-frequency space."""
SizeFn = Callable[[data.SoundEvent, float, float], np.ndarray]
"""Compute the size of a sound event in time-frequency space.
The time and frequency scales are provided as arguments to allow
modifying the size of the sound event based on the spectrogram
parameters.
"""
LabelFn = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Convert a sound event annotation to a label.
When the label is None, this indicates that the sound event does not
belong to any of the classes of interest.
"""
TARGET_SIGMA = 3.0
GENERIC_LABEL = "__UNKNOWN__"
def get_lower_left_position(
sound_event: data.SoundEvent,
) -> Tuple[float, float]:
if sound_event.geometry is None:
raise ValueError("Sound event has no geometry.")
start_time, low_freq, _, _ = geometry.compute_bounds(sound_event.geometry)
return start_time, low_freq
def get_bbox_size(
sound_event: data.SoundEvent,
time_scale: float = 1.0,
frequency_scale: float = 1.0,
) -> np.ndarray:
if sound_event.geometry is None:
raise ValueError("Sound event has no geometry.")
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
sound_event.geometry
)
return np.array(
[
time_scale * (end_time - start_time),
frequency_scale * (high_freq - low_freq),
]
)
def _tag_key(tag: data.Tag) -> Tuple[str, str]:
return (tag.key, tag.value)
def set_value_at_position(
array: xr.DataArray,
value: Any,
**query,
) -> xr.DataArray:
dims = {dim: n for n, dim in enumerate(array.dims)}
indexer: List[Union[slice, int]] = [slice(None) for _ in range(array.ndim)]
for key, coord in query.items():
if key not in dims:
raise ValueError(f"Dimension {key} not found in array.")
coordinates = array.indexes[key]
indexer[dims[key]] = coordinates.get_loc(coordinates.asof(coord))
if isinstance(value, (tuple, list)):
value = np.array(value)
array.data[tuple(indexer)] = value
return array
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: xr.DataArray,
num_classes: int = 1,
label_fn: LabelFn = lambda _: None,
class_mapper: ClassMapper,
target_sigma: float = TARGET_SIGMA,
size_fn: SizeFn = get_bbox_size,
position_fn: PositionFn = get_lower_left_position,
class_labels: Optional[List[str]] = None,
position: Positions = "bottom-left",
dtype=np.float32,
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
if class_labels is None:
class_labels = [str(i) for i in range(num_classes)]
if len(class_labels) != num_classes:
raise ValueError(
"Number of class labels must match the number of classes."
)
shape = dict(zip(spec.dims, spec.shape))
if "time" not in shape or "frequency" not in shape:
@ -115,8 +31,8 @@ def generate_heatmaps(
"Spectrogram must have time and frequency dimensions."
)
time_duration = spec.time.attrs["max"] - spec.time.attrs["min"]
freq_bandwidth = spec.frequency.attrs["max"] - spec.frequency.attrs["min"]
time_duration = arrays.get_dim_width(spec, dim="time")
freq_bandwidth = arrays.get_dim_width(spec, dim="frequency")
# Compute the size factors
time_scale = 1 / time_duration
@ -125,10 +41,10 @@ def generate_heatmaps(
# Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray(
data=np.zeros((num_classes, *spec.shape), dtype=dtype),
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
dims=["category", *spec.dims],
coords={
"category": class_labels,
"category": class_mapper.class_labels,
**spec.coords,
},
)
@ -142,11 +58,17 @@ def generate_heatmaps(
)
for sound_event_annotation in clip_annotation.sound_events:
geom = sound_event_annotation.sound_event.geometry
if geom is None:
continue
# Get the position of the sound event
time, frequency = position_fn(sound_event_annotation.sound_event)
time, frequency = geometry.get_geometry_point(geom, position=position)
print(time, frequency)
# Set 1.0 at the position of the sound event in the detection heatmap
detection_heatmap = set_value_at_position(
detection_heatmap = arrays.set_value_at_pos(
detection_heatmap,
1.0,
time=time,
@ -154,35 +76,37 @@ def generate_heatmaps(
)
# Set the size of the sound event at the position in the size heatmap
size = size_fn(
sound_event_annotation.sound_event,
time_scale,
frequency_scale,
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
size_heatmap = set_value_at_position(
size = np.array(
[
(end_time - start_time) * time_scale,
(high_freq - low_freq) * frequency_scale,
]
)
size_heatmap = arrays.set_value_at_pos(
size_heatmap,
size,
time=time,
frequency=frequency,
)
# Get the label id for the sound event
label = label_fn(sound_event_annotation)
# Get the class name of the sound event
class_name = class_mapper.transform(sound_event_annotation)
if label is None or label not in class_labels:
# If the label is None or not in the class labels, we skip the
# sound event
if class_name is None:
# If the label is None skip the sound event
continue
# Set 1.0 at the position and category of the sound event in the class
# heatmap
class_heatmap = set_value_at_position(
class_heatmap = arrays.set_value_at_pos(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=label,
category=class_name,
)
# Apply gaussian filters
@ -207,25 +131,3 @@ def generate_heatmaps(
).fillna(0.0)
return detection_heatmap, class_heatmap, size_heatmap
class Labeler:
def __init__(self, tags: List[data.Tag]):
"""Create a labeler from a list of tags.
Each tag is assigned a unique label. The labeler can then be used
to convert sound event annotations to labels.
"""
self.tags = tags
self._label_map = {_tag_key(tag): i for i, tag in enumerate(tags)}
self._inverse_label_map = {v: k for k, v in self._label_map.items()}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> Optional[int]:
for tag in sound_event_annotation.tags:
key = _tag_key(tag)
if key in self._label_map:
return self._label_map[key]
return None

View File

@ -1,15 +1,22 @@
"""Module containing functions for preprocessing audio clips."""
import random
from typing import List, Optional, Tuple
from typing import Optional
import librosa
import librosa.core.spectrum
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
from pydantic import BaseModel, Field
from scipy.signal import resample_poly
from soundevent import audio, data
from soundevent import audio, data, arrays
from soundevent.arrays import operations as ops
__all__ = [
"PreprocessingConfig",
"preprocess_audio_clip",
]
TARGET_SAMPLERATE_HZ = 256000
SCALE_RAW_AUDIO = False
@ -26,20 +33,37 @@ DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False
class PreprocessingConfig(BaseModel):
"""Configuration for preprocessing data."""
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
spec_scale: str = Field(default=SPEC_SCALE)
denoise_spec_avg: bool = DENOISE_SPEC_AVG
max_scale_spec: bool = MAX_SCALE_SPEC
duration: Optional[float] = DEFAULT_DURATION
spec_height: int = SPEC_HEIGHT
spec_time_period: float = SPEC_TIME_PERIOD
def preprocess_audio_clip(
clip: data.Clip,
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
scale_audio: bool = SCALE_RAW_AUDIO,
fft_win_length: float = FFT_WIN_LENGTH_S,
fft_overlap: float = FFT_OVERLAP,
max_freq: int = MAX_FREQ_HZ,
min_freq: int = MIN_FREQ_HZ,
spec_scale: str = SPEC_SCALE,
denoise_spec_avg: bool = True,
max_scale_spec: bool = False,
duration: Optional[float] = DEFAULT_DURATION,
spec_height: int = SPEC_HEIGHT,
spec_time_period: float = SPEC_TIME_PERIOD,
config: PreprocessingConfig = PreprocessingConfig(),
) -> xr.DataArray:
"""Preprocesses audio clip to generate spectrogram.
@ -47,45 +71,8 @@ def preprocess_audio_clip(
----------
clip
The audio clip to preprocess.
target_sampling_rate
Target sampling rate for the audio. If the audio has a different
sampling rate, it will be resampled to this rate.
scale_audio
Whether to scale the audio amplitudes to a range of [-1, 1].
By default, the audio is not scaled.
fft_win_length
Length of the FFT window in seconds.
fft_overlap
Amount of overlap between FFT windows as a fraction of the window
length.
max_freq
Maximum frequency for spectrogram. Anything above this frequency will
be cropped.
min_freq
Minimum frequency for spectrogram. Anything below this frequency will
be cropped.
spec_scale
Scaling method for the spectrogram. Can be "pcen", "log" or
"amplitude".
denoise_spec_avg
Whether to denoise the spectrogram. Denoising is done by subtracting
the average of the spectrogram from the spectrogram and clipping
negative values to 0.
max_scale_spec
Whether to max scale the spectrogram. Max scaling is done by dividing
the spectrogram by its maximum value thus scaling values to [0, 1].
duration
Duration of the spectrogram in seconds. If the clip duration is
different from this value, the spectrogram will be cropped or extended
to match this duration. If None, the spectrogram will have the same
duration as the clip.
spec_height
Number of frequency bins for the spectrogram. This is the height of
the final spectrogram.
spec_time_period
Time period for each spectrogram bin in seconds. The spectrogram array
will be resized (using bilinear interpolation) to have this time
period.
config
Configuration for preprocessing.
Returns
-------
@ -95,35 +82,29 @@ def preprocess_audio_clip(
"""
wav = load_clip_audio(
clip,
target_sampling_rate=target_sampling_rate,
scale=scale_audio,
)
wav = wav.assign_attrs(
recording_id=str(wav.attrs["recording_id"]),
clip_id=str(wav.attrs["clip_id"]),
path=str(wav.attrs["path"]),
target_sampling_rate=config.target_samplerate,
scale=config.scale_audio,
)
spec = compute_spectrogram(
wav,
fft_win_length=fft_win_length,
fft_overlap=fft_overlap,
max_freq=max_freq,
min_freq=min_freq,
spec_scale=spec_scale,
denoise_spec_avg=denoise_spec_avg,
max_scale_spec=max_scale_spec,
fft_win_length=config.fft_win_length,
fft_overlap=config.fft_overlap,
max_freq=config.max_freq,
min_freq=config.min_freq,
spec_scale=config.spec_scale,
denoise_spec_avg=config.denoise_spec_avg,
max_scale_spec=config.max_scale_spec,
)
if duration is not None:
spec = adjust_spec_duration(clip, spec, duration)
if config.duration is not None:
spec = adjust_spec_duration(clip, spec, config.duration)
duration = get_dim_width(spec, dim="time")
return resize_spectrogram(
duration = arrays.get_dim_width(spec, dim="time")
return ops.resize(
spec,
time_bins=int(np.ceil(duration / spec_time_period)),
freq_bins=spec_height,
time=int(np.ceil(duration / config.spec_time_period)),
frequency=config.spec_height,
)
@ -138,18 +119,18 @@ def adjust_spec_duration(
return spec
if current_duration > duration:
return crop_axis(
return arrays.crop_dim(
spec,
dim="time",
start=clip.start_time,
end=clip.start_time + duration,
stop=clip.start_time + duration,
)
return extend_axis(
return arrays.extend_dim(
spec,
dim="time",
start=clip.start_time,
end=clip.start_time + duration,
stop=clip.start_time + duration,
)
@ -159,21 +140,15 @@ def load_clip_audio(
scale: bool = SCALE_RAW_AUDIO,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
wav = audio.load_clip(clip).sel(channel=0)
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
if scale:
wav = scale_audio(wav)
wav = ops.center(wav)
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
wav.coords["time"] = wav.time.assign_attrs(
unit="s",
long_name="Seconds since start of recording",
min=clip.start_time,
max=clip.end_time,
)
return wav
return wav.astype(dtype)
def resample_audio(
@ -181,14 +156,14 @@ def resample_audio(
target_samplerate: int = TARGET_SAMPLERATE_HZ,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
if "samplerate" not in wav.attrs:
raise ValueError("Audio must have a 'samplerate' attribute")
if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension")
time_axis: int = wav.get_axis_num("time") # type: ignore
original_samplerate = wav.attrs["samplerate"]
start, stop = arrays.get_dim_range(wav, dim="time")
step = arrays.get_dim_step(wav, dim="time")
original_samplerate = int(1 / step)
if original_samplerate == target_samplerate:
return wav.astype(dtype)
@ -202,8 +177,8 @@ def resample_audio(
)
resampled_times = np.linspace(
wav.time[0],
wav.time[-1],
start,
stop + step,
len(resampled),
endpoint=False,
dtype=dtype,
@ -214,23 +189,15 @@ def resample_audio(
dims=wav.dims,
coords={
**wav.coords,
"time": resampled_times,
},
attrs={
**wav.attrs,
"samplerate": target_samplerate,
"time": arrays.create_time_dim_from_array(
resampled_times,
samplerate=target_samplerate,
),
},
attrs=wav.attrs,
)
def scale_audio(
audio: xr.DataArray,
eps: float = 10e-6,
) -> xr.DataArray:
audio = audio - audio.mean()
return audio / np.add(np.abs(audio).max(), eps, dtype=audio.dtype)
def compute_spectrogram(
wav: xr.DataArray,
fft_win_length: float = FFT_WIN_LENGTH_S,
@ -249,12 +216,12 @@ def compute_spectrogram(
dtype=dtype,
)
spec = crop_axis(
spec = arrays.crop_dim(
spec,
dim="frequency",
start=min_freq,
end=max_freq,
)
stop=max_freq,
).astype(dtype)
spec = scale_spectrogram(spec, scale=spec_scale)
@ -262,172 +229,67 @@ def compute_spectrogram(
spec = denoise_spectrogram(spec)
if max_scale_spec:
spec = max_scale_spectrogram(spec)
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
return spec
def crop_axis(
arr: xr.DataArray,
dim: str,
start: float,
end: float,
right_closed: bool = False,
left_closed: bool = True,
eps: float = 10e-6,
) -> xr.DataArray:
coord = arr.coords[dim]
if not all(attr in coord.attrs for attr in ["min", "max"]):
raise ValueError(
f"Coordinate '{dim}' must have 'min' and 'max' attributes"
)
current_min = coord.attrs["min"]
current_max = coord.attrs["max"]
if start < current_min or end > current_max:
raise ValueError(
f"Cannot select axis '{dim}' from {start} to {end}. "
f"Axis range is {current_min} to {current_max}"
)
slice_end = end
if not right_closed:
slice_end = end - eps
slice_start = start
if not left_closed:
slice_start = start + eps
arr = arr.sel({dim: slice(slice_start, slice_end)})
arr.coords[dim].attrs.update(
min=start,
max=end,
)
return arr
def extend_axis(
arr: xr.DataArray,
dim: str,
start: float,
end: float,
fill_value: float = 0,
) -> xr.DataArray:
coord = arr.coords[dim]
if not all(attr in coord.attrs for attr in ["min", "max", "period"]):
raise ValueError(
f"Coordinate '{dim}' must have 'min', 'max' and 'period' attributes"
" to extend axis"
)
current_min = coord.attrs["min"]
current_max = coord.attrs["max"]
period = coord.attrs["period"]
coords = coord.data
if start < current_min:
new_coords = np.arange(
current_min,
start,
-period,
dtype=coord.dtype,
)[1:][::-1]
coords = np.concatenate([new_coords, coords])
if end > current_max:
new_coords = np.arange(
current_max,
end,
period,
dtype=coord.dtype,
)[1:]
coords = np.concatenate([coords, new_coords])
arr = arr.reindex(
{dim: coords},
fill_value=fill_value, # type: ignore
)
arr.coords[dim].attrs.update(
min=start,
max=end,
)
return arr
return spec.astype(dtype)
def gen_mag_spectrogram(
audio: xr.DataArray,
wave: xr.DataArray,
window_len: float,
overlap_perc: float,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
sampling_rate = audio.attrs["samplerate"]
start_time, end_time = arrays.get_dim_range(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time")
sampling_rate = 1 / step
hop_len = window_len * (1 - overlap_perc)
nfft = int(window_len * sampling_rate)
noverlap = int(overlap_perc * nfft)
start_time = audio.time.attrs["min"]
end_time = audio.time.attrs["max"]
# compute spec
spec, _ = librosa.core.spectrum._spectrogram(
y=audio.data,
y=wave.data,
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
)
spec = xr.DataArray(
return xr.DataArray(
data=spec.astype(dtype),
dims=["frequency", "time"],
coords={
"frequency": np.linspace(
"frequency": arrays.create_frequency_dim_from_array(
np.linspace(
0,
sampling_rate / 2,
spec.shape[0],
endpoint=False,
dtype=dtype,
),
"time": np.linspace(
step=sampling_rate / nfft,
),
"time": arrays.create_time_dim_from_array(
np.linspace(
start_time,
end_time - (window_len - hop_len),
spec.shape[1],
endpoint=False,
dtype=dtype,
),
step=hop_len,
),
},
attrs={
**audio.attrs,
**wave.attrs,
"original_samplerate": sampling_rate,
"nfft": nfft,
"noverlap": noverlap,
},
)
# Add metadata to coordinates
spec.coords["time"].attrs.update(
unit="s",
long_name="Time",
min=start_time,
max=end_time - (window_len - hop_len),
period=(nfft - noverlap) / sampling_rate,
)
spec.coords["frequency"].attrs.update(
unit="Hz",
long_name="Frequency",
period=(sampling_rate / nfft),
min=0,
max=sampling_rate / 2,
)
return spec
def denoise_spectrogram(
spec: xr.DataArray,
@ -436,10 +298,7 @@ def denoise_spectrogram(
data=(spec - spec.mean("time")).clip(0),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"denoised": 1,
},
attrs=spec.attrs,
)
@ -448,8 +307,14 @@ def scale_spectrogram(
scale: str = SPEC_SCALE,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"]
if scale == "pcen":
return pcen(spec, dtype=dtype)
smoothing_constant = get_pcen_smoothing_constant(samplerate / 10)
return audio.pcen(
spec * (2**31),
smooth=smoothing_constant,
).astype(dtype)
if scale == "log":
return log_scale(spec, dtype=dtype)
@ -461,126 +326,25 @@ def log_scale(
spec: xr.DataArray,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"]
nfft = spec.attrs["nfft"]
sampling_rate = spec.attrs["samplerate"]
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (1.0 / samplerate)
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
)
return xr.DataArray(
data=np.log1p(log_scaling * spec).astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"scale": "log",
},
attrs=spec.attrs,
)
def pcen(spec: xr.DataArray, dtype: DTypeLike = np.float32) -> xr.DataArray:
sampling_rate = spec.attrs["samplerate"]
data = librosa.pcen(
spec.data * (2**31),
sr=sampling_rate / 10,
)
return xr.DataArray(
data=data.astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"scale": "pcen",
},
)
def max_scale_spectrogram(spec: xr.DataArray, eps=10e-6) -> xr.DataArray:
return xr.DataArray(
data=spec / np.add(spec.max(), eps, dtype=spec.dtype),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"max_scaled": 1,
},
)
def resize_spectrogram(
spec: xr.DataArray,
time_bins: int,
freq_bins: int,
) -> xr.DataArray:
new_times = np.linspace(
spec.time[0],
spec.time[-1],
time_bins,
dtype=spec.time.dtype,
endpoint=True,
)
new_frequencies = np.linspace(
spec.frequency[0],
spec.frequency[-1],
freq_bins,
dtype=spec.frequency.dtype,
endpoint=True,
)
return spec.interp(
coords=dict(
time=new_times,
frequency=new_frequencies,
),
method="linear",
)
def get_dim_width(arr: xr.DataArray, dim: str) -> float:
coord = arr.coords[dim]
attrs = coord.attrs
if "min" in attrs and "max" in attrs:
return attrs["max"] - attrs["min"]
coord_min = coord.min()
coord_max = coord.max()
return float(coord_max - coord_min)
class RandomClipProvider:
def __init__(
self,
clip_annotations: List[data.ClipAnnotation],
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
scale_audio: bool = SCALE_RAW_AUDIO,
):
self.target_sampling_rate = target_sampling_rate
self.scale_audio = scale_audio
self.clip_annotations = clip_annotations
def get_next_clip(self, clip: data.ClipAnnotation) -> data.ClipAnnotation:
tries = 0
while True:
random_clip = random.choice(self.clip_annotations)
if random_clip.clip != clip.clip:
return random_clip
tries += 1
if tries > 4:
raise ValueError("Could not find a different clip")
def __call__(
self,
clip: data.ClipAnnotation,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
random_clip = self.get_next_clip(clip)
wav = load_clip_audio(
random_clip.clip,
target_sampling_rate=self.target_sampling_rate,
scale=self.scale_audio,
)
return wav, random_clip
def get_pcen_smoothing_constant(
sr: int,
time_constant: float = 0.4,
hop_length: int = 512,
) -> float:
t_frames = time_constant * sr / float(hop_length)
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)

View File

@ -68,7 +68,6 @@ def run_nms(
params["fft_win_length"],
params["fft_overlap"],
)
print("duration", duration)
top_k = int(duration * params["nms_top_k_per_sec"])
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)

View File

@ -1,91 +1,11 @@
import os
from typing import Tuple, Union
import torch
from batdetect2.models.encoders import (
from batdetect2.models.feature_extractors import (
Net2DFast,
Net2DFastNoAttn,
Net2DFastNoCoordConv,
)
from batdetect2.models.typing import DetectionModel
__all__ = [
"load_model",
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
]
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)
def load_model(
model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
device: Union[torch.device, str, None] = None,
) -> Tuple[DetectionModel, dict]:
"""Load model from file.
Args:
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
load_weights (bool, optional): Load weights. Defaults to True.
Returns:
model, params: Model and parameters.
Raises:
FileNotFoundError: Model file not found.
ValueError: Unknown model name.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device)
params = net_params["params"]
model: DetectionModel
if params["model_name"] == "Net2DFast":
model = Net2DFast(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
elif params["model_name"] == "Net2DFastNoAttn":
model = Net2DFastNoAttn(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
elif params["model_name"] == "Net2DFastNoCoordConv":
model = Net2DFastNoCoordConv(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
else:
raise ValueError("Unknown model.")
if load_weights:
model.load_state_dict(net_params["state_dict"])
model = model.to(device)
model.eval()
return model, params

View File

@ -1,100 +1,104 @@
from typing import Type
import pytorch_lightning as L
import torch
import xarray as xr
from soundevent import data
from torch import nn, optim
from batdetect2.data.preprocessing import preprocess_audio_clip
from batdetect2.models.typing import EncoderModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
from batdetect2.data.preprocessing import (
preprocess_audio_clip,
PreprocessingConfig,
)
from batdetect2.data.labels import ClassMapper
from batdetect2.models.feature_extractors import Net2DFast
from batdetect2.models.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.train.preprocess import PreprocessingConfig
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
class DetectorModel(L.LightningModule):
def __init__(
self,
encoder: EncoderModel,
num_classes: int,
class_mapper: ClassMapper,
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
learning_rate: float = 1e-3,
input_height: int = 128,
num_features: int = 32,
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
postprocessing_config: PostprocessConfig = PostprocessConfig(),
):
super().__init__()
self.save_hyperparameters()
self.preprocessing_config = preprocessing_config
self.postprocessing_config = postprocessing_config
self.num_classes = num_classes
self.class_mapper = class_mapper
self.learning_rate = learning_rate
self.input_height = input_height
self.num_features = num_features
self.num_classes = class_mapper.num_classes
self.encoder = encoder
self.feature_extractor = feature_extractor_class(
input_height=input_height,
num_features=num_features,
)
self.classifier = nn.Conv2d(
self.encoder.num_filts // 4,
self.feature_extractor.num_features // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
)
self.bbox = nn.Conv2d(
self.encoder.num_filts // 4,
self.feature_extractor.num_features // 4,
2,
kernel_size=1,
padding=0,
)
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.encoder(spec)
features = self.feature_extractor(spec)
classification_logits = self.classifier(features)
classification_probs = torch.softmax(classification_logits, dim=1)
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
return ModelOutput(
detection_probs=detection_probs,
size_preds=self.bbox(features),
class_probs=classification_probs,
class_probs=classification_probs[:, :-1],
features=features,
)
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
config = self.preprocessing_config
return preprocess_audio_clip(
clip,
target_sampling_rate=config.target_samplerate,
scale_audio=config.scale_audio,
fft_win_length=config.fft_win_length,
fft_overlap=config.fft_overlap,
max_freq=config.max_freq,
min_freq=config.min_freq,
spec_scale=config.spec_scale,
denoise_spec_avg=config.denoise_spec_avg,
max_scale_spec=config.max_scale_spec,
config=self.preprocessing_config,
)
def process_clip(self, clip: data.Clip):
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
spectrogram = self.compute_spectrogram(clip)
return self.feature_extractor(
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
)
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
spectrogram = self.compute_spectrogram(clip)
spec_tensor = (
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
)
outputs = self(spec_tensor)
config = self.postprocessing_config
return postprocess_model_outputs(
outputs,
[clip],
nms_kernel_size=config.nms_kernel_size,
detection_threshold=config.detection_threshold,
min_freq=config.min_freq,
max_freq=config.max_freq,
top_k_per_sec=config.top_k_per_sec,
)
class_mapper=self.class_mapper,
config=self.postprocessing_config,
)[0]
def compute_loss(
self,
@ -124,21 +128,8 @@ class DetectorModel(L.LightningModule):
self,
batch: TrainExample,
):
features = self.encoder(batch.spec)
classification_logits = self.classifier(features)
classification_probs = torch.softmax(classification_logits, dim=1)
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
loss = self.compute_loss(
ModelOutput(
detection_probs=detection_probs,
size_preds=self.bbox(features),
class_probs=classification_probs,
features=features,
),
batch,
)
outputs = self.forward(batch.spec)
loss = self.compute_loss(outputs, batch)
self.log("train_loss", loss)
return loss

View File

@ -5,7 +5,6 @@ import torch.fft
import torch.nn.functional as F
from torch import nn
from batdetect2.models.typing import EncoderModel
from batdetect2.models.blocks import (
ConvBlockDownCoordF,
ConvBlockDownStandard,
@ -13,6 +12,7 @@ from batdetect2.models.blocks import (
ConvBlockUpStandard,
SelfAttention,
)
from batdetect2.models.typing import FeatureExtractorModel
__all__ = [
"Net2DFast",
@ -21,84 +21,84 @@ __all__ = [
]
class Net2DFast(EncoderModel):
class Net2DFast(FeatureExtractorModel):
def __init__(
self,
num_filts: int,
num_features: int,
input_height: int = 128,
):
super().__init__()
self.num_filts = num_filts
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
# encoder
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_filts // 4,
self.num_features // 4,
self.input_height,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_filts // 4,
self.num_filts // 2,
self.num_features // 4,
self.num_features // 2,
self.input_height // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_filts // 2,
self.num_filts,
self.num_features // 2,
self.num_features,
self.input_height // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_filts,
self.num_filts * 2,
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
# bottleneck
self.conv_1d = nn.Conv2d(
self.num_filts * 2,
self.num_filts * 2,
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
# decoder
self.conv_up_2 = ConvBlockUpF(
self.num_filts * 2,
self.num_filts // 2,
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_filts // 2,
self.num_filts // 4,
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_filts // 4,
self.num_filts // 4,
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_filts // 4,
self.num_filts // 4,
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
h, w = spec.shape[2:]
@ -135,81 +135,81 @@ class Net2DFast(EncoderModel):
return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoAttn(EncoderModel):
class Net2DFastNoAttn(FeatureExtractorModel):
def __init__(
self,
num_filts: int,
num_features: int,
input_height: int = 128,
):
super().__init__()
self.num_filts = num_filts
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_filts // 4,
self.num_features // 4,
self.input_height,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_filts // 4,
self.num_filts // 2,
self.num_features // 4,
self.num_features // 2,
self.input_height // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_filts // 2,
self.num_filts,
self.num_features // 2,
self.num_features,
self.input_height // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_filts,
self.num_filts * 2,
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_1d = nn.Conv2d(
self.num_filts * 2,
self.num_filts * 2,
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_up_2 = ConvBlockUpF(
self.num_filts * 2,
self.num_filts // 2,
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_filts // 2,
self.num_filts // 4,
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_filts // 4,
self.num_filts // 4,
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_filts // 4,
self.num_filts // 4,
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)
@ -227,80 +227,80 @@ class Net2DFastNoAttn(EncoderModel):
return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoCoordConv(EncoderModel):
class Net2DFastNoCoordConv(FeatureExtractorModel):
def __init__(
self,
num_filts: int,
num_features: int,
input_height: int = 128,
):
super().__init__()
self.num_filts = num_filts
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownStandard(
1,
self.num_filts // 4,
self.num_features // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownStandard(
self.num_filts // 4,
self.num_filts // 2,
self.num_features // 4,
self.num_features // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownStandard(
self.num_filts // 2,
self.num_filts,
self.num_features // 2,
self.num_features,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_filts,
self.num_filts * 2,
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_1d = nn.Conv2d(
self.num_filts * 2,
self.num_filts * 2,
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
self.conv_up_2 = ConvBlockUpStandard(
self.num_filts * 2,
self.num_filts // 2,
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpStandard(
self.num_filts // 2,
self.num_filts // 4,
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpStandard(
self.num_filts // 4,
self.num_filts // 4,
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_filts // 4,
self.num_filts // 4,
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)

View File

@ -8,6 +8,7 @@ import torch
from soundevent import data
from torch import nn
from batdetect2.data.labels import ClassMapper
from batdetect2.models.typing import ModelOutput
__all__ = [
@ -36,11 +37,8 @@ TagFunction = Callable[[int], List[data.Tag]]
def postprocess_model_outputs(
outputs: ModelOutput,
clips: List[data.Clip],
nms_kernel_size: int = NMS_KERNEL_SIZE,
detection_threshold: float = DETECTION_THRESHOLD,
min_freq: int = 10000,
max_freq: int = 120000,
top_k_per_sec: int = TOP_K_PER_SEC,
class_mapper: ClassMapper,
config: PostprocessConfig,
) -> List[data.ClipPrediction]:
"""Postprocesses model outputs to generate clip predictions.
@ -57,16 +55,8 @@ def postprocess_model_outputs(
clips
List of clips for which predictions are made. The number of clips
must match the batch dimension of the model outputs.
nms_kernel_size
Size of the non-maximum suppression kernel. Default is 9.
detection_threshold
Detection threshold. Default is 0.01.
min_freq
Minimum frequency. Default is 10000.
max_freq
Maximum frequency. Default is 120000.
top_k_per_sec
Top k per second. Default is 200.
config
Configuration for postprocessing model outputs.
Returns
-------
@ -90,14 +80,14 @@ def postprocess_model_outputs(
detection_probs = non_max_suppression(
outputs.detection_probs,
kernel_size=nms_kernel_size,
kernel_size=config.nms_kernel_size,
)
duration = clips[0].end_time - clips[0].start_time
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
detection_probs,
int(top_k_per_sec * duration / 2),
int(config.top_k_per_sec * duration / 2),
)
predictions: List[data.ClipPrediction] = []
@ -118,9 +108,10 @@ def postprocess_model_outputs(
size_preds,
class_probs,
features,
min_freq=min_freq,
max_freq=max_freq,
detection_threshold=detection_threshold,
class_mapper=class_mapper,
min_freq=config.min_freq,
max_freq=config.max_freq,
detection_threshold=config.detection_threshold,
)
predictions.append(
@ -141,7 +132,7 @@ def compute_sound_events_from_outputs(
size_preds: torch.Tensor,
class_probs: torch.Tensor,
features: torch.Tensor,
tag_fn: TagFunction = lambda _: [],
class_mapper: ClassMapper,
min_freq: int = 10000,
max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD,
@ -160,7 +151,6 @@ def compute_sound_events_from_outputs(
predictions: List[data.SoundEventPrediction] = []
for score, x, y in zip(scores, x_pos, y_pos):
width, height = size_preds[:, y, x]
print(width, height)
class_prob = class_probs[:, y, x]
feature = features[:, y, x]
@ -191,7 +181,7 @@ def compute_sound_events_from_outputs(
predicted_tags: List[data.PredictedTag] = []
for label_id, class_score in enumerate(class_prob):
corresponding_tags = tag_fn(label_id)
corresponding_tags = class_mapper.inverse_transform(label_id)
predicted_tags.extend(
[
data.PredictedTag(

View File

@ -4,6 +4,11 @@ from typing import NamedTuple
import torch
import torch.nn as nn
__all__ = [
"ModelOutput",
"FeatureExtractorModel",
]
class ModelOutput(NamedTuple):
"""Output of the detection model.
@ -36,12 +41,11 @@ class ModelOutput(NamedTuple):
"""Tensor with intermediate features."""
class EncoderModel(ABC, nn.Module):
class FeatureExtractorModel(ABC, nn.Module):
input_height: int
"""Height of the input spectrogram."""
num_filts: int
num_features: int
"""Dimension of the feature tensor."""
@abstractmethod

View File

@ -0,0 +1,244 @@
from functools import wraps
from typing import Callable, List, Optional, Tuple
import numpy as np
import xarray as xr
from soundevent import data
from soundevent.geometry import compute_bounds
Augmentation = Callable[[xr.Dataset], xr.Dataset]
AUGMENTATION_PROBABILITY = 0.2
MAX_DELAY = 0.005
STRETCH_SQUEEZE_DELTA = 0.04
MASK_MAX_TIME_PERC: float = 0.05
MASK_MAX_FREQ_PERC: float = 0.10
def maybe_apply(
augmentation: Callable,
prob: float = AUGMENTATION_PROBABILITY,
) -> Callable:
"""Apply an augmentation with a given probability."""
@wraps(augmentation)
def _augmentation(x):
if np.random.rand() > prob:
return x
return augmentation(x)
return _augmentation
def select_random_subclip(
train_example: xr.Dataset,
duration: Optional[float] = None,
proportion: float = 0.9,
) -> xr.Dataset:
"""Select a random subclip from a clip."""
time_coords = train_example.coords["time"]
start_time = time_coords.attrs.get("min", time_coords.min())
end_time = time_coords.attrs.get("max", time_coords.max())
if duration is None:
duration = (end_time - start_time) * proportion
start_time = np.random.uniform(start_time, end_time - duration)
return train_example.sel(time=slice(start_time, start_time + duration))
def combine_audio(
audio1: xr.DataArray,
audio2: xr.DataArray,
alpha: Optional[float] = None,
min_alpha: float = 0.3,
max_alpha: float = 0.7,
) -> xr.DataArray:
"""Combine two audio clips."""
if alpha is None:
alpha = np.random.uniform(min_alpha, max_alpha)
return alpha * audio1 + (1 - alpha) * audio2.data
# def random_mix(
# audio: xr.DataArray,
# clip: data.ClipAnnotation,
# provider: Optional[ClipProvider] = None,
# alpha: Optional[float] = None,
# min_alpha: float = 0.3,
# max_alpha: float = 0.7,
# join_annotations: bool = True,
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
# """Mix two audio clips."""
# if provider is None:
# raise ValueError("No audio provider given.")
#
# try:
# other_audio, other_clip = provider(clip)
# except (StopIteration, ValueError):
# raise ValueError("No more audio sources available.")
#
# new_audio = combine_audio(
# audio,
# other_audio,
# alpha=alpha,
# min_alpha=min_alpha,
# max_alpha=max_alpha,
# )
#
# if join_annotations:
# clip = clip.model_copy(
# update=dict(
# sound_events=clip.sound_events + other_clip.sound_events,
# )
# )
#
# return new_audio, clip
def add_echo(
train_example: xr.Dataset,
delay: Optional[float] = None,
alpha: Optional[float] = None,
min_alpha: float = 0.0,
max_alpha: float = 1.0,
max_delay: float = MAX_DELAY,
) -> xr.Dataset:
"""Add a delay to the audio."""
if delay is None:
delay = np.random.uniform(0, max_delay)
if alpha is None:
alpha = np.random.uniform(min_alpha, max_alpha)
spec = train_example["spectrogram"]
time_coords = spec.coords["time"]
start_time = time_coords.attrs["min"]
end_time = time_coords.attrs["max"]
step = (end_time - start_time) / time_coords.size
spec_delay = spec.shift(time=int(delay / step), fill_value=0)
return train_example.assign(spectrogram=spec + alpha * spec_delay)
def scale_volume(
train_example: xr.Dataset,
factor: Optional[float] = None,
max_scaling: float = 2,
min_scaling: float = 0,
) -> xr.Dataset:
"""Scale the volume of a spectrogram."""
if factor is None:
factor = np.random.uniform(min_scaling, max_scaling)
return train_example.assign(
spectrogram=train_example["spectrogram"] * factor
)
def warp_spectrogram(
train_example: xr.Dataset,
factor: Optional[float] = None,
delta: float = STRETCH_SQUEEZE_DELTA,
) -> xr.Dataset:
"""Warp a spectrogram."""
if factor is None:
factor = np.random.uniform(1 - delta, 1 + delta)
time_coords = train_example.coords["time"]
start_time = time_coords.attrs["min"]
end_time = time_coords.attrs["max"]
duration = end_time - start_time
new_time = np.linspace(
start_time,
start_time + duration * factor,
train_example.time.size,
)
return train_example.interp(time=new_time)
def mask_axis(
train_example: xr.Dataset,
dim: str,
start: float,
end: float,
mask_all: bool = False,
mask_value: float = 0,
) -> xr.Dataset:
if dim not in train_example.dims:
raise ValueError(f"Axis {dim} not found in array")
coord = train_example.coords[dim]
condition = (coord < start) | (coord > end)
if mask_all:
return train_example.where(condition, other=mask_value)
return train_example.assign(
spectrogram=train_example.spectrogram.where(
condition, other=mask_value
)
)
def mask_time(
train_example: xr.Dataset,
max_time_mask: float = MASK_MAX_TIME_PERC,
max_num_masks: int = 3,
) -> xr.Dataset:
"""Mask a random section of the time axis."""
num_masks = np.random.randint(1, max_num_masks + 1)
time_coord = train_example.coords["time"]
start_time = time_coord.attrs.get("min", time_coord.min())
end_time = time_coord.attrs.get("max", time_coord.max())
for _ in range(num_masks):
mask_size = np.random.uniform(0, max_time_mask)
start = np.random.uniform(start_time, end_time - mask_size)
end = start + mask_size
train_example = mask_axis(train_example, "time", start, end)
return train_example
def mask_frequency(
train_example: xr.Dataset,
max_freq_mask: float = MASK_MAX_FREQ_PERC,
max_num_masks: int = 3,
) -> xr.Dataset:
"""Mask a random section of the frequency axis."""
num_masks = np.random.randint(1, max_num_masks + 1)
freq_coord = train_example.coords["frequency"]
min_freq = freq_coord.min()
max_freq = freq_coord.max()
for _ in range(num_masks):
mask_size = np.random.uniform(0, max_freq_mask)
start = np.random.uniform(min_freq, max_freq - mask_size)
end = start + mask_size
train_example = mask_axis(train_example, "frequency", start, end)
return train_example
AUGMENTATIONS: List[Augmentation] = [
select_random_subclip,
add_echo,
scale_volume,
mask_time,
mask_frequency,
]

View File

@ -1,16 +1,14 @@
import os
from typing import NamedTuple
from pathlib import Path
from typing import Sequence, Union, Dict
from soundevent import data
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union
from torch.utils.data import Dataset
import torch
import xarray as xr
from soundevent import data
from torch.utils.data import Dataset
from batdetect2.train.preprocess import PreprocessingConfig
__all__ = [
"TrainExample",
"LabeledDataset",
@ -33,8 +31,13 @@ def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
class LabeledDataset(Dataset):
def __init__(self, filenames: Sequence[PathLike]):
def __init__(
self,
filenames: Sequence[PathLike],
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
):
self.filenames = filenames
self.transform = transform
def __len__(self):
return len(self.filenames)
@ -54,7 +57,7 @@ class LabeledDataset(Dataset):
return cls(get_files(directory, extension))
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
dataset = xr.open_dataset(filename)
dataset = self.get_dataset(filename)
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
return {
"spectrogram": spectrogram,
@ -63,6 +66,15 @@ class LabeledDataset(Dataset):
"size": torch.tensor(dataset["size"].values),
}
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset:
if self.transform is not None:
return self.transform(dataset)
return dataset
def get_dataset(self, idx):
return xr.open_dataset(self.filenames[idx])
def get_spectrogram(self, idx):
return xr.open_dataset(self.filenames[idx])["spectrogram"]

View File

@ -9,21 +9,12 @@ from tqdm.auto import tqdm
from multiprocessing import Pool
import xarray as xr
from pydantic import BaseModel, Field
from soundevent import data
from batdetect2.data.labels import TARGET_SIGMA, LabelFn, generate_heatmaps
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
from batdetect2.data.preprocessing import (
DENOISE_SPEC_AVG,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
SCALE_RAW_AUDIO,
SPEC_SCALE,
TARGET_SAMPLERATE_HZ,
preprocess_audio_clip,
PreprocessingConfig,
)
PathLike = Union[Path, str, os.PathLike]
@ -34,61 +25,24 @@ __all__ = [
]
class PreprocessingConfig(BaseModel):
"""Configuration for preprocessing data."""
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
spec_scale: str = Field(default=SPEC_SCALE)
denoise_spec_avg: bool = DENOISE_SPEC_AVG
max_scale_spec: bool = MAX_SCALE_SPEC
target_sigma: float = Field(default=TARGET_SIGMA, gt=0)
class_labels: Sequence[str] = ["bat"]
def generate_train_example(
clip_annotation: data.ClipAnnotation,
label_fn: LabelFn = lambda _: None,
config: Optional[PreprocessingConfig] = None,
class_mapper: ClassMapper,
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
target_sigma: float = TARGET_SIGMA,
) -> xr.Dataset:
"""Generate a training example."""
if config is None:
config = PreprocessingConfig()
spectrogram = preprocess_audio_clip(
clip_annotation.clip,
target_sampling_rate=config.target_samplerate,
scale_audio=config.scale_audio,
fft_win_length=config.fft_win_length,
fft_overlap=config.fft_overlap,
max_freq=config.max_freq,
min_freq=config.min_freq,
spec_scale=config.spec_scale,
denoise_spec_avg=config.denoise_spec_avg,
max_scale_spec=config.max_scale_spec,
config=preprocessing_config,
)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation,
spectrogram,
target_sigma=config.target_sigma,
num_classes=len(config.class_labels),
class_labels=list(config.class_labels),
label_fn=label_fn,
class_mapper,
target_sigma=target_sigma,
)
dataset = xr.Dataset(
@ -102,7 +56,8 @@ def generate_train_example(
return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}",
configuration=config.model_dump_json(),
preprocessing_configuration=preprocessing_config.model_dump_json(),
target_sigma=target_sigma,
clip_annotation=clip_annotation.model_dump_json(),
)
@ -148,9 +103,10 @@ def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation,
output_dir: PathLike,
config: PreprocessingConfig,
class_mapper: ClassMapper,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
label_fn: LabelFn = lambda _: None,
target_sigma: float = TARGET_SIGMA,
) -> None:
output_dir = Path(output_dir)
@ -162,8 +118,9 @@ def preprocess_single_annotation(
sample = generate_train_example(
clip_annotation,
label_fn=label_fn,
config=config,
class_mapper,
preprocessing_config=config,
target_sigma=target_sigma,
)
save_to_file(sample, path)
@ -172,10 +129,11 @@ def preprocess_single_annotation(
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: PathLike,
class_mapper: ClassMapper,
target_sigma: float = TARGET_SIGMA,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
config_file: Optional[PathLike] = None,
label_fn: LabelFn = lambda _: None,
max_workers: Optional[int] = None,
**kwargs,
) -> None:
@ -198,9 +156,10 @@ def preprocess_annotations(
preprocess_single_annotation,
output_dir=output_dir,
config=config,
class_mapper=class_mapper,
filename_fn=filename_fn,
replace=replace,
label_fn=label_fn,
target_sigma=target_sigma,
),
clip_annotations,
),

View File

@ -28,7 +28,7 @@ dependencies = [
"torch>=1.13.1",
"torchaudio",
"torchvision",
"soundevent[audio,geometry,plot]>=1.3.5",
"soundevent[audio,geometry,plot]>=2.0",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",

View File

@ -10,6 +10,8 @@
-e file:.
absl-py==2.1.0
# via tensorboard
affine==2.4.0
# via rasterio
aiobotocore==2.12.3
# via s3fs
aiohttp==3.9.5
@ -37,6 +39,7 @@ async-timeout==4.0.3
# via redis
attrs==23.2.0
# via aiohttp
# via rasterio
audioread==3.0.1
# via librosa
backcall==0.2.0
@ -57,6 +60,7 @@ botocore==1.34.69
# via s3transfer
certifi==2024.2.2
# via netcdf4
# via rasterio
# via requests
cf-xarray==0.9.0
# via batdetect2
@ -68,9 +72,16 @@ charset-normalizer==3.3.2
# via requests
click==8.1.7
# via batdetect2
# via click-plugins
# via cligj
# via lightning
# via lightning-cloud
# via rasterio
# via uvicorn
click-plugins==1.1.1
# via rasterio
cligj==0.7.2
# via rasterio
comm==0.2.2
# via ipykernel
contourpy==1.1.1
@ -136,6 +147,7 @@ idna==3.7
importlib-metadata==7.1.0
# via jupyter-client
# via markdown
# via rasterio
importlib-resources==6.4.0
# via matplotlib
# via typeshed-client
@ -229,9 +241,11 @@ numpy==1.24.4
# via onnx
# via pandas
# via pytorch-lightning
# via rasterio
# via scikit-learn
# via scipy
# via shapely
# via snuggs
# via soxr
# via tensorboard
# via tensorboardx
@ -335,6 +349,7 @@ pyjwt==2.8.0
# via lightning-cloud
pyparsing==3.1.2
# via matplotlib
# via snuggs
pytest==8.1.1
python-dateutil==2.9.0.post0
# via arrow
@ -361,6 +376,8 @@ pyyaml==6.0.1
pyzmq==26.0.0
# via ipykernel
# via jupyter-client
rasterio==1.3.10
# via soundevent
readchar==4.0.6
# via inquirer
redis==5.0.4
@ -390,6 +407,7 @@ scipy==1.10.1
# via soundevent
setuptools==69.5.1
# via lightning-utilities
# via rasterio
# via readchar
# via tensorboard
shapely==2.0.3
@ -402,7 +420,9 @@ six==1.16.0
# via tensorboard
sniffio==1.3.1
# via anyio
soundevent==1.3.5
snuggs==1.4.7
# via rasterio
soundevent==2.0.0
# via batdetect2
soundfile==0.12.1
# via librosa

View File

@ -10,6 +10,8 @@
-e file:.
absl-py==2.1.0
# via tensorboard
affine==2.4.0
# via rasterio
aiobotocore==2.12.3
# via s3fs
aiohttp==3.9.5
@ -35,6 +37,7 @@ async-timeout==4.0.3
# via redis
attrs==23.2.0
# via aiohttp
# via rasterio
audioread==3.0.1
# via librosa
backoff==2.2.1
@ -53,6 +56,7 @@ botocore==1.34.69
# via s3transfer
certifi==2024.2.2
# via netcdf4
# via rasterio
# via requests
cf-xarray==0.9.0
# via batdetect2
@ -64,9 +68,16 @@ charset-normalizer==3.3.2
# via requests
click==8.1.7
# via batdetect2
# via click-plugins
# via cligj
# via lightning
# via lightning-cloud
# via rasterio
# via uvicorn
click-plugins==1.1.1
# via rasterio
cligj==0.7.2
# via rasterio
contourpy==1.1.1
# via matplotlib
croniter==1.4.1
@ -123,6 +134,7 @@ idna==3.7
# via yarl
importlib-metadata==7.1.0
# via markdown
# via rasterio
importlib-resources==6.4.0
# via matplotlib
# via typeshed-client
@ -199,9 +211,11 @@ numpy==1.24.4
# via onnx
# via pandas
# via pytorch-lightning
# via rasterio
# via scikit-learn
# via scipy
# via shapely
# via snuggs
# via soxr
# via tensorboard
# via tensorboardx
@ -286,6 +300,7 @@ pyjwt==2.8.0
# via lightning-cloud
pyparsing==3.1.2
# via matplotlib
# via snuggs
python-dateutil==2.9.0.post0
# via arrow
# via botocore
@ -307,6 +322,8 @@ pyyaml==6.0.1
# via lightning
# via omegaconf
# via pytorch-lightning
rasterio==1.3.10
# via soundevent
readchar==4.0.6
# via inquirer
redis==5.0.4
@ -336,6 +353,7 @@ scipy==1.10.1
# via soundevent
setuptools==69.5.1
# via lightning-utilities
# via rasterio
# via readchar
# via tensorboard
shapely==2.0.3
@ -347,7 +365,9 @@ six==1.16.0
# via tensorboard
sniffio==1.3.1
# via anyio
soundevent==1.3.5
snuggs==1.4.7
# via rasterio
soundevent==2.0.0
# via batdetect2
soundfile==0.12.1
# via librosa