mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
WIP
This commit is contained in:
parent
c865b53c17
commit
b252f23093
@ -306,6 +306,9 @@ def _compute_spec_extent(
|
||||
|
||||
# If the spectrogram is not resized, the duration is correct
|
||||
# but if it is resized, the duration needs to be adjusted
|
||||
# NOTE: For now we can only detect if the spectrogram is resized
|
||||
# by checking if the height is equal to the specified height,
|
||||
# but this could fail.
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_height = params["spec_height"]
|
||||
if spec_height * resize_factor == shape[0]:
|
||||
|
@ -1,14 +1,22 @@
|
||||
"""Functions and dataloaders for training and testing the model."""
|
||||
import copy
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
|
||||
import batdetect2.utils.audio_utils as au
|
||||
from batdetect2.types import AnnotationGroup, HeatmapParameters
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AnnotationGroup,
|
||||
AudioLoaderAnnotationGroup,
|
||||
FileAnnotations,
|
||||
HeatmapParameters,
|
||||
)
|
||||
|
||||
|
||||
def generate_gt_heatmaps(
|
||||
@ -32,22 +40,17 @@ def generate_gt_heatmaps(
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
y_2d_det : np.ndarray
|
||||
2D heatmap of the presence of an event.
|
||||
|
||||
y_2d_size : np.ndarray
|
||||
2D heatmap of the size of the bounding box associated to event.
|
||||
|
||||
y_2d_classes : np.ndarray
|
||||
3D array containing the ground-truth class probabilities for each
|
||||
pixel.
|
||||
|
||||
ann_aug : AnnotationGroup
|
||||
A dictionary containing the annotation information of the
|
||||
annotations that are within the input spectrogram, augmented with
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(params["class_names"])
|
||||
@ -62,20 +65,20 @@ def generate_gt_heatmaps(
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
|
||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32)
|
||||
x_pos_end = au.time_to_x_coords(
|
||||
ann["end_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
|
||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32)
|
||||
|
||||
# location on y axis i.e. frequency
|
||||
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int)
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
||||
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int)
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
|
||||
@ -127,7 +130,12 @@ def generate_gt_heatmaps(
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
# draw_gaussian(
|
||||
# y_2d_det[0, :],
|
||||
# (x_pos_start[ii], y_pos_low[ii]),
|
||||
# params["target_sigma"],
|
||||
# params["target_sigma"] * 2,
|
||||
# )
|
||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||
|
||||
@ -138,10 +146,15 @@ def generate_gt_heatmaps(
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
# draw_gaussian(
|
||||
# y_2d_classes[cls_id, :],
|
||||
# (x_pos_start[ii], y_pos_low[ii]),
|
||||
# params["target_sigma"],
|
||||
# params["target_sigma"] * 2,
|
||||
# )
|
||||
|
||||
# be careful as this will have a 1.0 places where we have event but dont know gt class
|
||||
# this will be masked in training anyway
|
||||
# be careful as this will have a 1.0 places where we have event but
|
||||
# dont know gt class this will be masked in training anyway
|
||||
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
||||
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
||||
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
||||
@ -149,7 +162,37 @@ def generate_gt_heatmaps(
|
||||
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
||||
|
||||
|
||||
def draw_gaussian(heatmap, center, sigmax, sigmay=None):
|
||||
def draw_gaussian(
|
||||
heatmap: np.ndarray,
|
||||
center: Tuple[int, int],
|
||||
sigmax: float,
|
||||
sigmay: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""Draw a 2D gaussian into the heatmap.
|
||||
|
||||
If the gaussian center is outside the heatmap, then the gaussian is not
|
||||
drawn.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
heatmap : np.ndarray
|
||||
The heatmap to draw into. Should be of shape (height, width).
|
||||
center : Tuple[int, int]
|
||||
The center of the gaussian in (x, y) format.
|
||||
sigmax : float
|
||||
The standard deviation of the gaussian in the x direction.
|
||||
sigmay : Optional[float], optional
|
||||
The standard deviation of the gaussian in the y direction. If None,
|
||||
then sigmay = sigmax, by default None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the gaussian was drawn, False if it was not (because
|
||||
the center was outside the heatmap).
|
||||
|
||||
|
||||
"""
|
||||
# center is (x, y)
|
||||
# this edits the heatmap inplace
|
||||
|
||||
@ -185,23 +228,46 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None):
|
||||
return True
|
||||
|
||||
|
||||
def pad_aray(ip_array, pad_size):
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
|
||||
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
||||
"""Pad array with -1s."""
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
|
||||
|
||||
|
||||
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
||||
def warp_spec_aug(
|
||||
spec: torch.Tensor,
|
||||
ann: AnnotationGroup,
|
||||
params: dict,
|
||||
) -> torch.Tensor:
|
||||
"""Warp spectrogram by randomly stretching and squeezing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to warp.
|
||||
ann: AnnotationGroup
|
||||
Annotation group for the spectrogram. Must be provided to sync
|
||||
the start and stop times with the spectrogram after warping.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Warped spectrogram.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function modifies the annotation group in place.
|
||||
"""
|
||||
# This is messy
|
||||
# Augment spectrogram by randomly stretch and squeezing
|
||||
# NOTE this also changes the start and stop time in place
|
||||
|
||||
# not taking care of spec for viz
|
||||
if return_spec_for_viz:
|
||||
assert False
|
||||
|
||||
delta = params["stretch_squeeze_delta"]
|
||||
op_size = (spec.shape[1], spec.shape[2])
|
||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||
|
||||
if resize_amt >= spec.shape[2]:
|
||||
spec_r = torch.cat(
|
||||
(
|
||||
@ -215,41 +281,124 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
||||
)
|
||||
else:
|
||||
spec_r = spec[:, :, :resize_amt]
|
||||
|
||||
# Resize the spectrogram
|
||||
spec = F.interpolate(
|
||||
spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
|
||||
spec_r.unsqueeze(0),
|
||||
size=op_size,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# Update the start and stop times
|
||||
ann["start_times"] *= 1.0 / resize_fract_r
|
||||
ann["end_times"] *= 1.0 / resize_fract_r
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def mask_time_aug(spec, params):
|
||||
# Mask out a random block of time - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
|
||||
def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
"""Mask out random blocks of time.
|
||||
|
||||
Will randomly mask out a block of time in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_time_perc` of the total time.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out time blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.TimeMasking(
|
||||
int(spec.shape[1] * params["mask_max_time_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mask_freq_aug(spec, params):
|
||||
# Mask out a random frequncy range - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
||||
def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
"""Mask out random blocks of frequency.
|
||||
|
||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out frequency blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.FrequencyMasking(
|
||||
int(spec.shape[1] * params["mask_max_freq_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def scale_vol_aug(spec, params):
|
||||
def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
"""Scale the volume of the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to scale.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
"""
|
||||
return spec * np.random.random() * params["spec_amp_scaling"]
|
||||
|
||||
|
||||
def echo_aug(audio, sampling_rate, params):
|
||||
def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
||||
"""Add echo to audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to add echo to.
|
||||
sampling_rate: int
|
||||
Sampling rate of the audio.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio with echo added.
|
||||
"""
|
||||
sample_offset = (
|
||||
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
@ -257,7 +406,35 @@ def echo_aug(audio, sampling_rate, params):
|
||||
return audio
|
||||
|
||||
|
||||
def resample_aug(audio, sampling_rate, params):
|
||||
def resample_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
params: dict,
|
||||
) -> Tuple[np.ndarray, int, float]:
|
||||
"""Resample audio augmentation.
|
||||
|
||||
Will resample the audio to a random sampling rate from the list of
|
||||
sampling rates in `aug_sampling_rates`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate: int
|
||||
Original sampling rate of the audio.
|
||||
params: dict
|
||||
Parameters for the augmentation. Includes the list of sampling rates
|
||||
to choose from for resampling in `aug_sampling_rates`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate : int
|
||||
New sampling rate.
|
||||
duration : float
|
||||
Duration of the audio in seconds.
|
||||
"""
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(params["aug_sampling_rates"])
|
||||
audio = librosa.resample(
|
||||
@ -280,7 +457,33 @@ def resample_aug(audio, sampling_rate, params):
|
||||
return audio, sampling_rate, duration
|
||||
|
||||
|
||||
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
def resample_audio(
|
||||
num_samples: int,
|
||||
sampling_rate: int,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: int,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Resample audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_samples: int
|
||||
Expected number of samples for the output audio.
|
||||
sampling_rate: int
|
||||
Original sampling rate of the audio.
|
||||
audio2: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate2: int
|
||||
Target sampling rate of the audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio2 : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate2 : int
|
||||
New sampling rate.
|
||||
"""
|
||||
# resample to target sampling rate
|
||||
if sampling_rate != sampling_rate2:
|
||||
audio2 = librosa.resample(
|
||||
audio2,
|
||||
@ -289,6 +492,8 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
res_type="polyphase",
|
||||
)
|
||||
sampling_rate2 = sampling_rate
|
||||
|
||||
# pad or trim to the correct length
|
||||
if audio2.shape[0] < num_samples:
|
||||
audio2 = np.hstack(
|
||||
(
|
||||
@ -298,14 +503,52 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
)
|
||||
elif audio2.shape[0] > num_samples:
|
||||
audio2 = audio2[:num_samples]
|
||||
|
||||
return audio2, sampling_rate2
|
||||
|
||||
|
||||
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
def combine_audio_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
ann: AnnotationGroup,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: int,
|
||||
ann2: AnnotationGroup,
|
||||
) -> Tuple[np.ndarray, AnnotationGroup]:
|
||||
"""Combine two audio files.
|
||||
|
||||
Will combine two audio files by resampling them to the same sampling rate
|
||||
and then combining them with a random weight. The annotations will be
|
||||
combined by taking the union of the two sets of annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
First Audio to combine.
|
||||
sampling_rate: int
|
||||
Sampling rate of the first audio.
|
||||
ann: AnnotationGroup
|
||||
Annotations for the first audio.
|
||||
audio2: np.ndarray
|
||||
Second Audio to combine.
|
||||
sampling_rate2: int
|
||||
Sampling rate of the second audio.
|
||||
ann2: AnnotationGroup
|
||||
Annotations for the second audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Combined audio.
|
||||
ann : AnnotationGroup
|
||||
Combined annotations.
|
||||
"""
|
||||
# resample so they are the same
|
||||
audio2, sampling_rate2 = resample_audio(
|
||||
audio.shape[0], sampling_rate, audio2, sampling_rate2
|
||||
audio.shape[0],
|
||||
sampling_rate,
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
)
|
||||
|
||||
# # set mean and std to be the same
|
||||
@ -314,8 +557,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
# audio2 = audio2 + audio.mean()
|
||||
|
||||
if (
|
||||
ann["annotated"]
|
||||
and (ann2["annotated"])
|
||||
ann.get("annotated", False)
|
||||
and (ann2.get("annotated", False))
|
||||
and (sampling_rate2 == sampling_rate)
|
||||
and (audio.shape[0] == audio2.shape[0])
|
||||
):
|
||||
@ -323,8 +566,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
||||
for kk in ann.keys():
|
||||
|
||||
# when combining calls from different files, assume they come from different individuals
|
||||
# when combining calls from different files, assume they come
|
||||
# from different individuals
|
||||
if kk == "individual_ids":
|
||||
if (ann[kk] > -1).sum() > 0:
|
||||
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
@ -335,68 +578,139 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
return audio, ann
|
||||
|
||||
|
||||
def _prepare_annotation(
|
||||
annotation: Annotation, class_names: List[str]
|
||||
) -> Annotation:
|
||||
try:
|
||||
class_id = class_names.index(annotation["class"])
|
||||
except ValueError:
|
||||
class_id = -1
|
||||
|
||||
ann: Annotation = {
|
||||
**annotation,
|
||||
"class_id": class_id,
|
||||
}
|
||||
|
||||
if "individual" in ann:
|
||||
ann["individual"] = int(ann["individual"]) # type: ignore
|
||||
|
||||
return ann
|
||||
|
||||
|
||||
def _prepare_file_annotation(
|
||||
annotation: FileAnnotations,
|
||||
class_names: List[str],
|
||||
classes_to_ignore: List[str],
|
||||
) -> AudioLoaderAnnotationGroup:
|
||||
annotations = [
|
||||
_prepare_annotation(ann, class_names)
|
||||
for ann in annotation["annotation"]
|
||||
if ann["class"] not in classes_to_ignore
|
||||
]
|
||||
|
||||
try:
|
||||
class_id_file = class_names.index(annotation["class_name"])
|
||||
except ValueError:
|
||||
class_id_file = -1
|
||||
|
||||
ret: AudioLoaderAnnotationGroup = {
|
||||
"id": annotation["id"],
|
||||
"annotated": annotation["annotated"],
|
||||
"duration": annotation["duration"],
|
||||
"issues": annotation["issues"],
|
||||
"time_exp": annotation["time_exp"],
|
||||
"class_name": annotation["class_name"],
|
||||
"notes": annotation["notes"],
|
||||
"annotation": annotations,
|
||||
"start_times": np.array([ann["start_time"] for ann in annotations]),
|
||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
||||
"class_ids": np.array([ann.get("class_id", -1) for ann in annotations]),
|
||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class AudioLoader(torch.utils.data.Dataset):
|
||||
def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False):
|
||||
"""Main AudioLoader for training and testing."""
|
||||
|
||||
self.data_anns = []
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.return_spec_for_viz = False
|
||||
def __init__(
|
||||
self,
|
||||
data_anns_ip: List[FileAnnotations],
|
||||
params,
|
||||
dataset_name: Optional[str] = None,
|
||||
is_train: bool = False,
|
||||
):
|
||||
self.is_train: bool = is_train
|
||||
self.params: dict = params
|
||||
self.return_spec_for_viz: bool = False
|
||||
|
||||
for ii in range(len(data_anns_ip)):
|
||||
dd = copy.deepcopy(data_anns_ip[ii])
|
||||
|
||||
# filter out unused annotation here
|
||||
filtered_annotations = []
|
||||
for ii, aa in enumerate(dd["annotation"]):
|
||||
|
||||
if "individual" in aa.keys():
|
||||
aa["individual"] = int(aa["individual"])
|
||||
|
||||
# if only one call labeled it has to be from the same individual
|
||||
if len(dd["annotation"]) == 1:
|
||||
aa["individual"] = 0
|
||||
|
||||
# convert class name into class label
|
||||
if aa["class"] in self.params["class_names"]:
|
||||
aa["class_id"] = self.params["class_names"].index(
|
||||
aa["class"]
|
||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
||||
_prepare_file_annotation(
|
||||
ann,
|
||||
params["class_names"],
|
||||
params["classes_to_ignore"],
|
||||
)
|
||||
else:
|
||||
aa["class_id"] = -1
|
||||
for ann in data_anns_ip
|
||||
]
|
||||
|
||||
if aa["class"] not in self.params["classes_to_ignore"]:
|
||||
filtered_annotations.append(aa)
|
||||
|
||||
dd["annotation"] = filtered_annotations
|
||||
dd["start_times"] = np.array(
|
||||
[aa["start_time"] for aa in dd["annotation"]]
|
||||
)
|
||||
dd["end_times"] = np.array(
|
||||
[aa["end_time"] for aa in dd["annotation"]]
|
||||
)
|
||||
dd["high_freqs"] = np.array(
|
||||
[float(aa["high_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["low_freqs"] = np.array(
|
||||
[float(aa["low_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["class_ids"] = np.array(
|
||||
[aa["class_id"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
dd["individual_ids"] = np.array(
|
||||
[aa["individual"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
|
||||
# file level class name
|
||||
dd["class_id_file"] = -1
|
||||
if "class_name" in dd.keys():
|
||||
if dd["class_name"] in self.params["class_names"]:
|
||||
dd["class_id_file"] = self.params["class_names"].index(
|
||||
dd["class_name"]
|
||||
)
|
||||
|
||||
self.data_anns.append(dd)
|
||||
# for ii in range(len(data_anns_ip)):
|
||||
# dd = copy.deepcopy(data_anns_ip[ii])
|
||||
#
|
||||
# # filter out unused annotation here
|
||||
# filtered_annotations = []
|
||||
# for ii, aa in enumerate(dd["annotation"]):
|
||||
# if "individual" in aa.keys():
|
||||
# aa["individual"] = int(aa["individual"])
|
||||
#
|
||||
# # if only one call labeled it has to be from the same
|
||||
# # individual
|
||||
# if len(dd["annotation"]) == 1:
|
||||
# aa["individual"] = 0
|
||||
#
|
||||
# # convert class name into class label
|
||||
# if aa["class"] in self.params["class_names"]:
|
||||
# aa["class_id"] = self.params["class_names"].index(
|
||||
# aa["class"]
|
||||
# )
|
||||
# else:
|
||||
# aa["class_id"] = -1
|
||||
#
|
||||
# if aa["class"] not in self.params["classes_to_ignore"]:
|
||||
# filtered_annotations.append(aa)
|
||||
#
|
||||
# dd["annotation"] = filtered_annotations
|
||||
# dd["start_times"] = np.array(
|
||||
# [aa["start_time"] for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["end_times"] = np.array(
|
||||
# [aa["end_time"] for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["high_freqs"] = np.array(
|
||||
# [float(aa["high_freq"]) for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["low_freqs"] = np.array(
|
||||
# [float(aa["low_freq"]) for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["class_ids"] = np.array(
|
||||
# [aa["class_id"] for aa in dd["annotation"]]
|
||||
# ).astype(np.int32)
|
||||
# dd["individual_ids"] = np.array(
|
||||
# [aa["individual"] for aa in dd["annotation"]]
|
||||
# ).astype(np.int32)
|
||||
#
|
||||
# # file level class name
|
||||
# dd["class_id_file"] = -1
|
||||
# if "class_name" in dd.keys():
|
||||
# if dd["class_name"] in self.params["class_names"]:
|
||||
# dd["class_id_file"] = self.params["class_names"].index(
|
||||
# dd["class_name"]
|
||||
# )
|
||||
#
|
||||
# self.data_anns.append(dd)
|
||||
|
||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||
self.max_num_anns = 2 * np.max(
|
||||
@ -413,10 +727,30 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
print("Num files : " + str(len(self.data_anns)))
|
||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
||||
|
||||
def get_file_and_anns(self, index=None):
|
||||
def get_file_and_anns(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]:
|
||||
"""Get an audio file and its annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int, optional
|
||||
Index of the file to be loaded. If None, a random file is chosen.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio_raw : np.ndarray
|
||||
Loaded audio file.
|
||||
sampling_rate : int
|
||||
Sampling rate of the audio file.
|
||||
duration : float
|
||||
Duration of the audio file in seconds.
|
||||
ann : AnnotationGroup
|
||||
AnnotationGroup object containing the annotations for the audio file.
|
||||
"""
|
||||
# if no file specified, choose random one
|
||||
if index == None:
|
||||
if index is None:
|
||||
index = np.random.randint(0, len(self.data_anns))
|
||||
|
||||
audio_file = self.data_anns[index]["file_path"]
|
||||
@ -428,19 +762,19 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
# copy annotation
|
||||
ann = {}
|
||||
ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
keys = [
|
||||
"start_times",
|
||||
"end_times",
|
||||
"high_freqs",
|
||||
"low_freqs",
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
]
|
||||
for kk in keys:
|
||||
ann[kk] = self.data_anns[index][kk].copy()
|
||||
ann = copy.deepcopy(self.data_anns[index])
|
||||
# ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann[kk] = self.data_anns[index][kk].copy()
|
||||
|
||||
# if train then grab a random crop
|
||||
if self.is_train:
|
||||
@ -489,7 +823,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
return audio_raw, sampling_rate, duration, ann
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
"""Get an item from the dataset."""
|
||||
# load audio file
|
||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
||||
|
||||
@ -515,12 +849,17 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
audio = echo_aug(audio, sampling_rate, self.params)
|
||||
|
||||
# resample the audio
|
||||
# if np.random.random() < self.params['aug_prob']:
|
||||
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
|
||||
# if np.random.random() < self.params["aug_prob"]:
|
||||
# audio, sampling_rate, duration = resample_aug(
|
||||
# audio, sampling_rate, self.params
|
||||
# )
|
||||
|
||||
# create spectrogram
|
||||
spec, spec_for_viz = au.generate_spectrogram(
|
||||
audio, sampling_rate, self.params, self.return_spec_for_viz
|
||||
audio,
|
||||
sampling_rate,
|
||||
self.params,
|
||||
self.return_spec_for_viz,
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
@ -531,18 +870,22 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
# resize the spec
|
||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
||||
spec = F.interpolate(
|
||||
spec, size=spec_op_shape, mode="bilinear", align_corners=False
|
||||
spec,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# augment spectrogram
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = scale_vol_aug(spec, self.params)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = warp_spec_aug(
|
||||
spec, ann, self.return_spec_for_viz, self.params
|
||||
spec,
|
||||
ann,
|
||||
self.params,
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
@ -564,10 +907,15 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
outputs["y_2d_size"],
|
||||
outputs["y_2d_classes"],
|
||||
ann_aug,
|
||||
) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
|
||||
) = generate_gt_heatmaps(
|
||||
spec_op_shape,
|
||||
sampling_rate,
|
||||
ann,
|
||||
self.params,
|
||||
)
|
||||
|
||||
# hack to get around requirement that all vectors are the same length in
|
||||
# the output batch
|
||||
# hack to get around requirement that all vectors are the same length
|
||||
# in the output batch
|
||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
||||
outputs["is_valid"] = pad_aray(
|
||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
||||
@ -600,4 +948,5 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
return outputs
|
||||
|
||||
def __len__(self):
|
||||
"""Denotes the total number of samples."""
|
||||
return len(self.data_anns)
|
||||
|
@ -1,10 +1,13 @@
|
||||
## How to train a model from scratch
|
||||
|
||||
> **Warning**
|
||||
> This code in currently broken. Will fix soon, stay tuned.
|
||||
|
||||
`python train_model.py data_dir annotation_dir` e.g.
|
||||
`python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/`
|
||||
|
||||
More comprehensive instructions are provided in the finetune directory.
|
||||
|
||||
|
||||
## Training on your own data
|
||||
You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory.
|
||||
|
||||
|
@ -34,6 +34,7 @@ __all__ = [
|
||||
"ResultParams",
|
||||
"RunResults",
|
||||
"SpectrogramParameters",
|
||||
"AudioLoaderAnnotationGroup",
|
||||
]
|
||||
|
||||
|
||||
@ -127,6 +128,9 @@ class Annotation(DictWithClass):
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
class_id: NotRequired[int]
|
||||
"""Numeric ID for the class of the annotation."""
|
||||
|
||||
|
||||
class FileAnnotations(TypedDict):
|
||||
"""Format of results.
|
||||
@ -468,8 +472,28 @@ class AnnotationGroup(TypedDict):
|
||||
individual_ids: np.ndarray
|
||||
"""Individual IDs of the annotations."""
|
||||
|
||||
annotated: NotRequired[bool]
|
||||
"""Wether the annotation group is complete or not.
|
||||
|
||||
Usually annotation groups are associated to a single
|
||||
audio clip. If the annotation group is complete, it means that all
|
||||
relevant sound events have been annotated. If it is not complete, it
|
||||
means that some sound events might not have been annotated.
|
||||
"""
|
||||
|
||||
x_inds: NotRequired[np.ndarray]
|
||||
"""X coordinate of the annotations in the spectrogram."""
|
||||
|
||||
y_inds: NotRequired[np.ndarray]
|
||||
"""Y coordinate of the annotations in the spectrogram."""
|
||||
|
||||
|
||||
class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations):
|
||||
"""Group of annotation items for the training audio loader.
|
||||
|
||||
This class is used to store the annotations for the training audio
|
||||
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
|
||||
"""
|
||||
|
||||
class_id_file: int
|
||||
"""ID of the class of the file."""
|
||||
|
@ -127,17 +127,27 @@ def load_audio(
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Args:
|
||||
audio_file (str): Path to the audio file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
Parameters
|
||||
----------
|
||||
audio_file: str
|
||||
Path to the audio file.
|
||||
target_samp_rate: int
|
||||
Target sampling rate.
|
||||
scale: bool, optional
|
||||
Whether to scale the audio to [-1, 1]. Default: False.
|
||||
max_duration: float, optional
|
||||
Maximum duration of the audio in seconds. Defaults to None.
|
||||
If provided, the audio is clipped to this duration.
|
||||
|
||||
Returns:
|
||||
sampling_rate: The sampling rate of the audio.
|
||||
audio_raw: The audio signal in a numpy array.
|
||||
Returns
|
||||
-------
|
||||
sampling_rate: int
|
||||
The sampling rate of the audio.
|
||||
audio_raw: np.ndarray
|
||||
The audio signal in a numpy array.
|
||||
|
||||
Raises:
|
||||
Raises
|
||||
------
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user