mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
WIP
This commit is contained in:
parent
c865b53c17
commit
b252f23093
@ -306,6 +306,9 @@ def _compute_spec_extent(
|
||||
|
||||
# If the spectrogram is not resized, the duration is correct
|
||||
# but if it is resized, the duration needs to be adjusted
|
||||
# NOTE: For now we can only detect if the spectrogram is resized
|
||||
# by checking if the height is equal to the specified height,
|
||||
# but this could fail.
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_height = params["spec_height"]
|
||||
if spec_height * resize_factor == shape[0]:
|
||||
|
@ -1,14 +1,22 @@
|
||||
"""Functions and dataloaders for training and testing the model."""
|
||||
import copy
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
|
||||
import batdetect2.utils.audio_utils as au
|
||||
from batdetect2.types import AnnotationGroup, HeatmapParameters
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AnnotationGroup,
|
||||
AudioLoaderAnnotationGroup,
|
||||
FileAnnotations,
|
||||
HeatmapParameters,
|
||||
)
|
||||
|
||||
|
||||
def generate_gt_heatmaps(
|
||||
@ -32,22 +40,17 @@ def generate_gt_heatmaps(
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
y_2d_det : np.ndarray
|
||||
2D heatmap of the presence of an event.
|
||||
|
||||
y_2d_size : np.ndarray
|
||||
2D heatmap of the size of the bounding box associated to event.
|
||||
|
||||
y_2d_classes : np.ndarray
|
||||
3D array containing the ground-truth class probabilities for each
|
||||
pixel.
|
||||
|
||||
ann_aug : AnnotationGroup
|
||||
A dictionary containing the annotation information of the
|
||||
annotations that are within the input spectrogram, augmented with
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(params["class_names"])
|
||||
@ -62,20 +65,20 @@ def generate_gt_heatmaps(
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
|
||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32)
|
||||
x_pos_end = au.time_to_x_coords(
|
||||
ann["end_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
|
||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32)
|
||||
|
||||
# location on y axis i.e. frequency
|
||||
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int)
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
||||
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int)
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
|
||||
@ -127,7 +130,12 @@ def generate_gt_heatmaps(
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
# draw_gaussian(
|
||||
# y_2d_det[0, :],
|
||||
# (x_pos_start[ii], y_pos_low[ii]),
|
||||
# params["target_sigma"],
|
||||
# params["target_sigma"] * 2,
|
||||
# )
|
||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||
|
||||
@ -138,10 +146,15 @@ def generate_gt_heatmaps(
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
# draw_gaussian(
|
||||
# y_2d_classes[cls_id, :],
|
||||
# (x_pos_start[ii], y_pos_low[ii]),
|
||||
# params["target_sigma"],
|
||||
# params["target_sigma"] * 2,
|
||||
# )
|
||||
|
||||
# be careful as this will have a 1.0 places where we have event but dont know gt class
|
||||
# this will be masked in training anyway
|
||||
# be careful as this will have a 1.0 places where we have event but
|
||||
# dont know gt class this will be masked in training anyway
|
||||
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
||||
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
||||
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
||||
@ -149,7 +162,37 @@ def generate_gt_heatmaps(
|
||||
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
||||
|
||||
|
||||
def draw_gaussian(heatmap, center, sigmax, sigmay=None):
|
||||
def draw_gaussian(
|
||||
heatmap: np.ndarray,
|
||||
center: Tuple[int, int],
|
||||
sigmax: float,
|
||||
sigmay: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""Draw a 2D gaussian into the heatmap.
|
||||
|
||||
If the gaussian center is outside the heatmap, then the gaussian is not
|
||||
drawn.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
heatmap : np.ndarray
|
||||
The heatmap to draw into. Should be of shape (height, width).
|
||||
center : Tuple[int, int]
|
||||
The center of the gaussian in (x, y) format.
|
||||
sigmax : float
|
||||
The standard deviation of the gaussian in the x direction.
|
||||
sigmay : Optional[float], optional
|
||||
The standard deviation of the gaussian in the y direction. If None,
|
||||
then sigmay = sigmax, by default None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the gaussian was drawn, False if it was not (because
|
||||
the center was outside the heatmap).
|
||||
|
||||
|
||||
"""
|
||||
# center is (x, y)
|
||||
# this edits the heatmap inplace
|
||||
|
||||
@ -185,23 +228,46 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None):
|
||||
return True
|
||||
|
||||
|
||||
def pad_aray(ip_array, pad_size):
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
|
||||
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
||||
"""Pad array with -1s."""
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
|
||||
|
||||
|
||||
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
||||
def warp_spec_aug(
|
||||
spec: torch.Tensor,
|
||||
ann: AnnotationGroup,
|
||||
params: dict,
|
||||
) -> torch.Tensor:
|
||||
"""Warp spectrogram by randomly stretching and squeezing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to warp.
|
||||
ann: AnnotationGroup
|
||||
Annotation group for the spectrogram. Must be provided to sync
|
||||
the start and stop times with the spectrogram after warping.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Warped spectrogram.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function modifies the annotation group in place.
|
||||
"""
|
||||
# This is messy
|
||||
# Augment spectrogram by randomly stretch and squeezing
|
||||
# NOTE this also changes the start and stop time in place
|
||||
|
||||
# not taking care of spec for viz
|
||||
if return_spec_for_viz:
|
||||
assert False
|
||||
|
||||
delta = params["stretch_squeeze_delta"]
|
||||
op_size = (spec.shape[1], spec.shape[2])
|
||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||
|
||||
if resize_amt >= spec.shape[2]:
|
||||
spec_r = torch.cat(
|
||||
(
|
||||
@ -215,41 +281,124 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
||||
)
|
||||
else:
|
||||
spec_r = spec[:, :, :resize_amt]
|
||||
|
||||
# Resize the spectrogram
|
||||
spec = F.interpolate(
|
||||
spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
|
||||
spec_r.unsqueeze(0),
|
||||
size=op_size,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# Update the start and stop times
|
||||
ann["start_times"] *= 1.0 / resize_fract_r
|
||||
ann["end_times"] *= 1.0 / resize_fract_r
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def mask_time_aug(spec, params):
|
||||
# Mask out a random block of time - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
|
||||
def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
"""Mask out random blocks of time.
|
||||
|
||||
Will randomly mask out a block of time in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_time_perc` of the total time.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out time blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.TimeMasking(
|
||||
int(spec.shape[1] * params["mask_max_time_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mask_freq_aug(spec, params):
|
||||
# Mask out a random frequncy range - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
||||
def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
"""Mask out random blocks of frequency.
|
||||
|
||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out frequency blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.FrequencyMasking(
|
||||
int(spec.shape[1] * params["mask_max_freq_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def scale_vol_aug(spec, params):
|
||||
def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
"""Scale the volume of the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to scale.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
"""
|
||||
return spec * np.random.random() * params["spec_amp_scaling"]
|
||||
|
||||
|
||||
def echo_aug(audio, sampling_rate, params):
|
||||
def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
||||
"""Add echo to audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to add echo to.
|
||||
sampling_rate: int
|
||||
Sampling rate of the audio.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio with echo added.
|
||||
"""
|
||||
sample_offset = (
|
||||
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
@ -257,7 +406,35 @@ def echo_aug(audio, sampling_rate, params):
|
||||
return audio
|
||||
|
||||
|
||||
def resample_aug(audio, sampling_rate, params):
|
||||
def resample_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
params: dict,
|
||||
) -> Tuple[np.ndarray, int, float]:
|
||||
"""Resample audio augmentation.
|
||||
|
||||
Will resample the audio to a random sampling rate from the list of
|
||||
sampling rates in `aug_sampling_rates`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate: int
|
||||
Original sampling rate of the audio.
|
||||
params: dict
|
||||
Parameters for the augmentation. Includes the list of sampling rates
|
||||
to choose from for resampling in `aug_sampling_rates`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate : int
|
||||
New sampling rate.
|
||||
duration : float
|
||||
Duration of the audio in seconds.
|
||||
"""
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(params["aug_sampling_rates"])
|
||||
audio = librosa.resample(
|
||||
@ -280,7 +457,33 @@ def resample_aug(audio, sampling_rate, params):
|
||||
return audio, sampling_rate, duration
|
||||
|
||||
|
||||
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
def resample_audio(
|
||||
num_samples: int,
|
||||
sampling_rate: int,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: int,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Resample audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_samples: int
|
||||
Expected number of samples for the output audio.
|
||||
sampling_rate: int
|
||||
Original sampling rate of the audio.
|
||||
audio2: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate2: int
|
||||
Target sampling rate of the audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio2 : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate2 : int
|
||||
New sampling rate.
|
||||
"""
|
||||
# resample to target sampling rate
|
||||
if sampling_rate != sampling_rate2:
|
||||
audio2 = librosa.resample(
|
||||
audio2,
|
||||
@ -289,6 +492,8 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
res_type="polyphase",
|
||||
)
|
||||
sampling_rate2 = sampling_rate
|
||||
|
||||
# pad or trim to the correct length
|
||||
if audio2.shape[0] < num_samples:
|
||||
audio2 = np.hstack(
|
||||
(
|
||||
@ -298,14 +503,52 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
)
|
||||
elif audio2.shape[0] > num_samples:
|
||||
audio2 = audio2[:num_samples]
|
||||
|
||||
return audio2, sampling_rate2
|
||||
|
||||
|
||||
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
def combine_audio_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
ann: AnnotationGroup,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: int,
|
||||
ann2: AnnotationGroup,
|
||||
) -> Tuple[np.ndarray, AnnotationGroup]:
|
||||
"""Combine two audio files.
|
||||
|
||||
Will combine two audio files by resampling them to the same sampling rate
|
||||
and then combining them with a random weight. The annotations will be
|
||||
combined by taking the union of the two sets of annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
First Audio to combine.
|
||||
sampling_rate: int
|
||||
Sampling rate of the first audio.
|
||||
ann: AnnotationGroup
|
||||
Annotations for the first audio.
|
||||
audio2: np.ndarray
|
||||
Second Audio to combine.
|
||||
sampling_rate2: int
|
||||
Sampling rate of the second audio.
|
||||
ann2: AnnotationGroup
|
||||
Annotations for the second audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Combined audio.
|
||||
ann : AnnotationGroup
|
||||
Combined annotations.
|
||||
"""
|
||||
# resample so they are the same
|
||||
audio2, sampling_rate2 = resample_audio(
|
||||
audio.shape[0], sampling_rate, audio2, sampling_rate2
|
||||
audio.shape[0],
|
||||
sampling_rate,
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
)
|
||||
|
||||
# # set mean and std to be the same
|
||||
@ -314,8 +557,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
# audio2 = audio2 + audio.mean()
|
||||
|
||||
if (
|
||||
ann["annotated"]
|
||||
and (ann2["annotated"])
|
||||
ann.get("annotated", False)
|
||||
and (ann2.get("annotated", False))
|
||||
and (sampling_rate2 == sampling_rate)
|
||||
and (audio.shape[0] == audio2.shape[0])
|
||||
):
|
||||
@ -323,8 +566,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
||||
for kk in ann.keys():
|
||||
|
||||
# when combining calls from different files, assume they come from different individuals
|
||||
# when combining calls from different files, assume they come
|
||||
# from different individuals
|
||||
if kk == "individual_ids":
|
||||
if (ann[kk] > -1).sum() > 0:
|
||||
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
@ -335,68 +578,139 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
return audio, ann
|
||||
|
||||
|
||||
def _prepare_annotation(
|
||||
annotation: Annotation, class_names: List[str]
|
||||
) -> Annotation:
|
||||
try:
|
||||
class_id = class_names.index(annotation["class"])
|
||||
except ValueError:
|
||||
class_id = -1
|
||||
|
||||
ann: Annotation = {
|
||||
**annotation,
|
||||
"class_id": class_id,
|
||||
}
|
||||
|
||||
if "individual" in ann:
|
||||
ann["individual"] = int(ann["individual"]) # type: ignore
|
||||
|
||||
return ann
|
||||
|
||||
|
||||
def _prepare_file_annotation(
|
||||
annotation: FileAnnotations,
|
||||
class_names: List[str],
|
||||
classes_to_ignore: List[str],
|
||||
) -> AudioLoaderAnnotationGroup:
|
||||
annotations = [
|
||||
_prepare_annotation(ann, class_names)
|
||||
for ann in annotation["annotation"]
|
||||
if ann["class"] not in classes_to_ignore
|
||||
]
|
||||
|
||||
try:
|
||||
class_id_file = class_names.index(annotation["class_name"])
|
||||
except ValueError:
|
||||
class_id_file = -1
|
||||
|
||||
ret: AudioLoaderAnnotationGroup = {
|
||||
"id": annotation["id"],
|
||||
"annotated": annotation["annotated"],
|
||||
"duration": annotation["duration"],
|
||||
"issues": annotation["issues"],
|
||||
"time_exp": annotation["time_exp"],
|
||||
"class_name": annotation["class_name"],
|
||||
"notes": annotation["notes"],
|
||||
"annotation": annotations,
|
||||
"start_times": np.array([ann["start_time"] for ann in annotations]),
|
||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
||||
"class_ids": np.array([ann.get("class_id", -1) for ann in annotations]),
|
||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class AudioLoader(torch.utils.data.Dataset):
|
||||
def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False):
|
||||
"""Main AudioLoader for training and testing."""
|
||||
|
||||
self.data_anns = []
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.return_spec_for_viz = False
|
||||
def __init__(
|
||||
self,
|
||||
data_anns_ip: List[FileAnnotations],
|
||||
params,
|
||||
dataset_name: Optional[str] = None,
|
||||
is_train: bool = False,
|
||||
):
|
||||
self.is_train: bool = is_train
|
||||
self.params: dict = params
|
||||
self.return_spec_for_viz: bool = False
|
||||
|
||||
for ii in range(len(data_anns_ip)):
|
||||
dd = copy.deepcopy(data_anns_ip[ii])
|
||||
|
||||
# filter out unused annotation here
|
||||
filtered_annotations = []
|
||||
for ii, aa in enumerate(dd["annotation"]):
|
||||
|
||||
if "individual" in aa.keys():
|
||||
aa["individual"] = int(aa["individual"])
|
||||
|
||||
# if only one call labeled it has to be from the same individual
|
||||
if len(dd["annotation"]) == 1:
|
||||
aa["individual"] = 0
|
||||
|
||||
# convert class name into class label
|
||||
if aa["class"] in self.params["class_names"]:
|
||||
aa["class_id"] = self.params["class_names"].index(
|
||||
aa["class"]
|
||||
)
|
||||
else:
|
||||
aa["class_id"] = -1
|
||||
|
||||
if aa["class"] not in self.params["classes_to_ignore"]:
|
||||
filtered_annotations.append(aa)
|
||||
|
||||
dd["annotation"] = filtered_annotations
|
||||
dd["start_times"] = np.array(
|
||||
[aa["start_time"] for aa in dd["annotation"]]
|
||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
||||
_prepare_file_annotation(
|
||||
ann,
|
||||
params["class_names"],
|
||||
params["classes_to_ignore"],
|
||||
)
|
||||
dd["end_times"] = np.array(
|
||||
[aa["end_time"] for aa in dd["annotation"]]
|
||||
)
|
||||
dd["high_freqs"] = np.array(
|
||||
[float(aa["high_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["low_freqs"] = np.array(
|
||||
[float(aa["low_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["class_ids"] = np.array(
|
||||
[aa["class_id"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
dd["individual_ids"] = np.array(
|
||||
[aa["individual"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
for ann in data_anns_ip
|
||||
]
|
||||
|
||||
# file level class name
|
||||
dd["class_id_file"] = -1
|
||||
if "class_name" in dd.keys():
|
||||
if dd["class_name"] in self.params["class_names"]:
|
||||
dd["class_id_file"] = self.params["class_names"].index(
|
||||
dd["class_name"]
|
||||
)
|
||||
|
||||
self.data_anns.append(dd)
|
||||
# for ii in range(len(data_anns_ip)):
|
||||
# dd = copy.deepcopy(data_anns_ip[ii])
|
||||
#
|
||||
# # filter out unused annotation here
|
||||
# filtered_annotations = []
|
||||
# for ii, aa in enumerate(dd["annotation"]):
|
||||
# if "individual" in aa.keys():
|
||||
# aa["individual"] = int(aa["individual"])
|
||||
#
|
||||
# # if only one call labeled it has to be from the same
|
||||
# # individual
|
||||
# if len(dd["annotation"]) == 1:
|
||||
# aa["individual"] = 0
|
||||
#
|
||||
# # convert class name into class label
|
||||
# if aa["class"] in self.params["class_names"]:
|
||||
# aa["class_id"] = self.params["class_names"].index(
|
||||
# aa["class"]
|
||||
# )
|
||||
# else:
|
||||
# aa["class_id"] = -1
|
||||
#
|
||||
# if aa["class"] not in self.params["classes_to_ignore"]:
|
||||
# filtered_annotations.append(aa)
|
||||
#
|
||||
# dd["annotation"] = filtered_annotations
|
||||
# dd["start_times"] = np.array(
|
||||
# [aa["start_time"] for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["end_times"] = np.array(
|
||||
# [aa["end_time"] for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["high_freqs"] = np.array(
|
||||
# [float(aa["high_freq"]) for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["low_freqs"] = np.array(
|
||||
# [float(aa["low_freq"]) for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["class_ids"] = np.array(
|
||||
# [aa["class_id"] for aa in dd["annotation"]]
|
||||
# ).astype(np.int32)
|
||||
# dd["individual_ids"] = np.array(
|
||||
# [aa["individual"] for aa in dd["annotation"]]
|
||||
# ).astype(np.int32)
|
||||
#
|
||||
# # file level class name
|
||||
# dd["class_id_file"] = -1
|
||||
# if "class_name" in dd.keys():
|
||||
# if dd["class_name"] in self.params["class_names"]:
|
||||
# dd["class_id_file"] = self.params["class_names"].index(
|
||||
# dd["class_name"]
|
||||
# )
|
||||
#
|
||||
# self.data_anns.append(dd)
|
||||
|
||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||
self.max_num_anns = 2 * np.max(
|
||||
@ -413,10 +727,30 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
print("Num files : " + str(len(self.data_anns)))
|
||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
||||
|
||||
def get_file_and_anns(self, index=None):
|
||||
def get_file_and_anns(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]:
|
||||
"""Get an audio file and its annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int, optional
|
||||
Index of the file to be loaded. If None, a random file is chosen.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio_raw : np.ndarray
|
||||
Loaded audio file.
|
||||
sampling_rate : int
|
||||
Sampling rate of the audio file.
|
||||
duration : float
|
||||
Duration of the audio file in seconds.
|
||||
ann : AnnotationGroup
|
||||
AnnotationGroup object containing the annotations for the audio file.
|
||||
"""
|
||||
# if no file specified, choose random one
|
||||
if index == None:
|
||||
if index is None:
|
||||
index = np.random.randint(0, len(self.data_anns))
|
||||
|
||||
audio_file = self.data_anns[index]["file_path"]
|
||||
@ -428,19 +762,19 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
# copy annotation
|
||||
ann = {}
|
||||
ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
keys = [
|
||||
"start_times",
|
||||
"end_times",
|
||||
"high_freqs",
|
||||
"low_freqs",
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
]
|
||||
for kk in keys:
|
||||
ann[kk] = self.data_anns[index][kk].copy()
|
||||
ann = copy.deepcopy(self.data_anns[index])
|
||||
# ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann[kk] = self.data_anns[index][kk].copy()
|
||||
|
||||
# if train then grab a random crop
|
||||
if self.is_train:
|
||||
@ -489,7 +823,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
return audio_raw, sampling_rate, duration, ann
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
"""Get an item from the dataset."""
|
||||
# load audio file
|
||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
||||
|
||||
@ -515,12 +849,17 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
audio = echo_aug(audio, sampling_rate, self.params)
|
||||
|
||||
# resample the audio
|
||||
# if np.random.random() < self.params['aug_prob']:
|
||||
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
|
||||
# if np.random.random() < self.params["aug_prob"]:
|
||||
# audio, sampling_rate, duration = resample_aug(
|
||||
# audio, sampling_rate, self.params
|
||||
# )
|
||||
|
||||
# create spectrogram
|
||||
spec, spec_for_viz = au.generate_spectrogram(
|
||||
audio, sampling_rate, self.params, self.return_spec_for_viz
|
||||
audio,
|
||||
sampling_rate,
|
||||
self.params,
|
||||
self.return_spec_for_viz,
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
@ -531,18 +870,22 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
# resize the spec
|
||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
||||
spec = F.interpolate(
|
||||
spec, size=spec_op_shape, mode="bilinear", align_corners=False
|
||||
spec,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# augment spectrogram
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = scale_vol_aug(spec, self.params)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = warp_spec_aug(
|
||||
spec, ann, self.return_spec_for_viz, self.params
|
||||
spec,
|
||||
ann,
|
||||
self.params,
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
@ -564,10 +907,15 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
outputs["y_2d_size"],
|
||||
outputs["y_2d_classes"],
|
||||
ann_aug,
|
||||
) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
|
||||
) = generate_gt_heatmaps(
|
||||
spec_op_shape,
|
||||
sampling_rate,
|
||||
ann,
|
||||
self.params,
|
||||
)
|
||||
|
||||
# hack to get around requirement that all vectors are the same length in
|
||||
# the output batch
|
||||
# hack to get around requirement that all vectors are the same length
|
||||
# in the output batch
|
||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
||||
outputs["is_valid"] = pad_aray(
|
||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
||||
@ -600,4 +948,5 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
return outputs
|
||||
|
||||
def __len__(self):
|
||||
"""Denotes the total number of samples."""
|
||||
return len(self.data_anns)
|
||||
|
@ -1,18 +1,21 @@
|
||||
## How to train a model from scratch
|
||||
`python train_model.py data_dir annotation_dir` e.g.
|
||||
|
||||
> **Warning**
|
||||
> This code in currently broken. Will fix soon, stay tuned.
|
||||
|
||||
`python train_model.py data_dir annotation_dir` e.g.
|
||||
`python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/`
|
||||
|
||||
More comprehensive instructions are provided in the finetune directory.
|
||||
|
||||
|
||||
## Training on your own data
|
||||
You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory.
|
||||
|
||||
Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on.
|
||||
Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on.
|
||||
|
||||
Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task.
|
||||
Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task.
|
||||
|
||||
## Additional notes
|
||||
Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`).
|
||||
Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`).
|
||||
|
||||
Training will be slow without a GPU.
|
||||
Training will be slow without a GPU.
|
||||
|
@ -34,6 +34,7 @@ __all__ = [
|
||||
"ResultParams",
|
||||
"RunResults",
|
||||
"SpectrogramParameters",
|
||||
"AudioLoaderAnnotationGroup",
|
||||
]
|
||||
|
||||
|
||||
@ -99,7 +100,7 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
tool.
|
||||
"""
|
||||
|
||||
@ -127,6 +128,9 @@ class Annotation(DictWithClass):
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
class_id: NotRequired[int]
|
||||
"""Numeric ID for the class of the annotation."""
|
||||
|
||||
|
||||
class FileAnnotations(TypedDict):
|
||||
"""Format of results.
|
||||
@ -468,8 +472,28 @@ class AnnotationGroup(TypedDict):
|
||||
individual_ids: np.ndarray
|
||||
"""Individual IDs of the annotations."""
|
||||
|
||||
annotated: NotRequired[bool]
|
||||
"""Wether the annotation group is complete or not.
|
||||
|
||||
Usually annotation groups are associated to a single
|
||||
audio clip. If the annotation group is complete, it means that all
|
||||
relevant sound events have been annotated. If it is not complete, it
|
||||
means that some sound events might not have been annotated.
|
||||
"""
|
||||
|
||||
x_inds: NotRequired[np.ndarray]
|
||||
"""X coordinate of the annotations in the spectrogram."""
|
||||
|
||||
y_inds: NotRequired[np.ndarray]
|
||||
"""Y coordinate of the annotations in the spectrogram."""
|
||||
|
||||
|
||||
class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations):
|
||||
"""Group of annotation items for the training audio loader.
|
||||
|
||||
This class is used to store the annotations for the training audio
|
||||
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
|
||||
"""
|
||||
|
||||
class_id_file: int
|
||||
"""ID of the class of the file."""
|
||||
|
@ -127,18 +127,28 @@ def load_audio(
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Args:
|
||||
audio_file (str): Path to the audio file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
Parameters
|
||||
----------
|
||||
audio_file: str
|
||||
Path to the audio file.
|
||||
target_samp_rate: int
|
||||
Target sampling rate.
|
||||
scale: bool, optional
|
||||
Whether to scale the audio to [-1, 1]. Default: False.
|
||||
max_duration: float, optional
|
||||
Maximum duration of the audio in seconds. Defaults to None.
|
||||
If provided, the audio is clipped to this duration.
|
||||
|
||||
Returns:
|
||||
sampling_rate: The sampling rate of the audio.
|
||||
audio_raw: The audio signal in a numpy array.
|
||||
Returns
|
||||
-------
|
||||
sampling_rate: int
|
||||
The sampling rate of the audio.
|
||||
audio_raw: np.ndarray
|
||||
The audio signal in a numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: If the audio file is stereo.
|
||||
Raises
|
||||
------
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
|
Loading…
Reference in New Issue
Block a user