diff --git a/batdetect2/detector/compute_features.py b/batdetect2/detector/compute_features.py index 31b06bc..f9d1da5 100644 --- a/batdetect2/detector/compute_features.py +++ b/batdetect2/detector/compute_features.py @@ -1,4 +1,5 @@ """Functions to compute features from predictions.""" + from typing import Dict, List, Optional import numpy as np @@ -219,7 +220,6 @@ def compute_call_interval( return round(prediction["start_time"] - previous["end_time"], 5) - # NOTE: The order of the features in this dictionary is important. The # features are extracted in this order and the order of the columns in the # output csv file is determined by this order. In order to avoid breaking diff --git a/batdetect2/detector/models.py b/batdetect2/detector/models.py index 84168bc..105ddaf 100644 --- a/batdetect2/detector/models.py +++ b/batdetect2/detector/models.py @@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module): num_filts // 4, 2, kernel_size=1, padding=0 ) self.conv_classes_op = nn.Conv2d( - num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0, + num_filts // 4, + self.num_classes + 1, + kernel_size=1, + padding=0, ) if self.emb_dim > 0: diff --git a/batdetect2/detector/parameters.py b/batdetect2/detector/parameters.py index 54dab95..fcc7a8c 100644 --- a/batdetect2/detector/parameters.py +++ b/batdetect2/detector/parameters.py @@ -5,7 +5,10 @@ from typing import List, Optional, Union from pydantic import BaseModel, Field, computed_field -from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names +from batdetect2.train.train_utils import ( + get_genus_mapping, + get_short_class_names, +) from batdetect2.types import ProcessingConfiguration, SpectrogramParameters TARGET_SAMPLERATE_HZ = 256000 diff --git a/batdetect2/detector/post_process.py b/batdetect2/detector/post_process.py index b47eec6..eaa8152 100644 --- a/batdetect2/detector/post_process.py +++ b/batdetect2/detector/post_process.py @@ -1,4 +1,5 @@ """Post-processing of the output of the model.""" + from typing import List, Tuple, Union import numpy as np diff --git a/batdetect2/evaluate/evaluate_models.py b/batdetect2/evaluate/evaluate_models.py index 9c7717a..602faa8 100644 --- a/batdetect2/evaluate/evaluate_models.py +++ b/batdetect2/evaluate/evaluate_models.py @@ -20,7 +20,6 @@ from batdetect2.detector import parameters def get_blank_annotation(ip_str): - res = {} res["class_name"] = "" res["duration"] = -1 @@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names): def load_tadarida_pred(ip_dir, dataset, file_of_interest): - res, ann = get_blank_annotation("Generated by Tadarida") # create the annotations in the correct format @@ -120,7 +118,6 @@ def load_sonobat_meta( class_names, only_accepted_species=True, ): - sp_dict = {} for ss in class_names: sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3] @@ -182,7 +179,6 @@ def load_sonobat_meta( def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): - # create the annotations in the correct format res, ann = get_blank_annotation("Generated by Sonobat") res_c = copy.deepcopy(res) @@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): def bb_overlap(bb_g_in, bb_p_in): - freq_scale = 10000000.0 # ensure that both axis are roughly the same range bb_g = [ bb_g_in["start_time"], @@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument( "op_dir", diff --git a/batdetect2/models/decoder.py b/batdetect2/models/decoder.py index 17bbebf..75863c0 100644 --- a/batdetect2/models/decoder.py +++ b/batdetect2/models/decoder.py @@ -13,5 +13,3 @@ else: def pairwise(iterable: Sequence) -> Iterable: for x, y in zip(iterable[:-1], iterable[1:]): yield x, y - - diff --git a/batdetect2/models/encoder.py b/batdetect2/models/encoder.py index 3f56130..4192e62 100644 --- a/batdetect2/models/encoder.py +++ b/batdetect2/models/encoder.py @@ -13,5 +13,3 @@ else: def pairwise(iterable: Sequence) -> Iterable: for x, y in zip(iterable[:-1], iterable[1:]): yield x, y - - diff --git a/batdetect2/plotting/common.py b/batdetect2/plotting/common.py index 82ce8c9..0e5f003 100644 --- a/batdetect2/plotting/common.py +++ b/batdetect2/plotting/common.py @@ -17,6 +17,6 @@ def create_ax( ) -> axes.Axes: """Create a new axis if none is provided""" if ax is None: - _, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore + _, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore - return ax # type: ignore + return ax # type: ignore diff --git a/batdetect2/post_process.py b/batdetect2/post_process.py index 871b0a2..3b1df2d 100644 --- a/batdetect2/post_process.py +++ b/batdetect2/post_process.py @@ -1,6 +1,6 @@ """Module for postprocessing model outputs.""" -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -30,6 +30,16 @@ class PostprocessConfig(BaseModel): top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) +class RawPrediction(BaseModel): + start_time: float + end_time: float + low_freq: float + high_freq: float + detection_score: float + class_scores: Dict[str, float] + features: np.ndarray + + def postprocess_model_outputs( outputs: ModelOutput, clips: List[data.Clip], @@ -125,6 +135,88 @@ def postprocess_model_outputs( return predictions +def compute_predictions_from_outputs( + start: float, + end: float, + scores: torch.Tensor, + y_pos: torch.Tensor, + x_pos: torch.Tensor, + size_preds: torch.Tensor, + class_probs: torch.Tensor, + features: torch.Tensor, + classes: List[str], + min_freq: int = 10000, + max_freq: int = 120000, + detection_threshold: float = DETECTION_THRESHOLD, +) -> List[RawPrediction]: + _, freq_bins, time_bins = size_preds.shape + + sorted_indices = torch.argsort(x_pos) + valid_indices = sorted_indices[ + scores[sorted_indices] > detection_threshold + ] + + scores = scores[valid_indices] + x_pos = x_pos[valid_indices] + y_pos = y_pos[valid_indices] + + predictions: List[RawPrediction] = [] + for score, x, y in zip(scores, x_pos, y_pos): + width, height = size_preds[:, y, x] + class_prob = class_probs[:, y, x].detach().numpy() + feats = features[:, y, x].detach().numpy() + + start_time = np.interp( + x.item(), + [0, time_bins], + [start, end], + ) + + end_time = np.interp( + x.item() + width.item(), + [0, time_bins], + [start, end], + ) + + low_freq = np.interp( + y.item(), + [0, freq_bins], + [max_freq, min_freq], + ) + + high_freq = np.interp( + y.item() - height.item(), + [0, freq_bins], + [max_freq, min_freq], + ) + + start_time, end_time = sorted([float(start_time), float(end_time)]) + low_freq, high_freq = sorted([float(low_freq), float(high_freq)]) + + predictions.append( + RawPrediction( + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + detection_score=score.item(), + features=feats, + class_scores={ + class_name: prob + for class_name, prob in zip(classes, class_prob) + }, + ) + ) + + return predictions + + +def convert_raw_prediction_to_soundevent( + prediction: RawPrediction, +) -> data.SoundEventPrediction: + pass + + def compute_sound_events_from_outputs( clip: data.Clip, scores: torch.Tensor, diff --git a/batdetect2/train/evaluate.py b/batdetect2/train/evaluate.py index cf3d4d1..dd0cbfe 100755 --- a/batdetect2/train/evaluate.py +++ b/batdetect2/train/evaluate.py @@ -1,10 +1,13 @@ from typing import List import numpy as np +import pandas as pd from sklearn.metrics import auc, roc_curve from soundevent import data from soundevent.evaluation import match_geometries +from batdetect2.train.targets import build_encoder, get_class_names + def match_predictions_and_annotations( clip_annotation: data.ClipAnnotation, @@ -48,6 +51,13 @@ def match_predictions_and_annotations( return matches +def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame: + ret = [] + + for match in matches: + pass + + def compute_error_auc(op_str, gt, pred, prob): # classification error pred_int = (pred > prob).astype(np.int32) diff --git a/batdetect2/train/modules.py b/batdetect2/train/modules.py index ff44e84..6a6bb85 100644 --- a/batdetect2/train/modules.py +++ b/batdetect2/train/modules.py @@ -1,8 +1,10 @@ +from pathlib import Path from typing import Optional import pytorch_lightning as L import torch from pydantic import Field +from soundevent import data from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader @@ -19,7 +21,7 @@ from batdetect2.post_process import ( PostprocessConfig, postprocess_model_outputs, ) -from batdetect2.preprocess import PreprocessingConfig +from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.evaluate import match_predictions_and_annotations from batdetect2.train.losses import LossConfig, compute_loss @@ -146,12 +148,14 @@ class DetectorModel(L.LightningModule): config=self.config.postprocessing, )[0] - self.validation_predictions.extend( - match_predictions_and_annotations(clip_annotation, clip_prediction) + matches = match_predictions_and_annotations( + clip_annotation, + clip_prediction, ) + self.validation_predictions.extend(matches) + def on_validation_epoch_end(self) -> None: - print(len(self.validation_predictions)) self.validation_predictions.clear() def configure_optimizers(self): @@ -159,3 +163,23 @@ class DetectorModel(L.LightningModule): optimizer = Adam(self.parameters(), lr=conf.learning_rate) scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max) return [optimizer], [scheduler] + + def process_clip( + self, + clip: data.Clip, + audio_dir: Optional[Path] = None, + ) -> data.ClipPrediction: + spec = preprocess_audio_clip( + clip, + config=self.config.preprocessing, + audio_dir=audio_dir, + ) + tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0) + outputs = self.forward(tensor) + return postprocess_model_outputs( + outputs, + clips=[clip], + classes=self.class_names, + decoder=self.decoder, + config=self.config.postprocessing, + )[0] diff --git a/batdetect2/train/train_split.py b/batdetect2/train/train_split.py index 902fe82..272e8f8 100644 --- a/batdetect2/train/train_split.py +++ b/batdetect2/train/train_split.py @@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True): def split_diff(ann_dir, wav_dir, load_extra=True): - train_sets = [] if load_extra: train_sets.append( @@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True): def split_same(ann_dir, wav_dir, load_extra=True): - train_sets = [] if load_extra: train_sets.append( diff --git a/batdetect2/train/train_utils.py b/batdetect2/train/train_utils.py index 5414d4a..1d0a6bd 100644 --- a/batdetect2/train/train_utils.py +++ b/batdetect2/train/train_utils.py @@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]: def standardize_low_freq( - data: List[types.FileAnnotation], class_of_interest: str, + data: List[types.FileAnnotation], + class_of_interest: str, ) -> List[types.FileAnnotation]: # address the issue of highly variable low frequency annotations # this often happens for contstant frequency calls diff --git a/batdetect2/types.py b/batdetect2/types.py index 9d55ae7..ec9ea8b 100644 --- a/batdetect2/types.py +++ b/batdetect2/types.py @@ -1,4 +1,5 @@ """Types used in the code base.""" + from typing import Any, List, NamedTuple, Optional @@ -594,8 +595,7 @@ class FeatureExtractor(Protocol): self, prediction: Prediction, **kwargs: Any, - ) -> float: - ... + ) -> float: ... class DatasetDict(TypedDict): diff --git a/batdetect2/utils/audio_utils.py b/batdetect2/utils/audio_utils.py index 347e3ff..531c8a5 100644 --- a/batdetect2/utils/audio_utils.py +++ b/batdetect2/utils/audio_utils.py @@ -100,7 +100,7 @@ def generate_spectrogram( # log_scaling = (1.0 / sampling_rate)*10e4 spec = np.log1p(log_scaling * spec) elif params["spec_scale"] == "pcen": - spec = pcen(spec , sampling_rate) + spec = pcen(spec, sampling_rate) elif params["spec_scale"] == "none": pass diff --git a/batdetect2/utils/plot_utils.py b/batdetect2/utils/plot_utils.py index 4bfde7a..da4664d 100644 --- a/batdetect2/utils/plot_utils.py +++ b/batdetect2/utils/plot_utils.py @@ -417,7 +417,9 @@ def plot_confusion_matrix( cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] - cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] + cm[valid_inds, :] = ( + cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] + ) cm[np.where(cm_norm == -0)[0], :] = np.nan if verbose: diff --git a/batdetect2/utils/visualize.py b/batdetect2/utils/visualize.py index d79f322..54be1df 100644 --- a/batdetect2/utils/visualize.py +++ b/batdetect2/utils/visualize.py @@ -155,9 +155,9 @@ class InteractivePlotter: # draw bounding box around call self.ax[1].patches[0].remove() - spec_width_orig = self.spec_slices[self.current_id].shape[1] / ( - 1.0 + 2.0 * self.spec_pad - ) + spec_width_orig = self.spec_slices[self.current_id].shape[ + 1 + ] / (1.0 + 2.0 * self.spec_pad) xx = w_diff + self.spec_pad * spec_width_orig ww = spec_width_orig yy = self.call_info[self.current_id]["low_freq"] / 1000 @@ -183,7 +183,9 @@ class InteractivePlotter: round(self.call_info[self.current_id]["start_time"], 3) ) + ", prob=" - + str(round(self.call_info[self.current_id]["det_prob"], 3)) + + str( + round(self.call_info[self.current_id]["det_prob"], 3) + ) ) self.ax[0].set_xlabel(info_str) diff --git a/batdetect2/utils/wavfile.py b/batdetect2/utils/wavfile.py index 7fee660..2a8232f 100644 --- a/batdetect2/utils/wavfile.py +++ b/batdetect2/utils/wavfile.py @@ -8,6 +8,7 @@ Functions `write`: Write a numpy array as a WAV file. """ + from __future__ import absolute_import, division, print_function import os @@ -156,7 +157,6 @@ def read(filename, mmap=False): fid = open(filename, "rb") try: - # some files seem to have the size recorded in the header greater than # the actual file size. fid.seek(0, os.SEEK_END) diff --git a/scripts/gen_dataset_summary_image.py b/scripts/gen_dataset_summary_image.py index a916900..3e0a26b 100644 --- a/scripts/gen_dataset_summary_image.py +++ b/scripts/gen_dataset_summary_image.py @@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu import batdetect2.utils.audio_utils as au if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument( "audio_path", type=str, help="Input directory for audio" @@ -65,7 +64,9 @@ if __name__ == "__main__": else: # load uk data - special case print("\nLoading:", args["uk_split"], "\n") - dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same + dataset_name = ( + "uk_" + args["uk_split"] + ) # should be uk_diff, or uk_same datasets, _ = ts.get_train_test_data( args["ann_file"], args["audio_path"], diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py index 2249b58..490cad3 100644 --- a/scripts/gen_spec_image.py +++ b/scripts/gen_spec_image.py @@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("audio_file", type=str, help="Path to audio file") parser.add_argument("model_path", type=str, help="Path to BatDetect model") @@ -143,7 +142,9 @@ if __name__ == "__main__": # run model and filter detections so only keep ones in relevant time range device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - results = du.process_file(args_cmd["audio_file"], model, run_config, device) + results = du.process_file( + args_cmd["audio_file"], model, run_config, device + ) pred_anns = filter_anns( results["pred_dict"]["annotation"], args_cmd["start_time"], diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py index 9636cae..625ba1a 100644 --- a/scripts/gen_spec_video.py +++ b/scripts/gen_spec_video.py @@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("audio_file", type=str, help="Path to input audio file") + parser.add_argument( + "audio_file", type=str, help="Path to input audio file" + ) parser.add_argument( "model_path", type=str, help="Path to trained BatDetect model" ) diff --git a/scripts/viz_helpers.py b/scripts/viz_helpers.py index 13d09b6..4d86283 100644 --- a/scripts/viz_helpers.py +++ b/scripts/viz_helpers.py @@ -198,7 +198,6 @@ def save_summary_image( ) ii = 0 for row in ax: - if type(row) != np.ndarray: row = np.array([row]) @@ -215,7 +214,9 @@ def save_summary_image( ) col.grid(color="w", alpha=0.3, linewidth=0.3) col.set_xticks([]) - col.title.set_text(str(ii + 1) + " " + species_names[order[ii]]) + col.title.set_text( + str(ii + 1) + " " + species_names[order[ii]] + ) col.tick_params(axis="both", which="major", labelsize=7) ii += 1