From b252f2309344d58fd5020c70ae428d14254ca9d7 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Thu, 13 Apr 2023 07:58:01 -0600 Subject: [PATCH] WIP --- batdetect2/plot.py | 3 + batdetect2/train/audio_dataloader.py | 597 +++++++++++++++++++++------ batdetect2/train/readme.md | 15 +- batdetect2/types.py | 26 +- batdetect2/utils/audio_utils.py | 30 +- 5 files changed, 530 insertions(+), 141 deletions(-) diff --git a/batdetect2/plot.py b/batdetect2/plot.py index cdcdbd8..1f1a343 100644 --- a/batdetect2/plot.py +++ b/batdetect2/plot.py @@ -306,6 +306,9 @@ def _compute_spec_extent( # If the spectrogram is not resized, the duration is correct # but if it is resized, the duration needs to be adjusted + # NOTE: For now we can only detect if the spectrogram is resized + # by checking if the height is equal to the specified height, + # but this could fail. resize_factor = params["resize_factor"] spec_height = params["spec_height"] if spec_height * resize_factor == shape[0]: diff --git a/batdetect2/train/audio_dataloader.py b/batdetect2/train/audio_dataloader.py index 8130ec6..68d86d4 100644 --- a/batdetect2/train/audio_dataloader.py +++ b/batdetect2/train/audio_dataloader.py @@ -1,14 +1,22 @@ +"""Functions and dataloaders for training and testing the model.""" import copy -from typing import Tuple +from typing import List, Optional, Tuple import librosa import numpy as np import torch import torch.nn.functional as F +import torch.utils.data import torchaudio import batdetect2.utils.audio_utils as au -from batdetect2.types import AnnotationGroup, HeatmapParameters +from batdetect2.types import ( + Annotation, + AnnotationGroup, + AudioLoaderAnnotationGroup, + FileAnnotations, + HeatmapParameters, +) def generate_gt_heatmaps( @@ -32,22 +40,17 @@ def generate_gt_heatmaps( Returns ------- - y_2d_det : np.ndarray 2D heatmap of the presence of an event. - y_2d_size : np.ndarray 2D heatmap of the size of the bounding box associated to event. - y_2d_classes : np.ndarray 3D array containing the ground-truth class probabilities for each pixel. - ann_aug : AnnotationGroup A dictionary containing the annotation information of the annotations that are within the input spectrogram, augmented with the x and y indices of their pixel location in the input spectrogram. - """ # spec may be resized on input into the network num_classes = len(params["class_names"]) @@ -62,20 +65,20 @@ def generate_gt_heatmaps( params["fft_win_length"], params["fft_overlap"], ) - x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int) + x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32) x_pos_end = au.time_to_x_coords( ann["end_times"], sampling_rate, params["fft_win_length"], params["fft_overlap"], ) - x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int) + x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32) # location on y axis i.e. frequency y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin - y_pos_low = (op_height - y_pos_low).astype(np.int) + y_pos_low = (op_height - y_pos_low).astype(np.int32) y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin - y_pos_high = (op_height - y_pos_high).astype(np.int) + y_pos_high = (op_height - y_pos_high).astype(np.int32) bb_widths = x_pos_end - x_pos_start bb_heights = y_pos_low - y_pos_high @@ -127,7 +130,12 @@ def generate_gt_heatmaps( (x_pos_start[ii], y_pos_low[ii]), params["target_sigma"], ) - # draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) + # draw_gaussian( + # y_2d_det[0, :], + # (x_pos_start[ii], y_pos_low[ii]), + # params["target_sigma"], + # params["target_sigma"] * 2, + # ) y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii] y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii] @@ -138,10 +146,15 @@ def generate_gt_heatmaps( (x_pos_start[ii], y_pos_low[ii]), params["target_sigma"], ) - # draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) + # draw_gaussian( + # y_2d_classes[cls_id, :], + # (x_pos_start[ii], y_pos_low[ii]), + # params["target_sigma"], + # params["target_sigma"] * 2, + # ) - # be careful as this will have a 1.0 places where we have event but dont know gt class - # this will be masked in training anyway + # be careful as this will have a 1.0 places where we have event but + # dont know gt class this will be masked in training anyway y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0) y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...] y_2d_classes[np.isnan(y_2d_classes)] = 0.0 @@ -149,7 +162,37 @@ def generate_gt_heatmaps( return y_2d_det, y_2d_size, y_2d_classes, ann_aug -def draw_gaussian(heatmap, center, sigmax, sigmay=None): +def draw_gaussian( + heatmap: np.ndarray, + center: Tuple[int, int], + sigmax: float, + sigmay: Optional[float] = None, +) -> bool: + """Draw a 2D gaussian into the heatmap. + + If the gaussian center is outside the heatmap, then the gaussian is not + drawn. + + Parameters + ---------- + heatmap : np.ndarray + The heatmap to draw into. Should be of shape (height, width). + center : Tuple[int, int] + The center of the gaussian in (x, y) format. + sigmax : float + The standard deviation of the gaussian in the x direction. + sigmay : Optional[float], optional + The standard deviation of the gaussian in the y direction. If None, + then sigmay = sigmax, by default None. + + Returns + ------- + bool + True if the gaussian was drawn, False if it was not (because + the center was outside the heatmap). + + + """ # center is (x, y) # this edits the heatmap inplace @@ -185,23 +228,46 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None): return True -def pad_aray(ip_array, pad_size): - return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1)) +def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray: + """Pad array with -1s.""" + return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1)) -def warp_spec_aug(spec, ann, return_spec_for_viz, params): +def warp_spec_aug( + spec: torch.Tensor, + ann: AnnotationGroup, + params: dict, +) -> torch.Tensor: + """Warp spectrogram by randomly stretching and squeezing. + + Parameters + ---------- + spec: torch.Tensor + Spectrogram to warp. + ann: AnnotationGroup + Annotation group for the spectrogram. Must be provided to sync + the start and stop times with the spectrogram after warping. + params: dict + Parameters for the augmentation. + + Returns + ------- + torch.Tensor + Warped spectrogram. + + Notes + ----- + This function modifies the annotation group in place. + """ # This is messy # Augment spectrogram by randomly stretch and squeezing # NOTE this also changes the start and stop time in place - # not taking care of spec for viz - if return_spec_for_viz: - assert False - delta = params["stretch_squeeze_delta"] op_size = (spec.shape[1], spec.shape[2]) resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0 resize_amt = int(spec.shape[2] * resize_fract_r) + if resize_amt >= spec.shape[2]: spec_r = torch.cat( ( @@ -215,41 +281,124 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params): ) else: spec_r = spec[:, :, :resize_amt] + + # Resize the spectrogram spec = F.interpolate( - spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False + spec_r.unsqueeze(0), + size=op_size, + mode="bilinear", + align_corners=False, ).squeeze(0) + + # Update the start and stop times ann["start_times"] *= 1.0 / resize_fract_r ann["end_times"] *= 1.0 / resize_fract_r + return spec -def mask_time_aug(spec, params): - # Mask out a random block of time - repeat up to 3 times - # SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition +def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: + """Mask out random blocks of time. + + Will randomly mask out a block of time in the spectrogram. The block + will be between 0.0 and `mask_max_time_perc` of the total time. + A random number of blocks will be masked out between 1 and 3. + + Parameters + ---------- + spec: torch.Tensor + Spectrogram to mask. + params: dict + Parameters for the augmentation. + + Returns + ------- + torch.Tensor + Spectrogram with masked out time blocks. + + Notes + ----- + This function is based on the implementation in:: + + SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition + """ fm = torchaudio.transforms.TimeMasking( int(spec.shape[1] * params["mask_max_time_perc"]) ) - for ii in range(np.random.randint(1, 4)): + for _ in range(np.random.randint(1, 4)): spec = fm(spec) return spec -def mask_freq_aug(spec, params): - # Mask out a random frequncy range - repeat up to 3 times - # SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition +def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: + """Mask out random blocks of frequency. + + Will randomly mask out a block of frequency in the spectrogram. The block + will be between 0.0 and `mask_max_freq_perc` of the total frequency. + A random number of blocks will be masked out between 1 and 3. + + Parameters + ---------- + spec: torch.Tensor + Spectrogram to mask. + params: dict + Parameters for the augmentation. + + Returns + ------- + torch.Tensor + Spectrogram with masked out frequency blocks. + + Notes + ----- + This function is based on the implementation in:: + + SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition + """ fm = torchaudio.transforms.FrequencyMasking( int(spec.shape[1] * params["mask_max_freq_perc"]) ) - for ii in range(np.random.randint(1, 4)): + for _ in range(np.random.randint(1, 4)): spec = fm(spec) return spec -def scale_vol_aug(spec, params): +def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: + """Scale the volume of the spectrogram. + + Parameters + ---------- + spec: torch.Tensor + Spectrogram to scale. + params: dict + Parameters for the augmentation. + + Returns + ------- + torch.Tensor + """ return spec * np.random.random() * params["spec_amp_scaling"] -def echo_aug(audio, sampling_rate, params): +def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray: + """Add echo to audio. + + Parameters + ---------- + audio: np.ndarray + Audio to add echo to. + sampling_rate: int + Sampling rate of the audio. + params: dict + Parameters for the augmentation. + + Returns + ------- + np.ndarray + Audio with echo added. + """ sample_offset = ( int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1 ) @@ -257,7 +406,35 @@ def echo_aug(audio, sampling_rate, params): return audio -def resample_aug(audio, sampling_rate, params): +def resample_aug( + audio: np.ndarray, + sampling_rate: int, + params: dict, +) -> Tuple[np.ndarray, int, float]: + """Resample audio augmentation. + + Will resample the audio to a random sampling rate from the list of + sampling rates in `aug_sampling_rates`. + + Parameters + ---------- + audio: np.ndarray + Audio to resample. + sampling_rate: int + Original sampling rate of the audio. + params: dict + Parameters for the augmentation. Includes the list of sampling rates + to choose from for resampling in `aug_sampling_rates`. + + Returns + ------- + audio : np.ndarray + Resampled audio. + sampling_rate : int + New sampling rate. + duration : float + Duration of the audio in seconds. + """ sampling_rate_old = sampling_rate sampling_rate = np.random.choice(params["aug_sampling_rates"]) audio = librosa.resample( @@ -280,7 +457,33 @@ def resample_aug(audio, sampling_rate, params): return audio, sampling_rate, duration -def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): +def resample_audio( + num_samples: int, + sampling_rate: int, + audio2: np.ndarray, + sampling_rate2: int, +) -> Tuple[np.ndarray, int]: + """Resample audio. + + Parameters + ---------- + num_samples: int + Expected number of samples for the output audio. + sampling_rate: int + Original sampling rate of the audio. + audio2: np.ndarray + Audio to resample. + sampling_rate2: int + Target sampling rate of the audio. + + Returns + ------- + audio2 : np.ndarray + Resampled audio. + sampling_rate2 : int + New sampling rate. + """ + # resample to target sampling rate if sampling_rate != sampling_rate2: audio2 = librosa.resample( audio2, @@ -289,6 +492,8 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): res_type="polyphase", ) sampling_rate2 = sampling_rate + + # pad or trim to the correct length if audio2.shape[0] < num_samples: audio2 = np.hstack( ( @@ -298,14 +503,52 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): ) elif audio2.shape[0] > num_samples: audio2 = audio2[:num_samples] + return audio2, sampling_rate2 -def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2): +def combine_audio_aug( + audio: np.ndarray, + sampling_rate: int, + ann: AnnotationGroup, + audio2: np.ndarray, + sampling_rate2: int, + ann2: AnnotationGroup, +) -> Tuple[np.ndarray, AnnotationGroup]: + """Combine two audio files. + Will combine two audio files by resampling them to the same sampling rate + and then combining them with a random weight. The annotations will be + combined by taking the union of the two sets of annotations. + + Parameters + ---------- + audio: np.ndarray + First Audio to combine. + sampling_rate: int + Sampling rate of the first audio. + ann: AnnotationGroup + Annotations for the first audio. + audio2: np.ndarray + Second Audio to combine. + sampling_rate2: int + Sampling rate of the second audio. + ann2: AnnotationGroup + Annotations for the second audio. + + Returns + ------- + audio : np.ndarray + Combined audio. + ann : AnnotationGroup + Combined annotations. + """ # resample so they are the same audio2, sampling_rate2 = resample_audio( - audio.shape[0], sampling_rate, audio2, sampling_rate2 + audio.shape[0], + sampling_rate, + audio2, + sampling_rate2, ) # # set mean and std to be the same @@ -314,8 +557,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2): # audio2 = audio2 + audio.mean() if ( - ann["annotated"] - and (ann2["annotated"]) + ann.get("annotated", False) + and (ann2.get("annotated", False)) and (sampling_rate2 == sampling_rate) and (audio.shape[0] == audio2.shape[0]) ): @@ -323,8 +566,8 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2): audio = comb_weight * audio + (1 - comb_weight) * audio2 inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"]))) for kk in ann.keys(): - - # when combining calls from different files, assume they come from different individuals + # when combining calls from different files, assume they come + # from different individuals if kk == "individual_ids": if (ann[kk] > -1).sum() > 0: ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1 @@ -335,68 +578,139 @@ def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2): return audio, ann +def _prepare_annotation( + annotation: Annotation, class_names: List[str] +) -> Annotation: + try: + class_id = class_names.index(annotation["class"]) + except ValueError: + class_id = -1 + + ann: Annotation = { + **annotation, + "class_id": class_id, + } + + if "individual" in ann: + ann["individual"] = int(ann["individual"]) # type: ignore + + return ann + + +def _prepare_file_annotation( + annotation: FileAnnotations, + class_names: List[str], + classes_to_ignore: List[str], +) -> AudioLoaderAnnotationGroup: + annotations = [ + _prepare_annotation(ann, class_names) + for ann in annotation["annotation"] + if ann["class"] not in classes_to_ignore + ] + + try: + class_id_file = class_names.index(annotation["class_name"]) + except ValueError: + class_id_file = -1 + + ret: AudioLoaderAnnotationGroup = { + "id": annotation["id"], + "annotated": annotation["annotated"], + "duration": annotation["duration"], + "issues": annotation["issues"], + "time_exp": annotation["time_exp"], + "class_name": annotation["class_name"], + "notes": annotation["notes"], + "annotation": annotations, + "start_times": np.array([ann["start_time"] for ann in annotations]), + "end_times": np.array([ann["end_time"] for ann in annotations]), + "high_freqs": np.array([ann["high_freq"] for ann in annotations]), + "low_freqs": np.array([ann["low_freq"] for ann in annotations]), + "class_ids": np.array([ann.get("class_id", -1) for ann in annotations]), + "individual_ids": np.array([ann["individual"] for ann in annotations]), + "class_id_file": class_id_file, + } + + return ret + + class AudioLoader(torch.utils.data.Dataset): - def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False): + """Main AudioLoader for training and testing.""" - self.data_anns = [] - self.is_train = is_train - self.params = params - self.return_spec_for_viz = False + def __init__( + self, + data_anns_ip: List[FileAnnotations], + params, + dataset_name: Optional[str] = None, + is_train: bool = False, + ): + self.is_train: bool = is_train + self.params: dict = params + self.return_spec_for_viz: bool = False - for ii in range(len(data_anns_ip)): - dd = copy.deepcopy(data_anns_ip[ii]) - - # filter out unused annotation here - filtered_annotations = [] - for ii, aa in enumerate(dd["annotation"]): - - if "individual" in aa.keys(): - aa["individual"] = int(aa["individual"]) - - # if only one call labeled it has to be from the same individual - if len(dd["annotation"]) == 1: - aa["individual"] = 0 - - # convert class name into class label - if aa["class"] in self.params["class_names"]: - aa["class_id"] = self.params["class_names"].index( - aa["class"] - ) - else: - aa["class_id"] = -1 - - if aa["class"] not in self.params["classes_to_ignore"]: - filtered_annotations.append(aa) - - dd["annotation"] = filtered_annotations - dd["start_times"] = np.array( - [aa["start_time"] for aa in dd["annotation"]] + self.data_anns: List[AudioLoaderAnnotationGroup] = [ + _prepare_file_annotation( + ann, + params["class_names"], + params["classes_to_ignore"], ) - dd["end_times"] = np.array( - [aa["end_time"] for aa in dd["annotation"]] - ) - dd["high_freqs"] = np.array( - [float(aa["high_freq"]) for aa in dd["annotation"]] - ) - dd["low_freqs"] = np.array( - [float(aa["low_freq"]) for aa in dd["annotation"]] - ) - dd["class_ids"] = np.array( - [aa["class_id"] for aa in dd["annotation"]] - ).astype(np.int) - dd["individual_ids"] = np.array( - [aa["individual"] for aa in dd["annotation"]] - ).astype(np.int) + for ann in data_anns_ip + ] - # file level class name - dd["class_id_file"] = -1 - if "class_name" in dd.keys(): - if dd["class_name"] in self.params["class_names"]: - dd["class_id_file"] = self.params["class_names"].index( - dd["class_name"] - ) - - self.data_anns.append(dd) + # for ii in range(len(data_anns_ip)): + # dd = copy.deepcopy(data_anns_ip[ii]) + # + # # filter out unused annotation here + # filtered_annotations = [] + # for ii, aa in enumerate(dd["annotation"]): + # if "individual" in aa.keys(): + # aa["individual"] = int(aa["individual"]) + # + # # if only one call labeled it has to be from the same + # # individual + # if len(dd["annotation"]) == 1: + # aa["individual"] = 0 + # + # # convert class name into class label + # if aa["class"] in self.params["class_names"]: + # aa["class_id"] = self.params["class_names"].index( + # aa["class"] + # ) + # else: + # aa["class_id"] = -1 + # + # if aa["class"] not in self.params["classes_to_ignore"]: + # filtered_annotations.append(aa) + # + # dd["annotation"] = filtered_annotations + # dd["start_times"] = np.array( + # [aa["start_time"] for aa in dd["annotation"]] + # ) + # dd["end_times"] = np.array( + # [aa["end_time"] for aa in dd["annotation"]] + # ) + # dd["high_freqs"] = np.array( + # [float(aa["high_freq"]) for aa in dd["annotation"]] + # ) + # dd["low_freqs"] = np.array( + # [float(aa["low_freq"]) for aa in dd["annotation"]] + # ) + # dd["class_ids"] = np.array( + # [aa["class_id"] for aa in dd["annotation"]] + # ).astype(np.int32) + # dd["individual_ids"] = np.array( + # [aa["individual"] for aa in dd["annotation"]] + # ).astype(np.int32) + # + # # file level class name + # dd["class_id_file"] = -1 + # if "class_name" in dd.keys(): + # if dd["class_name"] in self.params["class_names"]: + # dd["class_id_file"] = self.params["class_names"].index( + # dd["class_name"] + # ) + # + # self.data_anns.append(dd) ann_cnt = [len(aa["annotation"]) for aa in self.data_anns] self.max_num_anns = 2 * np.max( @@ -413,10 +727,30 @@ class AudioLoader(torch.utils.data.Dataset): print("Num files : " + str(len(self.data_anns))) print("Num calls : " + str(np.sum(ann_cnt))) - def get_file_and_anns(self, index=None): + def get_file_and_anns( + self, + index: Optional[int] = None, + ) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]: + """Get an audio file and its annotations. + Parameters + ---------- + index : int, optional + Index of the file to be loaded. If None, a random file is chosen. + + Returns + ------- + audio_raw : np.ndarray + Loaded audio file. + sampling_rate : int + Sampling rate of the audio file. + duration : float + Duration of the audio file in seconds. + ann : AnnotationGroup + AnnotationGroup object containing the annotations for the audio file. + """ # if no file specified, choose random one - if index == None: + if index is None: index = np.random.randint(0, len(self.data_anns)) audio_file = self.data_anns[index]["file_path"] @@ -428,19 +762,19 @@ class AudioLoader(torch.utils.data.Dataset): ) # copy annotation - ann = {} - ann["annotated"] = self.data_anns[index]["annotated"] - ann["class_id_file"] = self.data_anns[index]["class_id_file"] - keys = [ - "start_times", - "end_times", - "high_freqs", - "low_freqs", - "class_ids", - "individual_ids", - ] - for kk in keys: - ann[kk] = self.data_anns[index][kk].copy() + ann = copy.deepcopy(self.data_anns[index]) + # ann["annotated"] = self.data_anns[index]["annotated"] + # ann["class_id_file"] = self.data_anns[index]["class_id_file"] + # keys = [ + # "start_times", + # "end_times", + # "high_freqs", + # "low_freqs", + # "class_ids", + # "individual_ids", + # ] + # for kk in keys: + # ann[kk] = self.data_anns[index][kk].copy() # if train then grab a random crop if self.is_train: @@ -489,7 +823,7 @@ class AudioLoader(torch.utils.data.Dataset): return audio_raw, sampling_rate, duration, ann def __getitem__(self, index): - + """Get an item from the dataset.""" # load audio file audio, sampling_rate, duration, ann = self.get_file_and_anns(index) @@ -515,12 +849,17 @@ class AudioLoader(torch.utils.data.Dataset): audio = echo_aug(audio, sampling_rate, self.params) # resample the audio - # if np.random.random() < self.params['aug_prob']: - # audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params) + # if np.random.random() < self.params["aug_prob"]: + # audio, sampling_rate, duration = resample_aug( + # audio, sampling_rate, self.params + # ) # create spectrogram spec, spec_for_viz = au.generate_spectrogram( - audio, sampling_rate, self.params, self.return_spec_for_viz + audio, + sampling_rate, + self.params, + self.return_spec_for_viz, ) rsf = self.params["resize_factor"] spec_op_shape = ( @@ -531,18 +870,22 @@ class AudioLoader(torch.utils.data.Dataset): # resize the spec spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0) spec = F.interpolate( - spec, size=spec_op_shape, mode="bilinear", align_corners=False + spec, + size=spec_op_shape, + mode="bilinear", + align_corners=False, ).squeeze(0) # augment spectrogram if self.is_train and self.params["augment_at_train"]: - if np.random.random() < self.params["aug_prob"]: spec = scale_vol_aug(spec, self.params) if np.random.random() < self.params["aug_prob"]: spec = warp_spec_aug( - spec, ann, self.return_spec_for_viz, self.params + spec, + ann, + self.params, ) if np.random.random() < self.params["aug_prob"]: @@ -564,10 +907,15 @@ class AudioLoader(torch.utils.data.Dataset): outputs["y_2d_size"], outputs["y_2d_classes"], ann_aug, - ) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params) + ) = generate_gt_heatmaps( + spec_op_shape, + sampling_rate, + ann, + self.params, + ) - # hack to get around requirement that all vectors are the same length in - # the output batch + # hack to get around requirement that all vectors are the same length + # in the output batch pad_size = self.max_num_anns - len(ann_aug["individual_ids"]) outputs["is_valid"] = pad_aray( np.ones(len(ann_aug["individual_ids"])), pad_size @@ -600,4 +948,5 @@ class AudioLoader(torch.utils.data.Dataset): return outputs def __len__(self): + """Denotes the total number of samples.""" return len(self.data_anns) diff --git a/batdetect2/train/readme.md b/batdetect2/train/readme.md index e406c7d..eaa1253 100644 --- a/batdetect2/train/readme.md +++ b/batdetect2/train/readme.md @@ -1,18 +1,21 @@ ## How to train a model from scratch -`python train_model.py data_dir annotation_dir` e.g. + +> **Warning** +> This code in currently broken. Will fix soon, stay tuned. + +`python train_model.py data_dir annotation_dir` e.g. `python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/` More comprehensive instructions are provided in the finetune directory. - ## Training on your own data You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory. -Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on. +Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on. -Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task. +Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task. ## Additional notes -Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`). +Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`). -Training will be slow without a GPU. +Training will be slow without a GPU. diff --git a/batdetect2/types.py b/batdetect2/types.py index 3bc810b..8a6437c 100644 --- a/batdetect2/types.py +++ b/batdetect2/types.py @@ -34,6 +34,7 @@ __all__ = [ "ResultParams", "RunResults", "SpectrogramParameters", + "AudioLoaderAnnotationGroup", ] @@ -99,7 +100,7 @@ DictWithClass = TypedDict("DictWithClass", {"class": str}) class Annotation(DictWithClass): """Format of annotations. - This is the format of a single annotation as expected by the annotation + This is the format of a single annotation as expected by the annotation tool. """ @@ -127,6 +128,9 @@ class Annotation(DictWithClass): event: str """Type of detected event.""" + class_id: NotRequired[int] + """Numeric ID for the class of the annotation.""" + class FileAnnotations(TypedDict): """Format of results. @@ -468,8 +472,28 @@ class AnnotationGroup(TypedDict): individual_ids: np.ndarray """Individual IDs of the annotations.""" + annotated: NotRequired[bool] + """Wether the annotation group is complete or not. + + Usually annotation groups are associated to a single + audio clip. If the annotation group is complete, it means that all + relevant sound events have been annotated. If it is not complete, it + means that some sound events might not have been annotated. + """ + x_inds: NotRequired[np.ndarray] """X coordinate of the annotations in the spectrogram.""" y_inds: NotRequired[np.ndarray] """Y coordinate of the annotations in the spectrogram.""" + + +class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations): + """Group of annotation items for the training audio loader. + + This class is used to store the annotations for the training audio + loader. It inherits from `AnnotationGroup` and `FileAnnotations`. + """ + + class_id_file: int + """ID of the class of the file.""" diff --git a/batdetect2/utils/audio_utils.py b/batdetect2/utils/audio_utils.py index 7c5852a..908d971 100644 --- a/batdetect2/utils/audio_utils.py +++ b/batdetect2/utils/audio_utils.py @@ -127,18 +127,28 @@ def load_audio( The audio is also scaled to [-1, 1] and clipped to the maximum duration. Only mono files are supported. - Args: - audio_file (str): Path to the audio file. - target_samp_rate (int): Target sampling rate. - scale (bool): Whether to scale the audio to [-1, 1]. - max_duration (float): Maximum duration of the audio in seconds. + Parameters + ---------- + audio_file: str + Path to the audio file. + target_samp_rate: int + Target sampling rate. + scale: bool, optional + Whether to scale the audio to [-1, 1]. Default: False. + max_duration: float, optional + Maximum duration of the audio in seconds. Defaults to None. + If provided, the audio is clipped to this duration. - Returns: - sampling_rate: The sampling rate of the audio. - audio_raw: The audio signal in a numpy array. + Returns + ------- + sampling_rate: int + The sampling rate of the audio. + audio_raw: np.ndarray + The audio signal in a numpy array. - Raises: - ValueError: If the audio file is stereo. + Raises + ------ + ValueError: If the audio file is stereo. """ with warnings.catch_warnings():