mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Working towards training code
This commit is contained in:
parent
c66d14b7c7
commit
17cf958cd3
1
.gitignore
vendored
1
.gitignore
vendored
@ -110,3 +110,4 @@ experiments/*
|
|||||||
!batdetect2_notebook.ipynb
|
!batdetect2_notebook.ipynb
|
||||||
!batdetect2/models/*.pth.tar
|
!batdetect2/models/*.pth.tar
|
||||||
!tests/data/*.wav
|
!tests/data/*.wav
|
||||||
|
notebooks/lightning_logs
|
||||||
|
@ -1,304 +0,0 @@
|
|||||||
from functools import wraps
|
|
||||||
from typing import Callable, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
ClipAugmentation = Callable[[data.ClipAnnotation], data.ClipAnnotation]
|
|
||||||
AudioAugmentation = Callable[
|
|
||||||
[xr.DataArray, data.ClipAnnotation],
|
|
||||||
Tuple[xr.DataArray, data.ClipAnnotation],
|
|
||||||
]
|
|
||||||
SpecAugmentation = Callable[
|
|
||||||
[xr.DataArray, data.ClipAnnotation],
|
|
||||||
Tuple[xr.DataArray, data.ClipAnnotation],
|
|
||||||
]
|
|
||||||
|
|
||||||
ClipProvider = Callable[
|
|
||||||
[data.ClipAnnotation], Tuple[xr.DataArray, data.ClipAnnotation]
|
|
||||||
]
|
|
||||||
"""A function that provides some clip and its annotation.
|
|
||||||
|
|
||||||
Usually this function loads a random clip from a dataset. Takes
|
|
||||||
as input a clip annotation that can be used to filter the clips
|
|
||||||
to load (in case you want to avoid loading the same clip multiple times).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
AUGMENTATION_PROBABILITY = 0.2
|
|
||||||
MAX_DELAY = 0.005
|
|
||||||
STRETCH_SQUEEZE_DELTA = 0.04
|
|
||||||
MASK_MAX_TIME_PERC: float = 0.05
|
|
||||||
MASK_MAX_FREQ_PERC: float = 0.10
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_apply(
|
|
||||||
augmentation: Callable,
|
|
||||||
prob: float = AUGMENTATION_PROBABILITY,
|
|
||||||
) -> Callable:
|
|
||||||
"""Apply an augmentation with a given probability."""
|
|
||||||
|
|
||||||
@wraps(augmentation)
|
|
||||||
def _augmentation(x):
|
|
||||||
if np.random.rand() > prob:
|
|
||||||
return x
|
|
||||||
return augmentation(x)
|
|
||||||
|
|
||||||
return _augmentation
|
|
||||||
|
|
||||||
|
|
||||||
def select_random_subclip(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
duration: Optional[float] = None,
|
|
||||||
proportion: float = 0.9,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
"""Select a random subclip from a clip."""
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
if duration is None:
|
|
||||||
clip_duration = clip.end_time - clip.start_time
|
|
||||||
duration = clip_duration * proportion
|
|
||||||
|
|
||||||
start_time = np.random.uniform(clip.start_time, clip.end_time - duration)
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
clip=clip.model_copy(
|
|
||||||
update=dict(
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=start_time + duration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def combine_audio(
|
|
||||||
audio1: xr.DataArray,
|
|
||||||
audio2: xr.DataArray,
|
|
||||||
alpha: Optional[float] = None,
|
|
||||||
min_alpha: float = 0.3,
|
|
||||||
max_alpha: float = 0.7,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Combine two audio clips."""
|
|
||||||
|
|
||||||
if alpha is None:
|
|
||||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
|
||||||
|
|
||||||
return alpha * audio1 + (1 - alpha) * audio2.data
|
|
||||||
|
|
||||||
|
|
||||||
def random_mix(
|
|
||||||
audio: xr.DataArray,
|
|
||||||
clip: data.ClipAnnotation,
|
|
||||||
provider: Optional[ClipProvider] = None,
|
|
||||||
alpha: Optional[float] = None,
|
|
||||||
min_alpha: float = 0.3,
|
|
||||||
max_alpha: float = 0.7,
|
|
||||||
join_annotations: bool = True,
|
|
||||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
"""Mix two audio clips."""
|
|
||||||
if provider is None:
|
|
||||||
raise ValueError("No audio provider given.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
other_audio, other_clip = provider(clip)
|
|
||||||
except (StopIteration, ValueError):
|
|
||||||
raise ValueError("No more audio sources available.")
|
|
||||||
|
|
||||||
new_audio = combine_audio(
|
|
||||||
audio,
|
|
||||||
other_audio,
|
|
||||||
alpha=alpha,
|
|
||||||
min_alpha=min_alpha,
|
|
||||||
max_alpha=max_alpha,
|
|
||||||
)
|
|
||||||
|
|
||||||
if join_annotations:
|
|
||||||
clip = clip.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=clip.sound_events + other_clip.sound_events,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_audio, clip
|
|
||||||
|
|
||||||
|
|
||||||
def add_echo(
|
|
||||||
audio: xr.DataArray,
|
|
||||||
clip: data.ClipAnnotation,
|
|
||||||
delay: Optional[float] = None,
|
|
||||||
alpha: Optional[float] = None,
|
|
||||||
min_alpha: float = 0.0,
|
|
||||||
max_alpha: float = 1.0,
|
|
||||||
max_delay: float = MAX_DELAY,
|
|
||||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
"""Add a delay to the audio."""
|
|
||||||
if delay is None:
|
|
||||||
delay = np.random.uniform(0, max_delay)
|
|
||||||
|
|
||||||
if alpha is None:
|
|
||||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
|
||||||
|
|
||||||
samplerate = audio.attrs["samplerate"]
|
|
||||||
offset = int(delay * samplerate)
|
|
||||||
|
|
||||||
# NOTE: We use the copy method to avoid modifying the original audio
|
|
||||||
# data.
|
|
||||||
new_audio = audio.copy()
|
|
||||||
new_audio[offset:] += alpha * audio.data[:-offset]
|
|
||||||
return new_audio, clip
|
|
||||||
|
|
||||||
|
|
||||||
def scale_volume(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
clip: data.ClipAnnotation,
|
|
||||||
factor: Optional[float] = None,
|
|
||||||
max_scaling: float = 2,
|
|
||||||
min_scaling: float = 0,
|
|
||||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
"""Scale the volume of a spectrogram."""
|
|
||||||
if factor is None:
|
|
||||||
factor = np.random.uniform(min_scaling, max_scaling)
|
|
||||||
|
|
||||||
return spec * factor, clip
|
|
||||||
|
|
||||||
|
|
||||||
def scale_sound_event_annotation(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
time_factor: float = 1,
|
|
||||||
frequency_factor: float = 1,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
sound_event = sound_event_annotation.sound_event
|
|
||||||
geometry = sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
|
||||||
new_geometry = data.BoundingBox(
|
|
||||||
coordinates=[
|
|
||||||
start_time * time_factor,
|
|
||||||
low_freq * frequency_factor,
|
|
||||||
end_time * time_factor,
|
|
||||||
high_freq * frequency_factor,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_event=sound_event.model_copy(
|
|
||||||
update=dict(
|
|
||||||
geometry=new_geometry,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def warp_spectrogram(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
clip: data.ClipAnnotation,
|
|
||||||
factor: Optional[float] = None,
|
|
||||||
delta: float = STRETCH_SQUEEZE_DELTA,
|
|
||||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
"""Warp a spectrogram."""
|
|
||||||
if factor is None:
|
|
||||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
|
||||||
|
|
||||||
start_time = clip.clip.start_time
|
|
||||||
end_time = clip.clip.end_time
|
|
||||||
duration = end_time - start_time
|
|
||||||
new_time = np.linspace(
|
|
||||||
start_time,
|
|
||||||
start_time + duration * factor,
|
|
||||||
spec.time.size,
|
|
||||||
)
|
|
||||||
|
|
||||||
scaled_spec = spec.interp(
|
|
||||||
time=new_time,
|
|
||||||
method="linear",
|
|
||||||
kwargs={"fill_value": 0},
|
|
||||||
)
|
|
||||||
scaled_spec.coords["time"] = spec.time
|
|
||||||
|
|
||||||
scaled_clip = clip.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=[
|
|
||||||
scale_sound_event_annotation(
|
|
||||||
sound_event_annotation,
|
|
||||||
time_factor=1 / factor,
|
|
||||||
)
|
|
||||||
for sound_event_annotation in clip.sound_events
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return scaled_spec, scaled_clip
|
|
||||||
|
|
||||||
|
|
||||||
def mask_axis(
|
|
||||||
array: xr.DataArray,
|
|
||||||
axis: str,
|
|
||||||
start: float,
|
|
||||||
end: float,
|
|
||||||
mask_value: float = 0,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
if axis not in array.dims:
|
|
||||||
raise ValueError(f"Axis {axis} not found in array")
|
|
||||||
|
|
||||||
coord = array[axis]
|
|
||||||
return array.where((coord < start) | (coord > end), mask_value)
|
|
||||||
|
|
||||||
|
|
||||||
def mask_time(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
clip: data.ClipAnnotation,
|
|
||||||
max_time_mask: float = MASK_MAX_TIME_PERC,
|
|
||||||
max_num_masks: int = 3,
|
|
||||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
"""Mask a random section of the time axis."""
|
|
||||||
|
|
||||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
|
||||||
for _ in range(num_masks):
|
|
||||||
mask_size = np.random.uniform(0, max_time_mask)
|
|
||||||
start = np.random.uniform(0, spec.time[-1] - mask_size)
|
|
||||||
end = start + mask_size
|
|
||||||
spec = mask_axis(spec, "time", start, end)
|
|
||||||
|
|
||||||
return spec, clip
|
|
||||||
|
|
||||||
|
|
||||||
def mask_frequency(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
clip: data.ClipAnnotation,
|
|
||||||
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
|
||||||
max_num_masks: int = 3,
|
|
||||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
"""Mask a random section of the frequency axis."""
|
|
||||||
|
|
||||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
|
||||||
for _ in range(num_masks):
|
|
||||||
mask_size = np.random.uniform(0, max_freq_mask)
|
|
||||||
start = np.random.uniform(0, spec.frequency[-1] - mask_size)
|
|
||||||
end = start + mask_size
|
|
||||||
spec = mask_axis(spec, "frequency", start, end)
|
|
||||||
|
|
||||||
return spec, clip
|
|
||||||
|
|
||||||
|
|
||||||
CLIP_AUGMENTATIONS: List[ClipAugmentation] = [
|
|
||||||
select_random_subclip,
|
|
||||||
]
|
|
||||||
|
|
||||||
AUDIO_AUGMENTATIONS: List[AudioAugmentation] = [
|
|
||||||
add_echo,
|
|
||||||
random_mix,
|
|
||||||
]
|
|
||||||
|
|
||||||
SPEC_AUGMENTATIONS: List[SpecAugmentation] = [
|
|
||||||
scale_volume,
|
|
||||||
warp_spectrogram,
|
|
||||||
mask_time,
|
|
||||||
mask_frequency,
|
|
||||||
]
|
|
@ -11,7 +11,7 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2 import types
|
from batdetect2 import types
|
||||||
from batdetect2.data.labels import LabelFn
|
from batdetect2.data.labels import ClassMapper
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
|
|||||||
|
|
||||||
def convert_to_annotation_group(
|
def convert_to_annotation_group(
|
||||||
annotation: data.ClipAnnotation,
|
annotation: data.ClipAnnotation,
|
||||||
label_fn: LabelFn = lambda _: None,
|
class_mapper: ClassMapper,
|
||||||
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||||
class_fn: ClassFn = lambda _: 0,
|
class_fn: ClassFn = lambda _: 0,
|
||||||
individual_fn: IndividualFn = lambda _: 0,
|
individual_fn: IndividualFn = lambda _: 0,
|
||||||
@ -80,8 +80,8 @@ def convert_to_annotation_group(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
class_id = label_fn(sound_event) or -1
|
class_id = class_mapper.transform(sound_event) or -1
|
||||||
event = event_fn(sound_event)
|
event = event_fn(sound_event) or ""
|
||||||
individual_id = individual_fn(sound_event) or -1
|
individual_id = individual_fn(sound_event) or -1
|
||||||
|
|
||||||
start_times.append(start_time)
|
start_times.append(start_time)
|
||||||
|
@ -4,7 +4,6 @@ from soundevent import data
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClipAnnotationDataset",
|
|
||||||
"ClipDataset",
|
"ClipDataset",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -12,31 +11,7 @@ __all__ = [
|
|||||||
E = TypeVar("E")
|
E = TypeVar("E")
|
||||||
|
|
||||||
|
|
||||||
class ClipAnnotationDataset(Dataset, Generic[E]):
|
|
||||||
|
|
||||||
clip_annotations: List[data.ClipAnnotation]
|
|
||||||
|
|
||||||
transform: Callable[[data.ClipAnnotation], E]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clip_annotations: Iterable[data.ClipAnnotation],
|
|
||||||
transform: Callable[[data.ClipAnnotation], E],
|
|
||||||
name: str = "ClipAnnotationDataset",
|
|
||||||
):
|
|
||||||
self.clip_annotations = list(clip_annotations)
|
|
||||||
self.transform = transform
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.clip_annotations)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> E:
|
|
||||||
return self.transform(self.clip_annotations[idx])
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDataset(Dataset, Generic[E]):
|
class ClipDataset(Dataset, Generic[E]):
|
||||||
|
|
||||||
clips: List[data.Clip]
|
clips: List[data.Clip]
|
||||||
|
|
||||||
transform: Callable[[data.Clip], E]
|
transform: Callable[[data.Clip], E]
|
||||||
|
@ -1,113 +1,29 @@
|
|||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import data, geometry
|
from soundevent import data, geometry, arrays
|
||||||
|
from soundevent.geometry.operations import Positions
|
||||||
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ClassMapper",
|
||||||
"generate_heatmaps",
|
"generate_heatmaps",
|
||||||
]
|
]
|
||||||
|
|
||||||
PositionFn = Callable[[data.SoundEvent], Tuple[float, float]]
|
|
||||||
"""Convert a sound event to a single position in time-frequency space."""
|
|
||||||
|
|
||||||
SizeFn = Callable[[data.SoundEvent, float, float], np.ndarray]
|
|
||||||
"""Compute the size of a sound event in time-frequency space.
|
|
||||||
|
|
||||||
The time and frequency scales are provided as arguments to allow
|
|
||||||
modifying the size of the sound event based on the spectrogram
|
|
||||||
parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
LabelFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
|
||||||
"""Convert a sound event annotation to a label.
|
|
||||||
|
|
||||||
When the label is None, this indicates that the sound event does not
|
|
||||||
belong to any of the classes of interest.
|
|
||||||
"""
|
|
||||||
|
|
||||||
TARGET_SIGMA = 3.0
|
TARGET_SIGMA = 3.0
|
||||||
|
|
||||||
|
|
||||||
GENERIC_LABEL = "__UNKNOWN__"
|
|
||||||
|
|
||||||
|
|
||||||
def get_lower_left_position(
|
|
||||||
sound_event: data.SoundEvent,
|
|
||||||
) -> Tuple[float, float]:
|
|
||||||
if sound_event.geometry is None:
|
|
||||||
raise ValueError("Sound event has no geometry.")
|
|
||||||
|
|
||||||
start_time, low_freq, _, _ = geometry.compute_bounds(sound_event.geometry)
|
|
||||||
return start_time, low_freq
|
|
||||||
|
|
||||||
|
|
||||||
def get_bbox_size(
|
|
||||||
sound_event: data.SoundEvent,
|
|
||||||
time_scale: float = 1.0,
|
|
||||||
frequency_scale: float = 1.0,
|
|
||||||
) -> np.ndarray:
|
|
||||||
if sound_event.geometry is None:
|
|
||||||
raise ValueError("Sound event has no geometry.")
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
|
||||||
sound_event.geometry
|
|
||||||
)
|
|
||||||
|
|
||||||
return np.array(
|
|
||||||
[
|
|
||||||
time_scale * (end_time - start_time),
|
|
||||||
frequency_scale * (high_freq - low_freq),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _tag_key(tag: data.Tag) -> Tuple[str, str]:
|
|
||||||
return (tag.key, tag.value)
|
|
||||||
|
|
||||||
|
|
||||||
def set_value_at_position(
|
|
||||||
array: xr.DataArray,
|
|
||||||
value: Any,
|
|
||||||
**query,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
dims = {dim: n for n, dim in enumerate(array.dims)}
|
|
||||||
indexer: List[Union[slice, int]] = [slice(None) for _ in range(array.ndim)]
|
|
||||||
|
|
||||||
for key, coord in query.items():
|
|
||||||
if key not in dims:
|
|
||||||
raise ValueError(f"Dimension {key} not found in array.")
|
|
||||||
|
|
||||||
coordinates = array.indexes[key]
|
|
||||||
indexer[dims[key]] = coordinates.get_loc(coordinates.asof(coord))
|
|
||||||
|
|
||||||
if isinstance(value, (tuple, list)):
|
|
||||||
value = np.array(value)
|
|
||||||
|
|
||||||
array.data[tuple(indexer)] = value
|
|
||||||
return array
|
|
||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
num_classes: int = 1,
|
class_mapper: ClassMapper,
|
||||||
label_fn: LabelFn = lambda _: None,
|
|
||||||
target_sigma: float = TARGET_SIGMA,
|
target_sigma: float = TARGET_SIGMA,
|
||||||
size_fn: SizeFn = get_bbox_size,
|
position: Positions = "bottom-left",
|
||||||
position_fn: PositionFn = get_lower_left_position,
|
|
||||||
class_labels: Optional[List[str]] = None,
|
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||||
if class_labels is None:
|
|
||||||
class_labels = [str(i) for i in range(num_classes)]
|
|
||||||
|
|
||||||
if len(class_labels) != num_classes:
|
|
||||||
raise ValueError(
|
|
||||||
"Number of class labels must match the number of classes."
|
|
||||||
)
|
|
||||||
|
|
||||||
shape = dict(zip(spec.dims, spec.shape))
|
shape = dict(zip(spec.dims, spec.shape))
|
||||||
|
|
||||||
if "time" not in shape or "frequency" not in shape:
|
if "time" not in shape or "frequency" not in shape:
|
||||||
@ -115,8 +31,8 @@ def generate_heatmaps(
|
|||||||
"Spectrogram must have time and frequency dimensions."
|
"Spectrogram must have time and frequency dimensions."
|
||||||
)
|
)
|
||||||
|
|
||||||
time_duration = spec.time.attrs["max"] - spec.time.attrs["min"]
|
time_duration = arrays.get_dim_width(spec, dim="time")
|
||||||
freq_bandwidth = spec.frequency.attrs["max"] - spec.frequency.attrs["min"]
|
freq_bandwidth = arrays.get_dim_width(spec, dim="frequency")
|
||||||
|
|
||||||
# Compute the size factors
|
# Compute the size factors
|
||||||
time_scale = 1 / time_duration
|
time_scale = 1 / time_duration
|
||||||
@ -125,10 +41,10 @@ def generate_heatmaps(
|
|||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||||
class_heatmap = xr.DataArray(
|
class_heatmap = xr.DataArray(
|
||||||
data=np.zeros((num_classes, *spec.shape), dtype=dtype),
|
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
||||||
dims=["category", *spec.dims],
|
dims=["category", *spec.dims],
|
||||||
coords={
|
coords={
|
||||||
"category": class_labels,
|
"category": class_mapper.class_labels,
|
||||||
**spec.coords,
|
**spec.coords,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -142,11 +58,17 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
|
geom = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geom is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# Get the position of the sound event
|
# Get the position of the sound event
|
||||||
time, frequency = position_fn(sound_event_annotation.sound_event)
|
time, frequency = geometry.get_geometry_point(geom, position=position)
|
||||||
|
print(time, frequency)
|
||||||
|
|
||||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||||
detection_heatmap = set_value_at_position(
|
detection_heatmap = arrays.set_value_at_pos(
|
||||||
detection_heatmap,
|
detection_heatmap,
|
||||||
1.0,
|
1.0,
|
||||||
time=time,
|
time=time,
|
||||||
@ -154,35 +76,37 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set the size of the sound event at the position in the size heatmap
|
# Set the size of the sound event at the position in the size heatmap
|
||||||
size = size_fn(
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
sound_event_annotation.sound_event,
|
geom
|
||||||
time_scale,
|
|
||||||
frequency_scale,
|
|
||||||
|
|
||||||
)
|
)
|
||||||
size_heatmap = set_value_at_position(
|
size = np.array(
|
||||||
|
[
|
||||||
|
(end_time - start_time) * time_scale,
|
||||||
|
(high_freq - low_freq) * frequency_scale,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
size_heatmap = arrays.set_value_at_pos(
|
||||||
size_heatmap,
|
size_heatmap,
|
||||||
size,
|
size,
|
||||||
time=time,
|
time=time,
|
||||||
frequency=frequency,
|
frequency=frequency,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the label id for the sound event
|
# Get the class name of the sound event
|
||||||
label = label_fn(sound_event_annotation)
|
class_name = class_mapper.transform(sound_event_annotation)
|
||||||
|
|
||||||
if label is None or label not in class_labels:
|
if class_name is None:
|
||||||
# If the label is None or not in the class labels, we skip the
|
# If the label is None skip the sound event
|
||||||
# sound event
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Set 1.0 at the position and category of the sound event in the class
|
# Set 1.0 at the position and category of the sound event in the class
|
||||||
# heatmap
|
# heatmap
|
||||||
class_heatmap = set_value_at_position(
|
class_heatmap = arrays.set_value_at_pos(
|
||||||
class_heatmap,
|
class_heatmap,
|
||||||
1.0,
|
1.0,
|
||||||
time=time,
|
time=time,
|
||||||
frequency=frequency,
|
frequency=frequency,
|
||||||
category=label,
|
category=class_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply gaussian filters
|
# Apply gaussian filters
|
||||||
@ -207,25 +131,3 @@ def generate_heatmaps(
|
|||||||
).fillna(0.0)
|
).fillna(0.0)
|
||||||
|
|
||||||
return detection_heatmap, class_heatmap, size_heatmap
|
return detection_heatmap, class_heatmap, size_heatmap
|
||||||
|
|
||||||
|
|
||||||
class Labeler:
|
|
||||||
def __init__(self, tags: List[data.Tag]):
|
|
||||||
"""Create a labeler from a list of tags.
|
|
||||||
|
|
||||||
Each tag is assigned a unique label. The labeler can then be used
|
|
||||||
to convert sound event annotations to labels.
|
|
||||||
"""
|
|
||||||
self.tags = tags
|
|
||||||
self._label_map = {_tag_key(tag): i for i, tag in enumerate(tags)}
|
|
||||||
self._inverse_label_map = {v: k for k, v in self._label_map.items()}
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> Optional[int]:
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
key = _tag_key(tag)
|
|
||||||
if key in self._label_map:
|
|
||||||
return self._label_map[key]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
@ -1,15 +1,22 @@
|
|||||||
"""Module containing functions for preprocessing audio clips."""
|
"""Module containing functions for preprocessing audio clips."""
|
||||||
|
|
||||||
import random
|
from typing import Optional
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from scipy.signal import resample_poly
|
from scipy.signal import resample_poly
|
||||||
from soundevent import audio, data
|
from soundevent import audio, data, arrays
|
||||||
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PreprocessingConfig",
|
||||||
|
"preprocess_audio_clip",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256000
|
TARGET_SAMPLERATE_HZ = 256000
|
||||||
SCALE_RAW_AUDIO = False
|
SCALE_RAW_AUDIO = False
|
||||||
@ -26,20 +33,37 @@ DENOISE_SPEC_AVG = True
|
|||||||
MAX_SCALE_SPEC = False
|
MAX_SCALE_SPEC = False
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingConfig(BaseModel):
|
||||||
|
"""Configuration for preprocessing data."""
|
||||||
|
|
||||||
|
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
|
||||||
|
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
||||||
|
|
||||||
|
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||||
|
|
||||||
|
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||||
|
|
||||||
|
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||||
|
|
||||||
|
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||||
|
|
||||||
|
spec_scale: str = Field(default=SPEC_SCALE)
|
||||||
|
|
||||||
|
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
||||||
|
|
||||||
|
max_scale_spec: bool = MAX_SCALE_SPEC
|
||||||
|
|
||||||
|
duration: Optional[float] = DEFAULT_DURATION
|
||||||
|
|
||||||
|
spec_height: int = SPEC_HEIGHT
|
||||||
|
|
||||||
|
spec_time_period: float = SPEC_TIME_PERIOD
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audio_clip(
|
def preprocess_audio_clip(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
config: PreprocessingConfig = PreprocessingConfig(),
|
||||||
scale_audio: bool = SCALE_RAW_AUDIO,
|
|
||||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
|
||||||
fft_overlap: float = FFT_OVERLAP,
|
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
|
||||||
spec_scale: str = SPEC_SCALE,
|
|
||||||
denoise_spec_avg: bool = True,
|
|
||||||
max_scale_spec: bool = False,
|
|
||||||
duration: Optional[float] = DEFAULT_DURATION,
|
|
||||||
spec_height: int = SPEC_HEIGHT,
|
|
||||||
spec_time_period: float = SPEC_TIME_PERIOD,
|
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
"""Preprocesses audio clip to generate spectrogram.
|
"""Preprocesses audio clip to generate spectrogram.
|
||||||
|
|
||||||
@ -47,45 +71,8 @@ def preprocess_audio_clip(
|
|||||||
----------
|
----------
|
||||||
clip
|
clip
|
||||||
The audio clip to preprocess.
|
The audio clip to preprocess.
|
||||||
target_sampling_rate
|
config
|
||||||
Target sampling rate for the audio. If the audio has a different
|
Configuration for preprocessing.
|
||||||
sampling rate, it will be resampled to this rate.
|
|
||||||
scale_audio
|
|
||||||
Whether to scale the audio amplitudes to a range of [-1, 1].
|
|
||||||
By default, the audio is not scaled.
|
|
||||||
fft_win_length
|
|
||||||
Length of the FFT window in seconds.
|
|
||||||
fft_overlap
|
|
||||||
Amount of overlap between FFT windows as a fraction of the window
|
|
||||||
length.
|
|
||||||
max_freq
|
|
||||||
Maximum frequency for spectrogram. Anything above this frequency will
|
|
||||||
be cropped.
|
|
||||||
min_freq
|
|
||||||
Minimum frequency for spectrogram. Anything below this frequency will
|
|
||||||
be cropped.
|
|
||||||
spec_scale
|
|
||||||
Scaling method for the spectrogram. Can be "pcen", "log" or
|
|
||||||
"amplitude".
|
|
||||||
denoise_spec_avg
|
|
||||||
Whether to denoise the spectrogram. Denoising is done by subtracting
|
|
||||||
the average of the spectrogram from the spectrogram and clipping
|
|
||||||
negative values to 0.
|
|
||||||
max_scale_spec
|
|
||||||
Whether to max scale the spectrogram. Max scaling is done by dividing
|
|
||||||
the spectrogram by its maximum value thus scaling values to [0, 1].
|
|
||||||
duration
|
|
||||||
Duration of the spectrogram in seconds. If the clip duration is
|
|
||||||
different from this value, the spectrogram will be cropped or extended
|
|
||||||
to match this duration. If None, the spectrogram will have the same
|
|
||||||
duration as the clip.
|
|
||||||
spec_height
|
|
||||||
Number of frequency bins for the spectrogram. This is the height of
|
|
||||||
the final spectrogram.
|
|
||||||
spec_time_period
|
|
||||||
Time period for each spectrogram bin in seconds. The spectrogram array
|
|
||||||
will be resized (using bilinear interpolation) to have this time
|
|
||||||
period.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -95,35 +82,29 @@ def preprocess_audio_clip(
|
|||||||
"""
|
"""
|
||||||
wav = load_clip_audio(
|
wav = load_clip_audio(
|
||||||
clip,
|
clip,
|
||||||
target_sampling_rate=target_sampling_rate,
|
target_sampling_rate=config.target_samplerate,
|
||||||
scale=scale_audio,
|
scale=config.scale_audio,
|
||||||
)
|
|
||||||
|
|
||||||
wav = wav.assign_attrs(
|
|
||||||
recording_id=str(wav.attrs["recording_id"]),
|
|
||||||
clip_id=str(wav.attrs["clip_id"]),
|
|
||||||
path=str(wav.attrs["path"]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = compute_spectrogram(
|
spec = compute_spectrogram(
|
||||||
wav,
|
wav,
|
||||||
fft_win_length=fft_win_length,
|
fft_win_length=config.fft_win_length,
|
||||||
fft_overlap=fft_overlap,
|
fft_overlap=config.fft_overlap,
|
||||||
max_freq=max_freq,
|
max_freq=config.max_freq,
|
||||||
min_freq=min_freq,
|
min_freq=config.min_freq,
|
||||||
spec_scale=spec_scale,
|
spec_scale=config.spec_scale,
|
||||||
denoise_spec_avg=denoise_spec_avg,
|
denoise_spec_avg=config.denoise_spec_avg,
|
||||||
max_scale_spec=max_scale_spec,
|
max_scale_spec=config.max_scale_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
if duration is not None:
|
if config.duration is not None:
|
||||||
spec = adjust_spec_duration(clip, spec, duration)
|
spec = adjust_spec_duration(clip, spec, config.duration)
|
||||||
|
|
||||||
duration = get_dim_width(spec, dim="time")
|
duration = arrays.get_dim_width(spec, dim="time")
|
||||||
return resize_spectrogram(
|
return ops.resize(
|
||||||
spec,
|
spec,
|
||||||
time_bins=int(np.ceil(duration / spec_time_period)),
|
time=int(np.ceil(duration / config.spec_time_period)),
|
||||||
freq_bins=spec_height,
|
frequency=config.spec_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -138,18 +119,18 @@ def adjust_spec_duration(
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
if current_duration > duration:
|
if current_duration > duration:
|
||||||
return crop_axis(
|
return arrays.crop_dim(
|
||||||
spec,
|
spec,
|
||||||
dim="time",
|
dim="time",
|
||||||
start=clip.start_time,
|
start=clip.start_time,
|
||||||
end=clip.start_time + duration,
|
stop=clip.start_time + duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
return extend_axis(
|
return arrays.extend_dim(
|
||||||
spec,
|
spec,
|
||||||
dim="time",
|
dim="time",
|
||||||
start=clip.start_time,
|
start=clip.start_time,
|
||||||
end=clip.start_time + duration,
|
stop=clip.start_time + duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -159,21 +140,15 @@ def load_clip_audio(
|
|||||||
scale: bool = SCALE_RAW_AUDIO,
|
scale: bool = SCALE_RAW_AUDIO,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
wav = audio.load_clip(clip).sel(channel=0)
|
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
||||||
|
|
||||||
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
||||||
|
|
||||||
if scale:
|
if scale:
|
||||||
wav = scale_audio(wav)
|
wav = ops.center(wav)
|
||||||
|
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||||
|
|
||||||
wav.coords["time"] = wav.time.assign_attrs(
|
return wav.astype(dtype)
|
||||||
unit="s",
|
|
||||||
long_name="Seconds since start of recording",
|
|
||||||
min=clip.start_time,
|
|
||||||
max=clip.end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
return wav
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
def resample_audio(
|
||||||
@ -181,14 +156,14 @@ def resample_audio(
|
|||||||
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
if "samplerate" not in wav.attrs:
|
|
||||||
raise ValueError("Audio must have a 'samplerate' attribute")
|
|
||||||
|
|
||||||
if "time" not in wav.dims:
|
if "time" not in wav.dims:
|
||||||
raise ValueError("Audio must have a time dimension")
|
raise ValueError("Audio must have a time dimension")
|
||||||
|
|
||||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||||
original_samplerate = wav.attrs["samplerate"]
|
|
||||||
|
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||||
|
step = arrays.get_dim_step(wav, dim="time")
|
||||||
|
original_samplerate = int(1 / step)
|
||||||
|
|
||||||
if original_samplerate == target_samplerate:
|
if original_samplerate == target_samplerate:
|
||||||
return wav.astype(dtype)
|
return wav.astype(dtype)
|
||||||
@ -202,8 +177,8 @@ def resample_audio(
|
|||||||
)
|
)
|
||||||
|
|
||||||
resampled_times = np.linspace(
|
resampled_times = np.linspace(
|
||||||
wav.time[0],
|
start,
|
||||||
wav.time[-1],
|
stop + step,
|
||||||
len(resampled),
|
len(resampled),
|
||||||
endpoint=False,
|
endpoint=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -214,23 +189,15 @@ def resample_audio(
|
|||||||
dims=wav.dims,
|
dims=wav.dims,
|
||||||
coords={
|
coords={
|
||||||
**wav.coords,
|
**wav.coords,
|
||||||
"time": resampled_times,
|
"time": arrays.create_time_dim_from_array(
|
||||||
},
|
resampled_times,
|
||||||
attrs={
|
samplerate=target_samplerate,
|
||||||
**wav.attrs,
|
),
|
||||||
"samplerate": target_samplerate,
|
|
||||||
},
|
},
|
||||||
|
attrs=wav.attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def scale_audio(
|
|
||||||
audio: xr.DataArray,
|
|
||||||
eps: float = 10e-6,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
audio = audio - audio.mean()
|
|
||||||
return audio / np.add(np.abs(audio).max(), eps, dtype=audio.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
wav: xr.DataArray,
|
wav: xr.DataArray,
|
||||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
fft_win_length: float = FFT_WIN_LENGTH_S,
|
||||||
@ -249,12 +216,12 @@ def compute_spectrogram(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = crop_axis(
|
spec = arrays.crop_dim(
|
||||||
spec,
|
spec,
|
||||||
dim="frequency",
|
dim="frequency",
|
||||||
start=min_freq,
|
start=min_freq,
|
||||||
end=max_freq,
|
stop=max_freq,
|
||||||
)
|
).astype(dtype)
|
||||||
|
|
||||||
spec = scale_spectrogram(spec, scale=spec_scale)
|
spec = scale_spectrogram(spec, scale=spec_scale)
|
||||||
|
|
||||||
@ -262,172 +229,67 @@ def compute_spectrogram(
|
|||||||
spec = denoise_spectrogram(spec)
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
if max_scale_spec:
|
if max_scale_spec:
|
||||||
spec = max_scale_spectrogram(spec)
|
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||||
|
|
||||||
return spec
|
return spec.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
def crop_axis(
|
|
||||||
arr: xr.DataArray,
|
|
||||||
dim: str,
|
|
||||||
start: float,
|
|
||||||
end: float,
|
|
||||||
right_closed: bool = False,
|
|
||||||
left_closed: bool = True,
|
|
||||||
eps: float = 10e-6,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
coord = arr.coords[dim]
|
|
||||||
|
|
||||||
if not all(attr in coord.attrs for attr in ["min", "max"]):
|
|
||||||
raise ValueError(
|
|
||||||
f"Coordinate '{dim}' must have 'min' and 'max' attributes"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_min = coord.attrs["min"]
|
|
||||||
current_max = coord.attrs["max"]
|
|
||||||
|
|
||||||
if start < current_min or end > current_max:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot select axis '{dim}' from {start} to {end}. "
|
|
||||||
f"Axis range is {current_min} to {current_max}"
|
|
||||||
)
|
|
||||||
|
|
||||||
slice_end = end
|
|
||||||
if not right_closed:
|
|
||||||
slice_end = end - eps
|
|
||||||
|
|
||||||
slice_start = start
|
|
||||||
if not left_closed:
|
|
||||||
slice_start = start + eps
|
|
||||||
|
|
||||||
arr = arr.sel({dim: slice(slice_start, slice_end)})
|
|
||||||
|
|
||||||
arr.coords[dim].attrs.update(
|
|
||||||
min=start,
|
|
||||||
max=end,
|
|
||||||
)
|
|
||||||
|
|
||||||
return arr
|
|
||||||
|
|
||||||
|
|
||||||
def extend_axis(
|
|
||||||
arr: xr.DataArray,
|
|
||||||
dim: str,
|
|
||||||
start: float,
|
|
||||||
end: float,
|
|
||||||
fill_value: float = 0,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
coord = arr.coords[dim]
|
|
||||||
|
|
||||||
if not all(attr in coord.attrs for attr in ["min", "max", "period"]):
|
|
||||||
raise ValueError(
|
|
||||||
f"Coordinate '{dim}' must have 'min', 'max' and 'period' attributes"
|
|
||||||
" to extend axis"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_min = coord.attrs["min"]
|
|
||||||
current_max = coord.attrs["max"]
|
|
||||||
period = coord.attrs["period"]
|
|
||||||
|
|
||||||
coords = coord.data
|
|
||||||
|
|
||||||
if start < current_min:
|
|
||||||
new_coords = np.arange(
|
|
||||||
current_min,
|
|
||||||
start,
|
|
||||||
-period,
|
|
||||||
dtype=coord.dtype,
|
|
||||||
)[1:][::-1]
|
|
||||||
coords = np.concatenate([new_coords, coords])
|
|
||||||
|
|
||||||
if end > current_max:
|
|
||||||
new_coords = np.arange(
|
|
||||||
current_max,
|
|
||||||
end,
|
|
||||||
period,
|
|
||||||
dtype=coord.dtype,
|
|
||||||
)[1:]
|
|
||||||
coords = np.concatenate([coords, new_coords])
|
|
||||||
|
|
||||||
arr = arr.reindex(
|
|
||||||
{dim: coords},
|
|
||||||
fill_value=fill_value, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
arr.coords[dim].attrs.update(
|
|
||||||
min=start,
|
|
||||||
max=end,
|
|
||||||
)
|
|
||||||
|
|
||||||
return arr
|
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram(
|
def gen_mag_spectrogram(
|
||||||
audio: xr.DataArray,
|
wave: xr.DataArray,
|
||||||
window_len: float,
|
window_len: float,
|
||||||
overlap_perc: float,
|
overlap_perc: float,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
sampling_rate = audio.attrs["samplerate"]
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
|
step = arrays.get_dim_step(wave, dim="time")
|
||||||
|
sampling_rate = 1 / step
|
||||||
|
|
||||||
hop_len = window_len * (1 - overlap_perc)
|
hop_len = window_len * (1 - overlap_perc)
|
||||||
nfft = int(window_len * sampling_rate)
|
nfft = int(window_len * sampling_rate)
|
||||||
noverlap = int(overlap_perc * nfft)
|
noverlap = int(overlap_perc * nfft)
|
||||||
start_time = audio.time.attrs["min"]
|
|
||||||
end_time = audio.time.attrs["max"]
|
|
||||||
|
|
||||||
# compute spec
|
# compute spec
|
||||||
spec, _ = librosa.core.spectrum._spectrogram(
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
y=audio.data,
|
y=wave.data,
|
||||||
power=1,
|
power=1,
|
||||||
n_fft=nfft,
|
n_fft=nfft,
|
||||||
hop_length=nfft - noverlap,
|
hop_length=nfft - noverlap,
|
||||||
center=False,
|
center=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = xr.DataArray(
|
return xr.DataArray(
|
||||||
data=spec.astype(dtype),
|
data=spec.astype(dtype),
|
||||||
dims=["frequency", "time"],
|
dims=["frequency", "time"],
|
||||||
coords={
|
coords={
|
||||||
"frequency": np.linspace(
|
"frequency": arrays.create_frequency_dim_from_array(
|
||||||
0,
|
np.linspace(
|
||||||
sampling_rate / 2,
|
0,
|
||||||
spec.shape[0],
|
sampling_rate / 2,
|
||||||
endpoint=False,
|
spec.shape[0],
|
||||||
dtype=dtype,
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
step=sampling_rate / nfft,
|
||||||
),
|
),
|
||||||
"time": np.linspace(
|
"time": arrays.create_time_dim_from_array(
|
||||||
start_time,
|
np.linspace(
|
||||||
end_time - (window_len - hop_len),
|
start_time,
|
||||||
spec.shape[1],
|
end_time - (window_len - hop_len),
|
||||||
endpoint=False,
|
spec.shape[1],
|
||||||
dtype=dtype,
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
step=hop_len,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
attrs={
|
attrs={
|
||||||
**audio.attrs,
|
**wave.attrs,
|
||||||
|
"original_samplerate": sampling_rate,
|
||||||
"nfft": nfft,
|
"nfft": nfft,
|
||||||
"noverlap": noverlap,
|
"noverlap": noverlap,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add metadata to coordinates
|
|
||||||
spec.coords["time"].attrs.update(
|
|
||||||
unit="s",
|
|
||||||
long_name="Time",
|
|
||||||
min=start_time,
|
|
||||||
max=end_time - (window_len - hop_len),
|
|
||||||
period=(nfft - noverlap) / sampling_rate,
|
|
||||||
)
|
|
||||||
spec.coords["frequency"].attrs.update(
|
|
||||||
unit="Hz",
|
|
||||||
long_name="Frequency",
|
|
||||||
period=(sampling_rate / nfft),
|
|
||||||
min=0,
|
|
||||||
max=sampling_rate / 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def denoise_spectrogram(
|
def denoise_spectrogram(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
@ -436,10 +298,7 @@ def denoise_spectrogram(
|
|||||||
data=(spec - spec.mean("time")).clip(0),
|
data=(spec - spec.mean("time")).clip(0),
|
||||||
dims=spec.dims,
|
dims=spec.dims,
|
||||||
coords=spec.coords,
|
coords=spec.coords,
|
||||||
attrs={
|
attrs=spec.attrs,
|
||||||
**spec.attrs,
|
|
||||||
"denoised": 1,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -448,8 +307,14 @@ def scale_spectrogram(
|
|||||||
scale: str = SPEC_SCALE,
|
scale: str = SPEC_SCALE,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
|
|
||||||
if scale == "pcen":
|
if scale == "pcen":
|
||||||
return pcen(spec, dtype=dtype)
|
smoothing_constant = get_pcen_smoothing_constant(samplerate / 10)
|
||||||
|
return audio.pcen(
|
||||||
|
spec * (2**31),
|
||||||
|
smooth=smoothing_constant,
|
||||||
|
).astype(dtype)
|
||||||
|
|
||||||
if scale == "log":
|
if scale == "log":
|
||||||
return log_scale(spec, dtype=dtype)
|
return log_scale(spec, dtype=dtype)
|
||||||
@ -461,126 +326,25 @@ def log_scale(
|
|||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
nfft = spec.attrs["nfft"]
|
nfft = spec.attrs["nfft"]
|
||||||
sampling_rate = spec.attrs["samplerate"]
|
|
||||||
log_scaling = (
|
log_scaling = (
|
||||||
2.0
|
2.0
|
||||||
* (1.0 / sampling_rate)
|
* (1.0 / samplerate)
|
||||||
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||||
)
|
)
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||||
dims=spec.dims,
|
dims=spec.dims,
|
||||||
coords=spec.coords,
|
coords=spec.coords,
|
||||||
attrs={
|
attrs=spec.attrs,
|
||||||
**spec.attrs,
|
|
||||||
"scale": "log",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pcen(spec: xr.DataArray, dtype: DTypeLike = np.float32) -> xr.DataArray:
|
def get_pcen_smoothing_constant(
|
||||||
sampling_rate = spec.attrs["samplerate"]
|
sr: int,
|
||||||
data = librosa.pcen(
|
time_constant: float = 0.4,
|
||||||
spec.data * (2**31),
|
hop_length: int = 512,
|
||||||
sr=sampling_rate / 10,
|
) -> float:
|
||||||
)
|
t_frames = time_constant * sr / float(hop_length)
|
||||||
return xr.DataArray(
|
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
data=data.astype(dtype),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs={
|
|
||||||
**spec.attrs,
|
|
||||||
"scale": "pcen",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def max_scale_spectrogram(spec: xr.DataArray, eps=10e-6) -> xr.DataArray:
|
|
||||||
return xr.DataArray(
|
|
||||||
data=spec / np.add(spec.max(), eps, dtype=spec.dtype),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs={
|
|
||||||
**spec.attrs,
|
|
||||||
"max_scaled": 1,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resize_spectrogram(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
time_bins: int,
|
|
||||||
freq_bins: int,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
new_times = np.linspace(
|
|
||||||
spec.time[0],
|
|
||||||
spec.time[-1],
|
|
||||||
time_bins,
|
|
||||||
dtype=spec.time.dtype,
|
|
||||||
endpoint=True,
|
|
||||||
)
|
|
||||||
new_frequencies = np.linspace(
|
|
||||||
spec.frequency[0],
|
|
||||||
spec.frequency[-1],
|
|
||||||
freq_bins,
|
|
||||||
dtype=spec.frequency.dtype,
|
|
||||||
endpoint=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return spec.interp(
|
|
||||||
coords=dict(
|
|
||||||
time=new_times,
|
|
||||||
frequency=new_frequencies,
|
|
||||||
),
|
|
||||||
method="linear",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_dim_width(arr: xr.DataArray, dim: str) -> float:
|
|
||||||
coord = arr.coords[dim]
|
|
||||||
attrs = coord.attrs
|
|
||||||
if "min" in attrs and "max" in attrs:
|
|
||||||
return attrs["max"] - attrs["min"]
|
|
||||||
|
|
||||||
coord_min = coord.min()
|
|
||||||
coord_max = coord.max()
|
|
||||||
return float(coord_max - coord_min)
|
|
||||||
|
|
||||||
|
|
||||||
class RandomClipProvider:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clip_annotations: List[data.ClipAnnotation],
|
|
||||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
scale_audio: bool = SCALE_RAW_AUDIO,
|
|
||||||
):
|
|
||||||
self.target_sampling_rate = target_sampling_rate
|
|
||||||
self.scale_audio = scale_audio
|
|
||||||
self.clip_annotations = clip_annotations
|
|
||||||
|
|
||||||
def get_next_clip(self, clip: data.ClipAnnotation) -> data.ClipAnnotation:
|
|
||||||
tries = 0
|
|
||||||
while True:
|
|
||||||
random_clip = random.choice(self.clip_annotations)
|
|
||||||
|
|
||||||
if random_clip.clip != clip.clip:
|
|
||||||
return random_clip
|
|
||||||
|
|
||||||
tries += 1
|
|
||||||
if tries > 4:
|
|
||||||
raise ValueError("Could not find a different clip")
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip: data.ClipAnnotation,
|
|
||||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
random_clip = self.get_next_clip(clip)
|
|
||||||
|
|
||||||
wav = load_clip_audio(
|
|
||||||
random_clip.clip,
|
|
||||||
target_sampling_rate=self.target_sampling_rate,
|
|
||||||
scale=self.scale_audio,
|
|
||||||
)
|
|
||||||
|
|
||||||
return wav, random_clip
|
|
||||||
|
@ -68,7 +68,6 @@ def run_nms(
|
|||||||
params["fft_win_length"],
|
params["fft_win_length"],
|
||||||
params["fft_overlap"],
|
params["fft_overlap"],
|
||||||
)
|
)
|
||||||
print("duration", duration)
|
|
||||||
top_k = int(duration * params["nms_top_k_per_sec"])
|
top_k = int(duration * params["nms_top_k_per_sec"])
|
||||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||||
|
|
||||||
|
@ -1,91 +1,11 @@
|
|||||||
import os
|
from batdetect2.models.feature_extractors import (
|
||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from batdetect2.models.encoders import (
|
|
||||||
Net2DFast,
|
Net2DFast,
|
||||||
Net2DFastNoAttn,
|
Net2DFastNoAttn,
|
||||||
Net2DFastNoCoordConv,
|
Net2DFastNoCoordConv,
|
||||||
)
|
)
|
||||||
from batdetect2.models.typing import DetectionModel
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_MODEL_PATH = os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
|
||||||
"models",
|
|
||||||
"checkpoints",
|
|
||||||
"Net2DFast_UK_same.pth.tar",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
model_path: str = DEFAULT_MODEL_PATH,
|
|
||||||
load_weights: bool = True,
|
|
||||||
device: Union[torch.device, str, None] = None,
|
|
||||||
) -> Tuple[DetectionModel, dict]:
|
|
||||||
"""Load model from file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
|
|
||||||
load_weights (bool, optional): Load weights. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
model, params: Model and parameters.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: Model file not found.
|
|
||||||
ValueError: Unknown model name.
|
|
||||||
"""
|
|
||||||
if device is None:
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
if not os.path.isfile(model_path):
|
|
||||||
raise FileNotFoundError("Model file not found.")
|
|
||||||
|
|
||||||
net_params = torch.load(model_path, map_location=device)
|
|
||||||
|
|
||||||
params = net_params["params"]
|
|
||||||
|
|
||||||
model: DetectionModel
|
|
||||||
|
|
||||||
if params["model_name"] == "Net2DFast":
|
|
||||||
model = Net2DFast(
|
|
||||||
params["num_filters"],
|
|
||||||
num_classes=len(params["class_names"]),
|
|
||||||
emb_dim=params["emb_dim"],
|
|
||||||
ip_height=params["ip_height"],
|
|
||||||
resize_factor=params["resize_factor"],
|
|
||||||
)
|
|
||||||
elif params["model_name"] == "Net2DFastNoAttn":
|
|
||||||
model = Net2DFastNoAttn(
|
|
||||||
params["num_filters"],
|
|
||||||
num_classes=len(params["class_names"]),
|
|
||||||
emb_dim=params["emb_dim"],
|
|
||||||
ip_height=params["ip_height"],
|
|
||||||
resize_factor=params["resize_factor"],
|
|
||||||
)
|
|
||||||
elif params["model_name"] == "Net2DFastNoCoordConv":
|
|
||||||
model = Net2DFastNoCoordConv(
|
|
||||||
params["num_filters"],
|
|
||||||
num_classes=len(params["class_names"]),
|
|
||||||
emb_dim=params["emb_dim"],
|
|
||||||
ip_height=params["ip_height"],
|
|
||||||
resize_factor=params["resize_factor"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown model.")
|
|
||||||
|
|
||||||
if load_weights:
|
|
||||||
model.load_state_dict(net_params["state_dict"])
|
|
||||||
|
|
||||||
model = model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
return model, params
|
|
||||||
|
@ -1,100 +1,104 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
import pytorch_lightning as L
|
import pytorch_lightning as L
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
from batdetect2.data.preprocessing import preprocess_audio_clip
|
from batdetect2.data.preprocessing import (
|
||||||
from batdetect2.models.typing import EncoderModel, ModelOutput
|
preprocess_audio_clip,
|
||||||
from batdetect2.train import losses
|
PreprocessingConfig,
|
||||||
from batdetect2.train.dataset import TrainExample
|
)
|
||||||
|
from batdetect2.data.labels import ClassMapper
|
||||||
|
from batdetect2.models.feature_extractors import Net2DFast
|
||||||
from batdetect2.models.post_process import (
|
from batdetect2.models.post_process import (
|
||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
postprocess_model_outputs,
|
postprocess_model_outputs,
|
||||||
)
|
)
|
||||||
from batdetect2.train.preprocess import PreprocessingConfig
|
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
|
||||||
|
from batdetect2.train import losses
|
||||||
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
|
||||||
|
|
||||||
class DetectorModel(L.LightningModule):
|
class DetectorModel(L.LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: EncoderModel,
|
class_mapper: ClassMapper,
|
||||||
num_classes: int,
|
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
|
||||||
learning_rate: float = 1e-3,
|
learning_rate: float = 1e-3,
|
||||||
|
input_height: int = 128,
|
||||||
|
num_features: int = 32,
|
||||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||||
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
self.preprocessing_config = preprocessing_config
|
self.preprocessing_config = preprocessing_config
|
||||||
self.postprocessing_config = postprocessing_config
|
self.postprocessing_config = postprocessing_config
|
||||||
self.num_classes = num_classes
|
self.class_mapper = class_mapper
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
|
self.input_height = input_height
|
||||||
|
self.num_features = num_features
|
||||||
|
self.num_classes = class_mapper.num_classes
|
||||||
|
|
||||||
self.encoder = encoder
|
self.feature_extractor = feature_extractor_class(
|
||||||
|
input_height=input_height,
|
||||||
|
num_features=num_features,
|
||||||
|
)
|
||||||
|
|
||||||
self.classifier = nn.Conv2d(
|
self.classifier = nn.Conv2d(
|
||||||
self.encoder.num_filts // 4,
|
self.feature_extractor.num_features // 4,
|
||||||
self.num_classes + 1,
|
self.num_classes + 1,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bbox = nn.Conv2d(
|
self.bbox = nn.Conv2d(
|
||||||
self.encoder.num_filts // 4,
|
self.feature_extractor.num_features // 4,
|
||||||
2,
|
2,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||||
features = self.encoder(spec)
|
features = self.feature_extractor(spec)
|
||||||
|
|
||||||
classification_logits = self.classifier(features)
|
classification_logits = self.classifier(features)
|
||||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
classification_probs = torch.softmax(classification_logits, dim=1)
|
||||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
detection_probs=detection_probs,
|
detection_probs=detection_probs,
|
||||||
size_preds=self.bbox(features),
|
size_preds=self.bbox(features),
|
||||||
class_probs=classification_probs,
|
class_probs=classification_probs[:, :-1],
|
||||||
features=features,
|
features=features,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
||||||
config = self.preprocessing_config
|
|
||||||
|
|
||||||
return preprocess_audio_clip(
|
return preprocess_audio_clip(
|
||||||
clip,
|
clip,
|
||||||
target_sampling_rate=config.target_samplerate,
|
config=self.preprocessing_config,
|
||||||
scale_audio=config.scale_audio,
|
|
||||||
fft_win_length=config.fft_win_length,
|
|
||||||
fft_overlap=config.fft_overlap,
|
|
||||||
max_freq=config.max_freq,
|
|
||||||
min_freq=config.min_freq,
|
|
||||||
spec_scale=config.spec_scale,
|
|
||||||
denoise_spec_avg=config.denoise_spec_avg,
|
|
||||||
max_scale_spec=config.max_scale_spec,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_clip(self, clip: data.Clip):
|
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
|
||||||
|
spectrogram = self.compute_spectrogram(clip)
|
||||||
|
return self.feature_extractor(
|
||||||
|
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
|
||||||
spectrogram = self.compute_spectrogram(clip)
|
spectrogram = self.compute_spectrogram(clip)
|
||||||
spec_tensor = (
|
spec_tensor = (
|
||||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self(spec_tensor)
|
outputs = self(spec_tensor)
|
||||||
|
|
||||||
config = self.postprocessing_config
|
|
||||||
return postprocess_model_outputs(
|
return postprocess_model_outputs(
|
||||||
outputs,
|
outputs,
|
||||||
[clip],
|
[clip],
|
||||||
nms_kernel_size=config.nms_kernel_size,
|
class_mapper=self.class_mapper,
|
||||||
detection_threshold=config.detection_threshold,
|
config=self.postprocessing_config,
|
||||||
min_freq=config.min_freq,
|
)[0]
|
||||||
max_freq=config.max_freq,
|
|
||||||
top_k_per_sec=config.top_k_per_sec,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
@ -124,21 +128,8 @@ class DetectorModel(L.LightningModule):
|
|||||||
self,
|
self,
|
||||||
batch: TrainExample,
|
batch: TrainExample,
|
||||||
):
|
):
|
||||||
features = self.encoder(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
|
loss = self.compute_loss(outputs, batch)
|
||||||
classification_logits = self.classifier(features)
|
|
||||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
|
||||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
loss = self.compute_loss(
|
|
||||||
ModelOutput(
|
|
||||||
detection_probs=detection_probs,
|
|
||||||
size_preds=self.bbox(features),
|
|
||||||
class_probs=classification_probs,
|
|
||||||
features=features,
|
|
||||||
),
|
|
||||||
batch,
|
|
||||||
)
|
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ import torch.fft
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.models.typing import EncoderModel
|
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvBlockDownCoordF,
|
ConvBlockDownCoordF,
|
||||||
ConvBlockDownStandard,
|
ConvBlockDownStandard,
|
||||||
@ -13,6 +12,7 @@ from batdetect2.models.blocks import (
|
|||||||
ConvBlockUpStandard,
|
ConvBlockUpStandard,
|
||||||
SelfAttention,
|
SelfAttention,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models.typing import FeatureExtractorModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
@ -21,84 +21,84 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Net2DFast(EncoderModel):
|
class Net2DFast(FeatureExtractorModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_filts: int,
|
num_features: int,
|
||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_filts = num_filts
|
self.num_features = num_features
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.bottleneck_height = self.input_height // 32
|
self.bottleneck_height = self.input_height // 32
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||||
1,
|
1,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height,
|
self.input_height,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.input_height // 2,
|
self.input_height // 2,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.num_filts,
|
self.num_features,
|
||||||
self.input_height // 4,
|
self.input_height // 4,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
self.conv_dn_3 = nn.Conv2d(
|
||||||
self.num_filts,
|
self.num_features,
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
3,
|
3,
|
||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||||
|
|
||||||
# bottleneck
|
# bottleneck
|
||||||
self.conv_1d = nn.Conv2d(
|
self.conv_1d = nn.Conv2d(
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
(self.input_height // 8, 1),
|
(self.input_height // 8, 1),
|
||||||
padding=0,
|
padding=0,
|
||||||
)
|
)
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||||
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
|
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
self.conv_up_2 = ConvBlockUpF(
|
self.conv_up_2 = ConvBlockUpF(
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.input_height // 8,
|
self.input_height // 8,
|
||||||
)
|
)
|
||||||
self.conv_up_3 = ConvBlockUpF(
|
self.conv_up_3 = ConvBlockUpF(
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height // 4,
|
self.input_height // 4,
|
||||||
)
|
)
|
||||||
self.conv_up_4 = ConvBlockUpF(
|
self.conv_up_4 = ConvBlockUpF(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height // 2,
|
self.input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
self.conv_op = nn.Conv2d(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||||
|
|
||||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||||
h, w = spec.shape[2:]
|
h, w = spec.shape[2:]
|
||||||
@ -135,81 +135,81 @@ class Net2DFast(EncoderModel):
|
|||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoAttn(EncoderModel):
|
class Net2DFastNoAttn(FeatureExtractorModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_filts: int,
|
num_features: int,
|
||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_filts = num_filts
|
self.num_features = num_features
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.bottleneck_height = self.input_height // 32
|
self.bottleneck_height = self.input_height // 32
|
||||||
|
|
||||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||||
1,
|
1,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height,
|
self.input_height,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.input_height // 2,
|
self.input_height // 2,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.num_filts,
|
self.num_features,
|
||||||
self.input_height // 4,
|
self.input_height // 4,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
self.conv_dn_3 = nn.Conv2d(
|
||||||
self.num_filts,
|
self.num_features,
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
3,
|
3,
|
||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||||
|
|
||||||
self.conv_1d = nn.Conv2d(
|
self.conv_1d = nn.Conv2d(
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
(self.input_height // 8, 1),
|
(self.input_height // 8, 1),
|
||||||
padding=0,
|
padding=0,
|
||||||
)
|
)
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||||
|
|
||||||
self.conv_up_2 = ConvBlockUpF(
|
self.conv_up_2 = ConvBlockUpF(
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.input_height // 8,
|
self.input_height // 8,
|
||||||
)
|
)
|
||||||
self.conv_up_3 = ConvBlockUpF(
|
self.conv_up_3 = ConvBlockUpF(
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height // 4,
|
self.input_height // 4,
|
||||||
)
|
)
|
||||||
self.conv_up_4 = ConvBlockUpF(
|
self.conv_up_4 = ConvBlockUpF(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height // 2,
|
self.input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
self.conv_op = nn.Conv2d(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = self.conv_dn_0(spec)
|
x1 = self.conv_dn_0(spec)
|
||||||
@ -227,80 +227,80 @@ class Net2DFastNoAttn(EncoderModel):
|
|||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoCoordConv(EncoderModel):
|
class Net2DFastNoCoordConv(FeatureExtractorModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_filts: int,
|
num_features: int,
|
||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_filts = num_filts
|
self.num_features = num_features
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.bottleneck_height = self.input_height // 32
|
self.bottleneck_height = self.input_height // 32
|
||||||
|
|
||||||
self.conv_dn_0 = ConvBlockDownStandard(
|
self.conv_dn_0 = ConvBlockDownStandard(
|
||||||
1,
|
1,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_1 = ConvBlockDownStandard(
|
self.conv_dn_1 = ConvBlockDownStandard(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_2 = ConvBlockDownStandard(
|
self.conv_dn_2 = ConvBlockDownStandard(
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.num_filts,
|
self.num_features,
|
||||||
k_size=3,
|
k_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
self.conv_dn_3 = nn.Conv2d(
|
||||||
self.num_filts,
|
self.num_features,
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
3,
|
3,
|
||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||||
|
|
||||||
self.conv_1d = nn.Conv2d(
|
self.conv_1d = nn.Conv2d(
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
(self.input_height // 8, 1),
|
(self.input_height // 8, 1),
|
||||||
padding=0,
|
padding=0,
|
||||||
)
|
)
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||||
|
|
||||||
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
|
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||||
|
|
||||||
self.conv_up_2 = ConvBlockUpStandard(
|
self.conv_up_2 = ConvBlockUpStandard(
|
||||||
self.num_filts * 2,
|
self.num_features * 2,
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.input_height // 8,
|
self.input_height // 8,
|
||||||
)
|
)
|
||||||
self.conv_up_3 = ConvBlockUpStandard(
|
self.conv_up_3 = ConvBlockUpStandard(
|
||||||
self.num_filts // 2,
|
self.num_features // 2,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height // 4,
|
self.input_height // 4,
|
||||||
)
|
)
|
||||||
self.conv_up_4 = ConvBlockUpStandard(
|
self.conv_up_4 = ConvBlockUpStandard(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.input_height // 2,
|
self.input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
self.conv_op = nn.Conv2d(
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
self.num_filts // 4,
|
self.num_features // 4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = self.conv_dn_0(spec)
|
x1 = self.conv_dn_0(spec)
|
@ -8,6 +8,7 @@ import torch
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.data.labels import ClassMapper
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -36,11 +37,8 @@ TagFunction = Callable[[int], List[data.Tag]]
|
|||||||
def postprocess_model_outputs(
|
def postprocess_model_outputs(
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
clips: List[data.Clip],
|
clips: List[data.Clip],
|
||||||
nms_kernel_size: int = NMS_KERNEL_SIZE,
|
class_mapper: ClassMapper,
|
||||||
detection_threshold: float = DETECTION_THRESHOLD,
|
config: PostprocessConfig,
|
||||||
min_freq: int = 10000,
|
|
||||||
max_freq: int = 120000,
|
|
||||||
top_k_per_sec: int = TOP_K_PER_SEC,
|
|
||||||
) -> List[data.ClipPrediction]:
|
) -> List[data.ClipPrediction]:
|
||||||
"""Postprocesses model outputs to generate clip predictions.
|
"""Postprocesses model outputs to generate clip predictions.
|
||||||
|
|
||||||
@ -57,16 +55,8 @@ def postprocess_model_outputs(
|
|||||||
clips
|
clips
|
||||||
List of clips for which predictions are made. The number of clips
|
List of clips for which predictions are made. The number of clips
|
||||||
must match the batch dimension of the model outputs.
|
must match the batch dimension of the model outputs.
|
||||||
nms_kernel_size
|
config
|
||||||
Size of the non-maximum suppression kernel. Default is 9.
|
Configuration for postprocessing model outputs.
|
||||||
detection_threshold
|
|
||||||
Detection threshold. Default is 0.01.
|
|
||||||
min_freq
|
|
||||||
Minimum frequency. Default is 10000.
|
|
||||||
max_freq
|
|
||||||
Maximum frequency. Default is 120000.
|
|
||||||
top_k_per_sec
|
|
||||||
Top k per second. Default is 200.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -90,14 +80,14 @@ def postprocess_model_outputs(
|
|||||||
|
|
||||||
detection_probs = non_max_suppression(
|
detection_probs = non_max_suppression(
|
||||||
outputs.detection_probs,
|
outputs.detection_probs,
|
||||||
kernel_size=nms_kernel_size,
|
kernel_size=config.nms_kernel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
duration = clips[0].end_time - clips[0].start_time
|
duration = clips[0].end_time - clips[0].start_time
|
||||||
|
|
||||||
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
|
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
|
||||||
detection_probs,
|
detection_probs,
|
||||||
int(top_k_per_sec * duration / 2),
|
int(config.top_k_per_sec * duration / 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
predictions: List[data.ClipPrediction] = []
|
predictions: List[data.ClipPrediction] = []
|
||||||
@ -118,9 +108,10 @@ def postprocess_model_outputs(
|
|||||||
size_preds,
|
size_preds,
|
||||||
class_probs,
|
class_probs,
|
||||||
features,
|
features,
|
||||||
min_freq=min_freq,
|
class_mapper=class_mapper,
|
||||||
max_freq=max_freq,
|
min_freq=config.min_freq,
|
||||||
detection_threshold=detection_threshold,
|
max_freq=config.max_freq,
|
||||||
|
detection_threshold=config.detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
predictions.append(
|
predictions.append(
|
||||||
@ -141,7 +132,7 @@ def compute_sound_events_from_outputs(
|
|||||||
size_preds: torch.Tensor,
|
size_preds: torch.Tensor,
|
||||||
class_probs: torch.Tensor,
|
class_probs: torch.Tensor,
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
tag_fn: TagFunction = lambda _: [],
|
class_mapper: ClassMapper,
|
||||||
min_freq: int = 10000,
|
min_freq: int = 10000,
|
||||||
max_freq: int = 120000,
|
max_freq: int = 120000,
|
||||||
detection_threshold: float = DETECTION_THRESHOLD,
|
detection_threshold: float = DETECTION_THRESHOLD,
|
||||||
@ -160,7 +151,6 @@ def compute_sound_events_from_outputs(
|
|||||||
predictions: List[data.SoundEventPrediction] = []
|
predictions: List[data.SoundEventPrediction] = []
|
||||||
for score, x, y in zip(scores, x_pos, y_pos):
|
for score, x, y in zip(scores, x_pos, y_pos):
|
||||||
width, height = size_preds[:, y, x]
|
width, height = size_preds[:, y, x]
|
||||||
print(width, height)
|
|
||||||
class_prob = class_probs[:, y, x]
|
class_prob = class_probs[:, y, x]
|
||||||
feature = features[:, y, x]
|
feature = features[:, y, x]
|
||||||
|
|
||||||
@ -191,7 +181,7 @@ def compute_sound_events_from_outputs(
|
|||||||
predicted_tags: List[data.PredictedTag] = []
|
predicted_tags: List[data.PredictedTag] = []
|
||||||
|
|
||||||
for label_id, class_score in enumerate(class_prob):
|
for label_id, class_score in enumerate(class_prob):
|
||||||
corresponding_tags = tag_fn(label_id)
|
corresponding_tags = class_mapper.inverse_transform(label_id)
|
||||||
predicted_tags.extend(
|
predicted_tags.extend(
|
||||||
[
|
[
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
|
@ -4,6 +4,11 @@ from typing import NamedTuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelOutput",
|
||||||
|
"FeatureExtractorModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelOutput(NamedTuple):
|
class ModelOutput(NamedTuple):
|
||||||
"""Output of the detection model.
|
"""Output of the detection model.
|
||||||
@ -36,12 +41,11 @@ class ModelOutput(NamedTuple):
|
|||||||
"""Tensor with intermediate features."""
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
class EncoderModel(ABC, nn.Module):
|
class FeatureExtractorModel(ABC, nn.Module):
|
||||||
|
|
||||||
input_height: int
|
input_height: int
|
||||||
"""Height of the input spectrogram."""
|
"""Height of the input spectrogram."""
|
||||||
|
|
||||||
num_filts: int
|
num_features: int
|
||||||
"""Dimension of the feature tensor."""
|
"""Dimension of the feature tensor."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
244
batdetect2/train/augmentations.py
Normal file
244
batdetect2/train/augmentations.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
|
||||||
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
|
|
||||||
|
AUGMENTATION_PROBABILITY = 0.2
|
||||||
|
MAX_DELAY = 0.005
|
||||||
|
STRETCH_SQUEEZE_DELTA = 0.04
|
||||||
|
MASK_MAX_TIME_PERC: float = 0.05
|
||||||
|
MASK_MAX_FREQ_PERC: float = 0.10
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_apply(
|
||||||
|
augmentation: Callable,
|
||||||
|
prob: float = AUGMENTATION_PROBABILITY,
|
||||||
|
) -> Callable:
|
||||||
|
"""Apply an augmentation with a given probability."""
|
||||||
|
|
||||||
|
@wraps(augmentation)
|
||||||
|
def _augmentation(x):
|
||||||
|
if np.random.rand() > prob:
|
||||||
|
return x
|
||||||
|
return augmentation(x)
|
||||||
|
|
||||||
|
return _augmentation
|
||||||
|
|
||||||
|
|
||||||
|
def select_random_subclip(
|
||||||
|
train_example: xr.Dataset,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
proportion: float = 0.9,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Select a random subclip from a clip."""
|
||||||
|
|
||||||
|
time_coords = train_example.coords["time"]
|
||||||
|
|
||||||
|
start_time = time_coords.attrs.get("min", time_coords.min())
|
||||||
|
end_time = time_coords.attrs.get("max", time_coords.max())
|
||||||
|
|
||||||
|
if duration is None:
|
||||||
|
duration = (end_time - start_time) * proportion
|
||||||
|
|
||||||
|
start_time = np.random.uniform(start_time, end_time - duration)
|
||||||
|
return train_example.sel(time=slice(start_time, start_time + duration))
|
||||||
|
|
||||||
|
|
||||||
|
def combine_audio(
|
||||||
|
audio1: xr.DataArray,
|
||||||
|
audio2: xr.DataArray,
|
||||||
|
alpha: Optional[float] = None,
|
||||||
|
min_alpha: float = 0.3,
|
||||||
|
max_alpha: float = 0.7,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Combine two audio clips."""
|
||||||
|
|
||||||
|
if alpha is None:
|
||||||
|
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||||
|
|
||||||
|
return alpha * audio1 + (1 - alpha) * audio2.data
|
||||||
|
|
||||||
|
|
||||||
|
# def random_mix(
|
||||||
|
# audio: xr.DataArray,
|
||||||
|
# clip: data.ClipAnnotation,
|
||||||
|
# provider: Optional[ClipProvider] = None,
|
||||||
|
# alpha: Optional[float] = None,
|
||||||
|
# min_alpha: float = 0.3,
|
||||||
|
# max_alpha: float = 0.7,
|
||||||
|
# join_annotations: bool = True,
|
||||||
|
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
# """Mix two audio clips."""
|
||||||
|
# if provider is None:
|
||||||
|
# raise ValueError("No audio provider given.")
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# other_audio, other_clip = provider(clip)
|
||||||
|
# except (StopIteration, ValueError):
|
||||||
|
# raise ValueError("No more audio sources available.")
|
||||||
|
#
|
||||||
|
# new_audio = combine_audio(
|
||||||
|
# audio,
|
||||||
|
# other_audio,
|
||||||
|
# alpha=alpha,
|
||||||
|
# min_alpha=min_alpha,
|
||||||
|
# max_alpha=max_alpha,
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if join_annotations:
|
||||||
|
# clip = clip.model_copy(
|
||||||
|
# update=dict(
|
||||||
|
# sound_events=clip.sound_events + other_clip.sound_events,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# return new_audio, clip
|
||||||
|
|
||||||
|
|
||||||
|
def add_echo(
|
||||||
|
train_example: xr.Dataset,
|
||||||
|
delay: Optional[float] = None,
|
||||||
|
alpha: Optional[float] = None,
|
||||||
|
min_alpha: float = 0.0,
|
||||||
|
max_alpha: float = 1.0,
|
||||||
|
max_delay: float = MAX_DELAY,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Add a delay to the audio."""
|
||||||
|
if delay is None:
|
||||||
|
delay = np.random.uniform(0, max_delay)
|
||||||
|
|
||||||
|
if alpha is None:
|
||||||
|
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||||
|
|
||||||
|
spec = train_example["spectrogram"]
|
||||||
|
|
||||||
|
time_coords = spec.coords["time"]
|
||||||
|
start_time = time_coords.attrs["min"]
|
||||||
|
end_time = time_coords.attrs["max"]
|
||||||
|
step = (end_time - start_time) / time_coords.size
|
||||||
|
|
||||||
|
spec_delay = spec.shift(time=int(delay / step), fill_value=0)
|
||||||
|
|
||||||
|
return train_example.assign(spectrogram=spec + alpha * spec_delay)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_volume(
|
||||||
|
train_example: xr.Dataset,
|
||||||
|
factor: Optional[float] = None,
|
||||||
|
max_scaling: float = 2,
|
||||||
|
min_scaling: float = 0,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Scale the volume of a spectrogram."""
|
||||||
|
if factor is None:
|
||||||
|
factor = np.random.uniform(min_scaling, max_scaling)
|
||||||
|
|
||||||
|
return train_example.assign(
|
||||||
|
spectrogram=train_example["spectrogram"] * factor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def warp_spectrogram(
|
||||||
|
train_example: xr.Dataset,
|
||||||
|
factor: Optional[float] = None,
|
||||||
|
delta: float = STRETCH_SQUEEZE_DELTA,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Warp a spectrogram."""
|
||||||
|
if factor is None:
|
||||||
|
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||||
|
|
||||||
|
time_coords = train_example.coords["time"]
|
||||||
|
start_time = time_coords.attrs["min"]
|
||||||
|
end_time = time_coords.attrs["max"]
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
new_time = np.linspace(
|
||||||
|
start_time,
|
||||||
|
start_time + duration * factor,
|
||||||
|
train_example.time.size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_example.interp(time=new_time)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_axis(
|
||||||
|
train_example: xr.Dataset,
|
||||||
|
dim: str,
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
mask_all: bool = False,
|
||||||
|
mask_value: float = 0,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
if dim not in train_example.dims:
|
||||||
|
raise ValueError(f"Axis {dim} not found in array")
|
||||||
|
|
||||||
|
coord = train_example.coords[dim]
|
||||||
|
condition = (coord < start) | (coord > end)
|
||||||
|
|
||||||
|
if mask_all:
|
||||||
|
return train_example.where(condition, other=mask_value)
|
||||||
|
|
||||||
|
return train_example.assign(
|
||||||
|
spectrogram=train_example.spectrogram.where(
|
||||||
|
condition, other=mask_value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_time(
|
||||||
|
train_example: xr.Dataset,
|
||||||
|
max_time_mask: float = MASK_MAX_TIME_PERC,
|
||||||
|
max_num_masks: int = 3,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Mask a random section of the time axis."""
|
||||||
|
|
||||||
|
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||||
|
|
||||||
|
time_coord = train_example.coords["time"]
|
||||||
|
start_time = time_coord.attrs.get("min", time_coord.min())
|
||||||
|
end_time = time_coord.attrs.get("max", time_coord.max())
|
||||||
|
|
||||||
|
for _ in range(num_masks):
|
||||||
|
mask_size = np.random.uniform(0, max_time_mask)
|
||||||
|
start = np.random.uniform(start_time, end_time - mask_size)
|
||||||
|
end = start + mask_size
|
||||||
|
train_example = mask_axis(train_example, "time", start, end)
|
||||||
|
|
||||||
|
return train_example
|
||||||
|
|
||||||
|
|
||||||
|
def mask_frequency(
|
||||||
|
train_example: xr.Dataset,
|
||||||
|
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
||||||
|
max_num_masks: int = 3,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Mask a random section of the frequency axis."""
|
||||||
|
|
||||||
|
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||||
|
|
||||||
|
freq_coord = train_example.coords["frequency"]
|
||||||
|
min_freq = freq_coord.min()
|
||||||
|
max_freq = freq_coord.max()
|
||||||
|
|
||||||
|
for _ in range(num_masks):
|
||||||
|
mask_size = np.random.uniform(0, max_freq_mask)
|
||||||
|
start = np.random.uniform(min_freq, max_freq - mask_size)
|
||||||
|
end = start + mask_size
|
||||||
|
train_example = mask_axis(train_example, "frequency", start, end)
|
||||||
|
|
||||||
|
return train_example
|
||||||
|
|
||||||
|
|
||||||
|
AUGMENTATIONS: List[Augmentation] = [
|
||||||
|
select_random_subclip,
|
||||||
|
add_echo,
|
||||||
|
scale_volume,
|
||||||
|
mask_time,
|
||||||
|
mask_frequency,
|
||||||
|
]
|
@ -1,16 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
from typing import NamedTuple
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence, Union, Dict
|
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from batdetect2.train.preprocess import PreprocessingConfig
|
from batdetect2.train.preprocess import PreprocessingConfig
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
@ -33,8 +31,13 @@ def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
|||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
def __init__(self, filenames: Sequence[PathLike]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
filenames: Sequence[PathLike],
|
||||||
|
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
|
||||||
|
):
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.filenames)
|
return len(self.filenames)
|
||||||
@ -54,7 +57,7 @@ class LabeledDataset(Dataset):
|
|||||||
return cls(get_files(directory, extension))
|
return cls(get_files(directory, extension))
|
||||||
|
|
||||||
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
|
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
|
||||||
dataset = xr.open_dataset(filename)
|
dataset = self.get_dataset(filename)
|
||||||
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
|
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
|
||||||
return {
|
return {
|
||||||
"spectrogram": spectrogram,
|
"spectrogram": spectrogram,
|
||||||
@ -63,6 +66,15 @@ class LabeledDataset(Dataset):
|
|||||||
"size": torch.tensor(dataset["size"].values),
|
"size": torch.tensor(dataset["size"].values),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset:
|
||||||
|
if self.transform is not None:
|
||||||
|
return self.transform(dataset)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def get_dataset(self, idx):
|
||||||
|
return xr.open_dataset(self.filenames[idx])
|
||||||
|
|
||||||
def get_spectrogram(self, idx):
|
def get_spectrogram(self, idx):
|
||||||
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
||||||
|
|
||||||
|
@ -9,21 +9,12 @@ from tqdm.auto import tqdm
|
|||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.labels import TARGET_SIGMA, LabelFn, generate_heatmaps
|
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
||||||
from batdetect2.data.preprocessing import (
|
from batdetect2.data.preprocessing import (
|
||||||
DENOISE_SPEC_AVG,
|
|
||||||
FFT_OVERLAP,
|
|
||||||
FFT_WIN_LENGTH_S,
|
|
||||||
MAX_FREQ_HZ,
|
|
||||||
MAX_SCALE_SPEC,
|
|
||||||
MIN_FREQ_HZ,
|
|
||||||
SCALE_RAW_AUDIO,
|
|
||||||
SPEC_SCALE,
|
|
||||||
TARGET_SAMPLERATE_HZ,
|
|
||||||
preprocess_audio_clip,
|
preprocess_audio_clip,
|
||||||
|
PreprocessingConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
@ -34,61 +25,24 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseModel):
|
|
||||||
"""Configuration for preprocessing data."""
|
|
||||||
|
|
||||||
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
|
||||||
|
|
||||||
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
|
||||||
|
|
||||||
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
|
||||||
|
|
||||||
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
|
||||||
|
|
||||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
|
||||||
|
|
||||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
|
||||||
|
|
||||||
spec_scale: str = Field(default=SPEC_SCALE)
|
|
||||||
|
|
||||||
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
|
||||||
|
|
||||||
max_scale_spec: bool = MAX_SCALE_SPEC
|
|
||||||
|
|
||||||
target_sigma: float = Field(default=TARGET_SIGMA, gt=0)
|
|
||||||
|
|
||||||
class_labels: Sequence[str] = ["bat"]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
label_fn: LabelFn = lambda _: None,
|
class_mapper: ClassMapper,
|
||||||
config: Optional[PreprocessingConfig] = None,
|
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||||
|
target_sigma: float = TARGET_SIGMA,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Generate a training example."""
|
"""Generate a training example."""
|
||||||
if config is None:
|
|
||||||
config = PreprocessingConfig()
|
|
||||||
|
|
||||||
spectrogram = preprocess_audio_clip(
|
spectrogram = preprocess_audio_clip(
|
||||||
clip_annotation.clip,
|
clip_annotation.clip,
|
||||||
target_sampling_rate=config.target_samplerate,
|
config=preprocessing_config,
|
||||||
scale_audio=config.scale_audio,
|
|
||||||
fft_win_length=config.fft_win_length,
|
|
||||||
fft_overlap=config.fft_overlap,
|
|
||||||
max_freq=config.max_freq,
|
|
||||||
min_freq=config.min_freq,
|
|
||||||
spec_scale=config.spec_scale,
|
|
||||||
denoise_spec_avg=config.denoise_spec_avg,
|
|
||||||
max_scale_spec=config.max_scale_spec,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
spectrogram,
|
spectrogram,
|
||||||
target_sigma=config.target_sigma,
|
class_mapper,
|
||||||
num_classes=len(config.class_labels),
|
target_sigma=target_sigma,
|
||||||
class_labels=list(config.class_labels),
|
|
||||||
label_fn=label_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
dataset = xr.Dataset(
|
||||||
@ -102,7 +56,8 @@ def generate_train_example(
|
|||||||
|
|
||||||
return dataset.assign_attrs(
|
return dataset.assign_attrs(
|
||||||
title=f"Training example for {clip_annotation.uuid}",
|
title=f"Training example for {clip_annotation.uuid}",
|
||||||
configuration=config.model_dump_json(),
|
preprocessing_configuration=preprocessing_config.model_dump_json(),
|
||||||
|
target_sigma=target_sigma,
|
||||||
clip_annotation=clip_annotation.model_dump_json(),
|
clip_annotation=clip_annotation.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -148,9 +103,10 @@ def preprocess_single_annotation(
|
|||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
output_dir: PathLike,
|
output_dir: PathLike,
|
||||||
config: PreprocessingConfig,
|
config: PreprocessingConfig,
|
||||||
|
class_mapper: ClassMapper,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
label_fn: LabelFn = lambda _: None,
|
target_sigma: float = TARGET_SIGMA,
|
||||||
) -> None:
|
) -> None:
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
@ -162,8 +118,9 @@ def preprocess_single_annotation(
|
|||||||
|
|
||||||
sample = generate_train_example(
|
sample = generate_train_example(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
label_fn=label_fn,
|
class_mapper,
|
||||||
config=config,
|
preprocessing_config=config,
|
||||||
|
target_sigma=target_sigma,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_to_file(sample, path)
|
save_to_file(sample, path)
|
||||||
@ -172,10 +129,11 @@ def preprocess_single_annotation(
|
|||||||
def preprocess_annotations(
|
def preprocess_annotations(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
output_dir: PathLike,
|
output_dir: PathLike,
|
||||||
|
class_mapper: ClassMapper,
|
||||||
|
target_sigma: float = TARGET_SIGMA,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
config_file: Optional[PathLike] = None,
|
config_file: Optional[PathLike] = None,
|
||||||
label_fn: LabelFn = lambda _: None,
|
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -198,9 +156,10 @@ def preprocess_annotations(
|
|||||||
preprocess_single_annotation,
|
preprocess_single_annotation,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
config=config,
|
config=config,
|
||||||
|
class_mapper=class_mapper,
|
||||||
filename_fn=filename_fn,
|
filename_fn=filename_fn,
|
||||||
replace=replace,
|
replace=replace,
|
||||||
label_fn=label_fn,
|
target_sigma=target_sigma,
|
||||||
),
|
),
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
),
|
),
|
||||||
|
@ -28,7 +28,7 @@ dependencies = [
|
|||||||
"torch>=1.13.1",
|
"torch>=1.13.1",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"soundevent[audio,geometry,plot]>=1.3.5",
|
"soundevent[audio,geometry,plot]>=2.0",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
|
@ -10,6 +10,8 @@
|
|||||||
-e file:.
|
-e file:.
|
||||||
absl-py==2.1.0
|
absl-py==2.1.0
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
|
affine==2.4.0
|
||||||
|
# via rasterio
|
||||||
aiobotocore==2.12.3
|
aiobotocore==2.12.3
|
||||||
# via s3fs
|
# via s3fs
|
||||||
aiohttp==3.9.5
|
aiohttp==3.9.5
|
||||||
@ -37,6 +39,7 @@ async-timeout==4.0.3
|
|||||||
# via redis
|
# via redis
|
||||||
attrs==23.2.0
|
attrs==23.2.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
|
# via rasterio
|
||||||
audioread==3.0.1
|
audioread==3.0.1
|
||||||
# via librosa
|
# via librosa
|
||||||
backcall==0.2.0
|
backcall==0.2.0
|
||||||
@ -57,6 +60,7 @@ botocore==1.34.69
|
|||||||
# via s3transfer
|
# via s3transfer
|
||||||
certifi==2024.2.2
|
certifi==2024.2.2
|
||||||
# via netcdf4
|
# via netcdf4
|
||||||
|
# via rasterio
|
||||||
# via requests
|
# via requests
|
||||||
cf-xarray==0.9.0
|
cf-xarray==0.9.0
|
||||||
# via batdetect2
|
# via batdetect2
|
||||||
@ -68,9 +72,16 @@ charset-normalizer==3.3.2
|
|||||||
# via requests
|
# via requests
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
# via batdetect2
|
# via batdetect2
|
||||||
|
# via click-plugins
|
||||||
|
# via cligj
|
||||||
# via lightning
|
# via lightning
|
||||||
# via lightning-cloud
|
# via lightning-cloud
|
||||||
|
# via rasterio
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
|
click-plugins==1.1.1
|
||||||
|
# via rasterio
|
||||||
|
cligj==0.7.2
|
||||||
|
# via rasterio
|
||||||
comm==0.2.2
|
comm==0.2.2
|
||||||
# via ipykernel
|
# via ipykernel
|
||||||
contourpy==1.1.1
|
contourpy==1.1.1
|
||||||
@ -136,6 +147,7 @@ idna==3.7
|
|||||||
importlib-metadata==7.1.0
|
importlib-metadata==7.1.0
|
||||||
# via jupyter-client
|
# via jupyter-client
|
||||||
# via markdown
|
# via markdown
|
||||||
|
# via rasterio
|
||||||
importlib-resources==6.4.0
|
importlib-resources==6.4.0
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
# via typeshed-client
|
# via typeshed-client
|
||||||
@ -229,9 +241,11 @@ numpy==1.24.4
|
|||||||
# via onnx
|
# via onnx
|
||||||
# via pandas
|
# via pandas
|
||||||
# via pytorch-lightning
|
# via pytorch-lightning
|
||||||
|
# via rasterio
|
||||||
# via scikit-learn
|
# via scikit-learn
|
||||||
# via scipy
|
# via scipy
|
||||||
# via shapely
|
# via shapely
|
||||||
|
# via snuggs
|
||||||
# via soxr
|
# via soxr
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
# via tensorboardx
|
# via tensorboardx
|
||||||
@ -335,6 +349,7 @@ pyjwt==2.8.0
|
|||||||
# via lightning-cloud
|
# via lightning-cloud
|
||||||
pyparsing==3.1.2
|
pyparsing==3.1.2
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
|
# via snuggs
|
||||||
pytest==8.1.1
|
pytest==8.1.1
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via arrow
|
# via arrow
|
||||||
@ -361,6 +376,8 @@ pyyaml==6.0.1
|
|||||||
pyzmq==26.0.0
|
pyzmq==26.0.0
|
||||||
# via ipykernel
|
# via ipykernel
|
||||||
# via jupyter-client
|
# via jupyter-client
|
||||||
|
rasterio==1.3.10
|
||||||
|
# via soundevent
|
||||||
readchar==4.0.6
|
readchar==4.0.6
|
||||||
# via inquirer
|
# via inquirer
|
||||||
redis==5.0.4
|
redis==5.0.4
|
||||||
@ -390,6 +407,7 @@ scipy==1.10.1
|
|||||||
# via soundevent
|
# via soundevent
|
||||||
setuptools==69.5.1
|
setuptools==69.5.1
|
||||||
# via lightning-utilities
|
# via lightning-utilities
|
||||||
|
# via rasterio
|
||||||
# via readchar
|
# via readchar
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
shapely==2.0.3
|
shapely==2.0.3
|
||||||
@ -402,7 +420,9 @@ six==1.16.0
|
|||||||
# via tensorboard
|
# via tensorboard
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
# via anyio
|
# via anyio
|
||||||
soundevent==1.3.5
|
snuggs==1.4.7
|
||||||
|
# via rasterio
|
||||||
|
soundevent==2.0.0
|
||||||
# via batdetect2
|
# via batdetect2
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
# via librosa
|
# via librosa
|
||||||
|
@ -10,6 +10,8 @@
|
|||||||
-e file:.
|
-e file:.
|
||||||
absl-py==2.1.0
|
absl-py==2.1.0
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
|
affine==2.4.0
|
||||||
|
# via rasterio
|
||||||
aiobotocore==2.12.3
|
aiobotocore==2.12.3
|
||||||
# via s3fs
|
# via s3fs
|
||||||
aiohttp==3.9.5
|
aiohttp==3.9.5
|
||||||
@ -35,6 +37,7 @@ async-timeout==4.0.3
|
|||||||
# via redis
|
# via redis
|
||||||
attrs==23.2.0
|
attrs==23.2.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
|
# via rasterio
|
||||||
audioread==3.0.1
|
audioread==3.0.1
|
||||||
# via librosa
|
# via librosa
|
||||||
backoff==2.2.1
|
backoff==2.2.1
|
||||||
@ -53,6 +56,7 @@ botocore==1.34.69
|
|||||||
# via s3transfer
|
# via s3transfer
|
||||||
certifi==2024.2.2
|
certifi==2024.2.2
|
||||||
# via netcdf4
|
# via netcdf4
|
||||||
|
# via rasterio
|
||||||
# via requests
|
# via requests
|
||||||
cf-xarray==0.9.0
|
cf-xarray==0.9.0
|
||||||
# via batdetect2
|
# via batdetect2
|
||||||
@ -64,9 +68,16 @@ charset-normalizer==3.3.2
|
|||||||
# via requests
|
# via requests
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
# via batdetect2
|
# via batdetect2
|
||||||
|
# via click-plugins
|
||||||
|
# via cligj
|
||||||
# via lightning
|
# via lightning
|
||||||
# via lightning-cloud
|
# via lightning-cloud
|
||||||
|
# via rasterio
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
|
click-plugins==1.1.1
|
||||||
|
# via rasterio
|
||||||
|
cligj==0.7.2
|
||||||
|
# via rasterio
|
||||||
contourpy==1.1.1
|
contourpy==1.1.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
croniter==1.4.1
|
croniter==1.4.1
|
||||||
@ -123,6 +134,7 @@ idna==3.7
|
|||||||
# via yarl
|
# via yarl
|
||||||
importlib-metadata==7.1.0
|
importlib-metadata==7.1.0
|
||||||
# via markdown
|
# via markdown
|
||||||
|
# via rasterio
|
||||||
importlib-resources==6.4.0
|
importlib-resources==6.4.0
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
# via typeshed-client
|
# via typeshed-client
|
||||||
@ -199,9 +211,11 @@ numpy==1.24.4
|
|||||||
# via onnx
|
# via onnx
|
||||||
# via pandas
|
# via pandas
|
||||||
# via pytorch-lightning
|
# via pytorch-lightning
|
||||||
|
# via rasterio
|
||||||
# via scikit-learn
|
# via scikit-learn
|
||||||
# via scipy
|
# via scipy
|
||||||
# via shapely
|
# via shapely
|
||||||
|
# via snuggs
|
||||||
# via soxr
|
# via soxr
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
# via tensorboardx
|
# via tensorboardx
|
||||||
@ -286,6 +300,7 @@ pyjwt==2.8.0
|
|||||||
# via lightning-cloud
|
# via lightning-cloud
|
||||||
pyparsing==3.1.2
|
pyparsing==3.1.2
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
|
# via snuggs
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via arrow
|
# via arrow
|
||||||
# via botocore
|
# via botocore
|
||||||
@ -307,6 +322,8 @@ pyyaml==6.0.1
|
|||||||
# via lightning
|
# via lightning
|
||||||
# via omegaconf
|
# via omegaconf
|
||||||
# via pytorch-lightning
|
# via pytorch-lightning
|
||||||
|
rasterio==1.3.10
|
||||||
|
# via soundevent
|
||||||
readchar==4.0.6
|
readchar==4.0.6
|
||||||
# via inquirer
|
# via inquirer
|
||||||
redis==5.0.4
|
redis==5.0.4
|
||||||
@ -336,6 +353,7 @@ scipy==1.10.1
|
|||||||
# via soundevent
|
# via soundevent
|
||||||
setuptools==69.5.1
|
setuptools==69.5.1
|
||||||
# via lightning-utilities
|
# via lightning-utilities
|
||||||
|
# via rasterio
|
||||||
# via readchar
|
# via readchar
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
shapely==2.0.3
|
shapely==2.0.3
|
||||||
@ -347,7 +365,9 @@ six==1.16.0
|
|||||||
# via tensorboard
|
# via tensorboard
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
# via anyio
|
# via anyio
|
||||||
soundevent==1.3.5
|
snuggs==1.4.7
|
||||||
|
# via rasterio
|
||||||
|
soundevent==2.0.0
|
||||||
# via batdetect2
|
# via batdetect2
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
# via librosa
|
# via librosa
|
||||||
|
Loading…
Reference in New Issue
Block a user