This commit is contained in:
Santiago Martinez 2023-04-13 07:58:01 -06:00
parent c865b53c17
commit b252f23093
5 changed files with 530 additions and 141 deletions

View File

@ -306,6 +306,9 @@ def _compute_spec_extent(
# If the spectrogram is not resized, the duration is correct
# but if it is resized, the duration needs to be adjusted
# NOTE: For now we can only detect if the spectrogram is resized
# by checking if the height is equal to the specified height,
# but this could fail.
resize_factor = params["resize_factor"]
spec_height = params["spec_height"]
if spec_height * resize_factor == shape[0]:

View File

@ -1,14 +1,22 @@
"""Functions and dataloaders for training and testing the model."""
import copy
from typing import Tuple
from typing import List, Optional, Tuple
import librosa
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
import batdetect2.utils.audio_utils as au
from batdetect2.types import AnnotationGroup, HeatmapParameters
from batdetect2.types import (
Annotation,
AnnotationGroup,
AudioLoaderAnnotationGroup,
FileAnnotations,
HeatmapParameters,
)
def generate_gt_heatmaps(
@ -32,22 +40,17 @@ def generate_gt_heatmaps(
Returns
-------
y_2d_det : np.ndarray
2D heatmap of the presence of an event.
y_2d_size : np.ndarray
2D heatmap of the size of the bounding box associated to event.
y_2d_classes : np.ndarray
3D array containing the ground-truth class probabilities for each
pixel.
ann_aug : AnnotationGroup
A dictionary containing the annotation information of the
annotations that are within the input spectrogram, augmented with
the x and y indices of their pixel location in the input spectrogram.
"""
# spec may be resized on input into the network
num_classes = len(params["class_names"])
@ -62,20 +65,20 @@ def generate_gt_heatmaps(
params["fft_win_length"],
params["fft_overlap"],
)
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32)
x_pos_end = au.time_to_x_coords(
ann["end_times"],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32)
# location on y axis i.e. frequency
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_low = (op_height - y_pos_low).astype(np.int)
y_pos_low = (op_height - y_pos_low).astype(np.int32)
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int)
y_pos_high = (op_height - y_pos_high).astype(np.int32)
bb_widths = x_pos_end - x_pos_start
bb_heights = y_pos_low - y_pos_high
@ -127,7 +130,12 @@ def generate_gt_heatmaps(
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
)
# draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
# draw_gaussian(
# y_2d_det[0, :],
# (x_pos_start[ii], y_pos_low[ii]),
# params["target_sigma"],
# params["target_sigma"] * 2,
# )
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
@ -138,10 +146,15 @@ def generate_gt_heatmaps(
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
)
# draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
# draw_gaussian(
# y_2d_classes[cls_id, :],
# (x_pos_start[ii], y_pos_low[ii]),
# params["target_sigma"],
# params["target_sigma"] * 2,
# )
# be careful as this will have a 1.0 places where we have event but dont know gt class
# this will be masked in training anyway
# be careful as this will have a 1.0 places where we have event but
# dont know gt class this will be masked in training anyway
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
@ -149,7 +162,37 @@ def generate_gt_heatmaps(
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
def draw_gaussian(heatmap, center, sigmax, sigmay=None):
def draw_gaussian(
heatmap: np.ndarray,
center: Tuple[int, int],
sigmax: float,
sigmay: Optional[float] = None,
) -> bool:
"""Draw a 2D gaussian into the heatmap.
If the gaussian center is outside the heatmap, then the gaussian is not
drawn.
Parameters
----------
heatmap : np.ndarray
The heatmap to draw into. Should be of shape (height, width).
center : Tuple[int, int]
The center of the gaussian in (x, y) format.
sigmax : float
The standard deviation of the gaussian in the x direction.
sigmay : Optional[float], optional
The standard deviation of the gaussian in the y direction. If None,
then sigmay = sigmax, by default None.
Returns
-------
bool
True if the gaussian was drawn, False if it was not (because
the center was outside the heatmap).
"""
# center is (x, y)
# this edits the heatmap inplace
@ -185,23 +228,46 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None):
return True
def pad_aray(ip_array, pad_size):
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
"""Pad array with -1s."""
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
def warp_spec_aug(
spec: torch.Tensor,
ann: AnnotationGroup,
params: dict,
) -> torch.Tensor:
"""Warp spectrogram by randomly stretching and squeezing.
Parameters
----------
spec: torch.Tensor
Spectrogram to warp.
ann: AnnotationGroup
Annotation group for the spectrogram. Must be provided to sync
the start and stop times with the spectrogram after warping.
params: dict
Parameters for the augmentation.
Returns
-------
torch.Tensor
Warped spectrogram.
Notes
-----
This function modifies the annotation group in place.
"""
# This is messy
# Augment spectrogram by randomly stretch and squeezing
# NOTE this also changes the start and stop time in place
# not taking care of spec for viz
if return_spec_for_viz:
assert False
delta = params["stretch_squeeze_delta"]
op_size = (spec.shape[1], spec.shape[2])
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
resize_amt = int(spec.shape[2] * resize_fract_r)
if resize_amt >= spec.shape[2]:
spec_r = torch.cat(
(
@ -215,41 +281,124 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params):
)
else:
spec_r = spec[:, :, :resize_amt]
# Resize the spectrogram
spec = F.interpolate(
spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
spec_r.unsqueeze(0),
size=op_size,
mode="bilinear",
align_corners=False,
).squeeze(0)
# Update the start and stop times
ann["start_times"] *= 1.0 / resize_fract_r
ann["end_times"] *= 1.0 / resize_fract_r
return spec
def mask_time_aug(spec, params):
# Mask out a random block of time - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
"""Mask out random blocks of time.
Will randomly mask out a block of time in the spectrogram. The block
will be between 0.0 and `mask_max_time_perc` of the total time.
A random number of blocks will be masked out between 1 and 3.
Parameters
----------
spec: torch.Tensor
Spectrogram to mask.
params: dict
Parameters for the augmentation.
Returns
-------
torch.Tensor
Spectrogram with masked out time blocks.
Notes
-----
This function is based on the implementation in::
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition
"""
fm = torchaudio.transforms.TimeMasking(
int(spec.shape[1] * params["mask_max_time_perc"])
)
for ii in range(np.random.randint(1, 4)):
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def mask_freq_aug(spec, params):
# Mask out a random frequncy range - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
"""Mask out random blocks of frequency.
Will randomly mask out a block of frequency in the spectrogram. The block
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
A random number of blocks will be masked out between 1 and 3.
Parameters
----------
spec: torch.Tensor
Spectrogram to mask.
params: dict
Parameters for the augmentation.
Returns
-------
torch.Tensor
Spectrogram with masked out frequency blocks.
Notes
-----
This function is based on the implementation in::
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition
"""
fm = torchaudio.transforms.FrequencyMasking(
int(spec.shape[1] * params["mask_max_freq_perc"])
)
for ii in range(np.random.randint(1, 4)):
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def scale_vol_aug(spec, params):
def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
"""Scale the volume of the spectrogram.
Parameters
----------
spec: torch.Tensor
Spectrogram to scale.
params: dict
Parameters for the augmentation.
Returns
-------
torch.Tensor
"""
return spec * np.random.random() * params["spec_amp_scaling"]
def echo_aug(audio, sampling_rate, params):
def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
"""Add echo to audio.
Parameters
----------
audio: np.ndarray
Audio to add echo to.
sampling_rate: int
Sampling rate of the audio.
params: dict
Parameters for the augmentation.
Returns
-------
np.ndarray
Audio with echo added.
"""
sample_offset = (
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
)
@ -257,7 +406,35 @@ def echo_aug(audio, sampling_rate, params):
return audio
def resample_aug(audio, sampling_rate, params):
def resample_aug(
audio: np.ndarray,
sampling_rate: int,
params: dict,
) -> Tuple[np.ndarray, int, float]:
"""Resample audio augmentation.
Will resample the audio to a random sampling rate from the list of
sampling rates in `aug_sampling_rates`.
Parameters
----------
audio: np.ndarray
Audio to resample.
sampling_rate: int
Original sampling rate of the audio.
params: dict
Parameters for the augmentation. Includes the list of sampling rates
to choose from for resampling in `aug_sampling_rates`.
Returns
-------
audio : np.ndarray
Resampled audio.
sampling_rate : int
New sampling rate.
duration : float
Duration of the audio in seconds.
"""
sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(params["aug_sampling_rates"])
audio = librosa.resample(
@ -280,7 +457,33 @@ def resample_aug(audio, sampling_rate, params):
return audio, sampling_rate, duration
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
def resample_audio(
num_samples: int,
sampling_rate: int,
audio2: np.ndarray,
sampling_rate2: int,
) -> Tuple[np.ndarray, int]:
"""Resample audio.
Parameters
----------
num_samples: int
Expected number of samples for the output audio.
sampling_rate: int
Original sampling rate of the audio.
audio2: np.ndarray
Audio to resample.
sampling_rate2: int
Target sampling rate of the audio.
Returns
-------
audio2 : np.ndarray
Resampled audio.
sampling_rate2 : int
New sampling rate.
"""
# resample to target sampling rate
if sampling_rate != sampling_rate2:
audio2 = librosa.resample(
audio2,
@ -289,6 +492,8 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
res_type="polyphase",
)
sampling_rate2 = sampling_rate
# pad or trim to the correct length
if audio2.shape[0] < num_samples:
audio2 = np.hstack(
(
@ -298,14 +503,52 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
)
elif audio2.shape[0] > num_samples:
audio2 = audio2[:num_samples]
return audio2, sampling_rate2
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
def combine_audio_aug(
audio: np.ndarray,
sampling_rate: int,
ann: AnnotationGroup,
audio2: np.ndarray,
sampling_rate2: int,
ann2: AnnotationGroup,
) -> Tuple[np.ndarray, AnnotationGroup]:
"""Combine two audio files.
Will combine two audio files by resampling them to the same sampling rate
and then combining them with a random weight. The annotations will be
combined by taking the union of the two sets of annotations.
Parameters
----------
audio: np.ndarray
First Audio to combine.
sampling_rate: int
Sampling rate of the first audio.
ann: AnnotationGroup
Annotations for the first audio.
audio2: np.ndarray
Second Audio to combine.
sampling_rate2: int
Sampling rate of the second audio.
ann2: AnnotationGroup
Annotations for the second audio.
Returns
-------
audio : np.ndarray
Combined audio.
ann : AnnotationGroup
Combined annotations.
"""
# resample so they are the same
audio2, sampling_rate2 = resample_audio(
audio.shape[0], sampling_rate, audio2, sampling_rate2
audio.shape[0],
sampling_rate,
audio2,
sampling_rate2,
)
# # set mean and std to be the same
@ -314,8 +557,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
# audio2 = audio2 + audio.mean()
if (
ann["annotated"]
and (ann2["annotated"])
ann.get("annotated", False)
and (ann2.get("annotated", False))
and (sampling_rate2 == sampling_rate)
and (audio.shape[0] == audio2.shape[0])
):
@ -323,8 +566,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
audio = comb_weight * audio + (1 - comb_weight) * audio2
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
for kk in ann.keys():
# when combining calls from different files, assume they come from different individuals
# when combining calls from different files, assume they come
# from different individuals
if kk == "individual_ids":
if (ann[kk] > -1).sum() > 0:
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
@ -335,68 +578,139 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
return audio, ann
def _prepare_annotation(
annotation: Annotation, class_names: List[str]
) -> Annotation:
try:
class_id = class_names.index(annotation["class"])
except ValueError:
class_id = -1
ann: Annotation = {
**annotation,
"class_id": class_id,
}
if "individual" in ann:
ann["individual"] = int(ann["individual"]) # type: ignore
return ann
def _prepare_file_annotation(
annotation: FileAnnotations,
class_names: List[str],
classes_to_ignore: List[str],
) -> AudioLoaderAnnotationGroup:
annotations = [
_prepare_annotation(ann, class_names)
for ann in annotation["annotation"]
if ann["class"] not in classes_to_ignore
]
try:
class_id_file = class_names.index(annotation["class_name"])
except ValueError:
class_id_file = -1
ret: AudioLoaderAnnotationGroup = {
"id": annotation["id"],
"annotated": annotation["annotated"],
"duration": annotation["duration"],
"issues": annotation["issues"],
"time_exp": annotation["time_exp"],
"class_name": annotation["class_name"],
"notes": annotation["notes"],
"annotation": annotations,
"start_times": np.array([ann["start_time"] for ann in annotations]),
"end_times": np.array([ann["end_time"] for ann in annotations]),
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
"class_ids": np.array([ann.get("class_id", -1) for ann in annotations]),
"individual_ids": np.array([ann["individual"] for ann in annotations]),
"class_id_file": class_id_file,
}
return ret
class AudioLoader(torch.utils.data.Dataset):
def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False):
"""Main AudioLoader for training and testing."""
self.data_anns = []
self.is_train = is_train
self.params = params
self.return_spec_for_viz = False
def __init__(
self,
data_anns_ip: List[FileAnnotations],
params,
dataset_name: Optional[str] = None,
is_train: bool = False,
):
self.is_train: bool = is_train
self.params: dict = params
self.return_spec_for_viz: bool = False
for ii in range(len(data_anns_ip)):
dd = copy.deepcopy(data_anns_ip[ii])
# filter out unused annotation here
filtered_annotations = []
for ii, aa in enumerate(dd["annotation"]):
if "individual" in aa.keys():
aa["individual"] = int(aa["individual"])
# if only one call labeled it has to be from the same individual
if len(dd["annotation"]) == 1:
aa["individual"] = 0
# convert class name into class label
if aa["class"] in self.params["class_names"]:
aa["class_id"] = self.params["class_names"].index(
aa["class"]
)
else:
aa["class_id"] = -1
if aa["class"] not in self.params["classes_to_ignore"]:
filtered_annotations.append(aa)
dd["annotation"] = filtered_annotations
dd["start_times"] = np.array(
[aa["start_time"] for aa in dd["annotation"]]
self.data_anns: List[AudioLoaderAnnotationGroup] = [
_prepare_file_annotation(
ann,
params["class_names"],
params["classes_to_ignore"],
)
dd["end_times"] = np.array(
[aa["end_time"] for aa in dd["annotation"]]
)
dd["high_freqs"] = np.array(
[float(aa["high_freq"]) for aa in dd["annotation"]]
)
dd["low_freqs"] = np.array(
[float(aa["low_freq"]) for aa in dd["annotation"]]
)
dd["class_ids"] = np.array(
[aa["class_id"] for aa in dd["annotation"]]
).astype(np.int)
dd["individual_ids"] = np.array(
[aa["individual"] for aa in dd["annotation"]]
).astype(np.int)
for ann in data_anns_ip
]
# file level class name
dd["class_id_file"] = -1
if "class_name" in dd.keys():
if dd["class_name"] in self.params["class_names"]:
dd["class_id_file"] = self.params["class_names"].index(
dd["class_name"]
)
self.data_anns.append(dd)
# for ii in range(len(data_anns_ip)):
# dd = copy.deepcopy(data_anns_ip[ii])
#
# # filter out unused annotation here
# filtered_annotations = []
# for ii, aa in enumerate(dd["annotation"]):
# if "individual" in aa.keys():
# aa["individual"] = int(aa["individual"])
#
# # if only one call labeled it has to be from the same
# # individual
# if len(dd["annotation"]) == 1:
# aa["individual"] = 0
#
# # convert class name into class label
# if aa["class"] in self.params["class_names"]:
# aa["class_id"] = self.params["class_names"].index(
# aa["class"]
# )
# else:
# aa["class_id"] = -1
#
# if aa["class"] not in self.params["classes_to_ignore"]:
# filtered_annotations.append(aa)
#
# dd["annotation"] = filtered_annotations
# dd["start_times"] = np.array(
# [aa["start_time"] for aa in dd["annotation"]]
# )
# dd["end_times"] = np.array(
# [aa["end_time"] for aa in dd["annotation"]]
# )
# dd["high_freqs"] = np.array(
# [float(aa["high_freq"]) for aa in dd["annotation"]]
# )
# dd["low_freqs"] = np.array(
# [float(aa["low_freq"]) for aa in dd["annotation"]]
# )
# dd["class_ids"] = np.array(
# [aa["class_id"] for aa in dd["annotation"]]
# ).astype(np.int32)
# dd["individual_ids"] = np.array(
# [aa["individual"] for aa in dd["annotation"]]
# ).astype(np.int32)
#
# # file level class name
# dd["class_id_file"] = -1
# if "class_name" in dd.keys():
# if dd["class_name"] in self.params["class_names"]:
# dd["class_id_file"] = self.params["class_names"].index(
# dd["class_name"]
# )
#
# self.data_anns.append(dd)
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
self.max_num_anns = 2 * np.max(
@ -413,10 +727,30 @@ class AudioLoader(torch.utils.data.Dataset):
print("Num files : " + str(len(self.data_anns)))
print("Num calls : " + str(np.sum(ann_cnt)))
def get_file_and_anns(self, index=None):
def get_file_and_anns(
self,
index: Optional[int] = None,
) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]:
"""Get an audio file and its annotations.
Parameters
----------
index : int, optional
Index of the file to be loaded. If None, a random file is chosen.
Returns
-------
audio_raw : np.ndarray
Loaded audio file.
sampling_rate : int
Sampling rate of the audio file.
duration : float
Duration of the audio file in seconds.
ann : AnnotationGroup
AnnotationGroup object containing the annotations for the audio file.
"""
# if no file specified, choose random one
if index == None:
if index is None:
index = np.random.randint(0, len(self.data_anns))
audio_file = self.data_anns[index]["file_path"]
@ -428,19 +762,19 @@ class AudioLoader(torch.utils.data.Dataset):
)
# copy annotation
ann = {}
ann["annotated"] = self.data_anns[index]["annotated"]
ann["class_id_file"] = self.data_anns[index]["class_id_file"]
keys = [
"start_times",
"end_times",
"high_freqs",
"low_freqs",
"class_ids",
"individual_ids",
]
for kk in keys:
ann[kk] = self.data_anns[index][kk].copy()
ann = copy.deepcopy(self.data_anns[index])
# ann["annotated"] = self.data_anns[index]["annotated"]
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
# keys = [
# "start_times",
# "end_times",
# "high_freqs",
# "low_freqs",
# "class_ids",
# "individual_ids",
# ]
# for kk in keys:
# ann[kk] = self.data_anns[index][kk].copy()
# if train then grab a random crop
if self.is_train:
@ -489,7 +823,7 @@ class AudioLoader(torch.utils.data.Dataset):
return audio_raw, sampling_rate, duration, ann
def __getitem__(self, index):
"""Get an item from the dataset."""
# load audio file
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
@ -515,12 +849,17 @@ class AudioLoader(torch.utils.data.Dataset):
audio = echo_aug(audio, sampling_rate, self.params)
# resample the audio
# if np.random.random() < self.params['aug_prob']:
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
# if np.random.random() < self.params["aug_prob"]:
# audio, sampling_rate, duration = resample_aug(
# audio, sampling_rate, self.params
# )
# create spectrogram
spec, spec_for_viz = au.generate_spectrogram(
audio, sampling_rate, self.params, self.return_spec_for_viz
audio,
sampling_rate,
self.params,
self.return_spec_for_viz,
)
rsf = self.params["resize_factor"]
spec_op_shape = (
@ -531,18 +870,22 @@ class AudioLoader(torch.utils.data.Dataset):
# resize the spec
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
spec = F.interpolate(
spec, size=spec_op_shape, mode="bilinear", align_corners=False
spec,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
).squeeze(0)
# augment spectrogram
if self.is_train and self.params["augment_at_train"]:
if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(spec, self.params)
if np.random.random() < self.params["aug_prob"]:
spec = warp_spec_aug(
spec, ann, self.return_spec_for_viz, self.params
spec,
ann,
self.params,
)
if np.random.random() < self.params["aug_prob"]:
@ -564,10 +907,15 @@ class AudioLoader(torch.utils.data.Dataset):
outputs["y_2d_size"],
outputs["y_2d_classes"],
ann_aug,
) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
) = generate_gt_heatmaps(
spec_op_shape,
sampling_rate,
ann,
self.params,
)
# hack to get around requirement that all vectors are the same length in
# the output batch
# hack to get around requirement that all vectors are the same length
# in the output batch
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
outputs["is_valid"] = pad_aray(
np.ones(len(ann_aug["individual_ids"])), pad_size
@ -600,4 +948,5 @@ class AudioLoader(torch.utils.data.Dataset):
return outputs
def __len__(self):
"""Denotes the total number of samples."""
return len(self.data_anns)

View File

@ -1,18 +1,21 @@
## How to train a model from scratch
`python train_model.py data_dir annotation_dir` e.g.
> **Warning**
> This code in currently broken. Will fix soon, stay tuned.
`python train_model.py data_dir annotation_dir` e.g.
`python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/`
More comprehensive instructions are provided in the finetune directory.
## Training on your own data
You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory.
Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on.
Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on.
Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task.
Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task.
## Additional notes
Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`).
Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`).
Training will be slow without a GPU.
Training will be slow without a GPU.

View File

@ -34,6 +34,7 @@ __all__ = [
"ResultParams",
"RunResults",
"SpectrogramParameters",
"AudioLoaderAnnotationGroup",
]
@ -99,7 +100,7 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
This is the format of a single annotation as expected by the annotation
tool.
"""
@ -127,6 +128,9 @@ class Annotation(DictWithClass):
event: str
"""Type of detected event."""
class_id: NotRequired[int]
"""Numeric ID for the class of the annotation."""
class FileAnnotations(TypedDict):
"""Format of results.
@ -468,8 +472,28 @@ class AnnotationGroup(TypedDict):
individual_ids: np.ndarray
"""Individual IDs of the annotations."""
annotated: NotRequired[bool]
"""Wether the annotation group is complete or not.
Usually annotation groups are associated to a single
audio clip. If the annotation group is complete, it means that all
relevant sound events have been annotated. If it is not complete, it
means that some sound events might not have been annotated.
"""
x_inds: NotRequired[np.ndarray]
"""X coordinate of the annotations in the spectrogram."""
y_inds: NotRequired[np.ndarray]
"""Y coordinate of the annotations in the spectrogram."""
class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations):
"""Group of annotation items for the training audio loader.
This class is used to store the annotations for the training audio
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
"""
class_id_file: int
"""ID of the class of the file."""

View File

@ -127,18 +127,28 @@ def load_audio(
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
audio_file (str): Path to the audio file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
Parameters
----------
audio_file: str
Path to the audio file.
target_samp_rate: int
Target sampling rate.
scale: bool, optional
Whether to scale the audio to [-1, 1]. Default: False.
max_duration: float, optional
Maximum duration of the audio in seconds. Defaults to None.
If provided, the audio is clipped to this duration.
Returns:
sampling_rate: The sampling rate of the audio.
audio_raw: The audio signal in a numpy array.
Returns
-------
sampling_rate: int
The sampling rate of the audio.
audio_raw: np.ndarray
The audio signal in a numpy array.
Raises:
ValueError: If the audio file is stereo.
Raises
------
ValueError: If the audio file is stereo.
"""
with warnings.catch_warnings():