mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Format code
This commit is contained in:
parent
904e8f23ea
commit
150305a273
@ -1,4 +1,5 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
@ -219,7 +220,6 @@ def compute_call_interval(
|
||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||
|
||||
|
||||
|
||||
# NOTE: The order of the features in this dictionary is important. The
|
||||
# features are extracted in this order and the order of the columns in the
|
||||
# output csv file is determined by this order. In order to avoid breaking
|
||||
|
@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
|
||||
num_filts // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
|
@ -5,7 +5,10 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
|
||||
from batdetect2.train.train_utils import (
|
||||
get_genus_mapping,
|
||||
get_short_class_names,
|
||||
)
|
||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -20,7 +20,6 @@ from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def get_blank_annotation(ip_str):
|
||||
|
||||
res = {}
|
||||
res["class_name"] = ""
|
||||
res["duration"] = -1
|
||||
@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names):
|
||||
|
||||
|
||||
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
||||
|
||||
res, ann = get_blank_annotation("Generated by Tadarida")
|
||||
|
||||
# create the annotations in the correct format
|
||||
@ -120,7 +118,6 @@ def load_sonobat_meta(
|
||||
class_names,
|
||||
only_accepted_species=True,
|
||||
):
|
||||
|
||||
sp_dict = {}
|
||||
for ss in class_names:
|
||||
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
||||
@ -182,7 +179,6 @@ def load_sonobat_meta(
|
||||
|
||||
|
||||
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
# create the annotations in the correct format
|
||||
res, ann = get_blank_annotation("Generated by Sonobat")
|
||||
res_c = copy.deepcopy(res)
|
||||
@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
|
||||
def bb_overlap(bb_g_in, bb_p_in):
|
||||
|
||||
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
||||
bb_g = [
|
||||
bb_g_in["start_time"],
|
||||
@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"op_dir",
|
||||
|
@ -13,5 +13,3 @@ else:
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
||||
|
@ -13,5 +13,3 @@ else:
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Module for postprocessing model outputs."""
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -30,6 +30,16 @@ class PostprocessConfig(BaseModel):
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
class RawPrediction(BaseModel):
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
detection_score: float
|
||||
class_scores: Dict[str, float]
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
@ -125,6 +135,88 @@ def postprocess_model_outputs(
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_predictions_from_outputs(
|
||||
start: float,
|
||||
end: float,
|
||||
scores: torch.Tensor,
|
||||
y_pos: torch.Tensor,
|
||||
x_pos: torch.Tensor,
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
classes: List[str],
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
) -> List[RawPrediction]:
|
||||
_, freq_bins, time_bins = size_preds.shape
|
||||
|
||||
sorted_indices = torch.argsort(x_pos)
|
||||
valid_indices = sorted_indices[
|
||||
scores[sorted_indices] > detection_threshold
|
||||
]
|
||||
|
||||
scores = scores[valid_indices]
|
||||
x_pos = x_pos[valid_indices]
|
||||
y_pos = y_pos[valid_indices]
|
||||
|
||||
predictions: List[RawPrediction] = []
|
||||
for score, x, y in zip(scores, x_pos, y_pos):
|
||||
width, height = size_preds[:, y, x]
|
||||
class_prob = class_probs[:, y, x].detach().numpy()
|
||||
feats = features[:, y, x].detach().numpy()
|
||||
|
||||
start_time = np.interp(
|
||||
x.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
end_time = np.interp(
|
||||
x.item() + width.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
low_freq = np.interp(
|
||||
y.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
high_freq = np.interp(
|
||||
y.item() - height.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||
|
||||
predictions.append(
|
||||
RawPrediction(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
detection_score=score.item(),
|
||||
features=feats,
|
||||
class_scores={
|
||||
class_name: prob
|
||||
for class_name, prob in zip(classes, class_prob)
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def convert_raw_prediction_to_soundevent(
|
||||
prediction: RawPrediction,
|
||||
) -> data.SoundEventPrediction:
|
||||
pass
|
||||
|
||||
|
||||
def compute_sound_events_from_outputs(
|
||||
clip: data.Clip,
|
||||
scores: torch.Tensor,
|
||||
|
@ -1,10 +1,13 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
|
||||
from batdetect2.train.targets import build_encoder, get_class_names
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
@ -48,6 +51,13 @@ def match_predictions_and_annotations(
|
||||
return matches
|
||||
|
||||
|
||||
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
||||
ret = []
|
||||
|
||||
for match in matches:
|
||||
pass
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
# classification error
|
||||
pred_int = (pred > prob).astype(np.int32)
|
||||
|
@ -1,8 +1,10 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as L
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
@ -19,7 +21,7 @@ from batdetect2.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.train.losses import LossConfig, compute_loss
|
||||
@ -146,12 +148,14 @@ class DetectorModel(L.LightningModule):
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
self.validation_predictions.extend(
|
||||
match_predictions_and_annotations(clip_annotation, clip_prediction)
|
||||
matches = match_predictions_and_annotations(
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
)
|
||||
|
||||
self.validation_predictions.extend(matches)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
print(len(self.validation_predictions))
|
||||
self.validation_predictions.clear()
|
||||
|
||||
def configure_optimizers(self):
|
||||
@ -159,3 +163,23 @@ class DetectorModel(L.LightningModule):
|
||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def process_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[Path] = None,
|
||||
) -> data.ClipPrediction:
|
||||
spec = preprocess_audio_clip(
|
||||
clip,
|
||||
config=self.config.preprocessing,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||
outputs = self.forward(tensor)
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
||||
|
||||
|
||||
def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
||||
@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
|
||||
def split_same(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
||||
|
@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
|
||||
|
||||
|
||||
def standardize_low_freq(
|
||||
data: List[types.FileAnnotation], class_of_interest: str,
|
||||
data: List[types.FileAnnotation],
|
||||
class_of_interest: str,
|
||||
) -> List[types.FileAnnotation]:
|
||||
# address the issue of highly variable low frequency annotations
|
||||
# this often happens for contstant frequency calls
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
|
||||
@ -594,8 +595,7 @@ class FeatureExtractor(Protocol):
|
||||
self,
|
||||
prediction: Prediction,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
...
|
||||
) -> float: ...
|
||||
|
||||
|
||||
class DatasetDict(TypedDict):
|
||||
|
@ -100,7 +100,7 @@ def generate_spectrogram(
|
||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling * spec)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec , sampling_rate)
|
||||
spec = pcen(spec, sampling_rate)
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
|
@ -417,7 +417,9 @@ def plot_confusion_matrix(
|
||||
cm_norm = cm.sum(1)
|
||||
|
||||
valid_inds = np.where(cm_norm > 0)[0]
|
||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
cm[valid_inds, :] = (
|
||||
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
)
|
||||
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
||||
|
||||
if verbose:
|
||||
|
@ -155,9 +155,9 @@ class InteractivePlotter:
|
||||
|
||||
# draw bounding box around call
|
||||
self.ax[1].patches[0].remove()
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
||||
1.0 + 2.0 * self.spec_pad
|
||||
)
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[
|
||||
1
|
||||
] / (1.0 + 2.0 * self.spec_pad)
|
||||
xx = w_diff + self.spec_pad * spec_width_orig
|
||||
ww = spec_width_orig
|
||||
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
||||
@ -183,7 +183,9 @@ class InteractivePlotter:
|
||||
round(self.call_info[self.current_id]["start_time"], 3)
|
||||
)
|
||||
+ ", prob="
|
||||
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
||||
+ str(
|
||||
round(self.call_info[self.current_id]["det_prob"], 3)
|
||||
)
|
||||
)
|
||||
self.ax[0].set_xlabel(info_str)
|
||||
|
||||
|
@ -8,6 +8,7 @@ Functions
|
||||
`write`: Write a numpy array as a WAV file.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
@ -156,7 +157,6 @@ def read(filename, mmap=False):
|
||||
fid = open(filename, "rb")
|
||||
|
||||
try:
|
||||
|
||||
# some files seem to have the size recorded in the header greater than
|
||||
# the actual file size.
|
||||
fid.seek(0, os.SEEK_END)
|
||||
|
@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.audio_utils as au
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"audio_path", type=str, help="Input directory for audio"
|
||||
@ -65,7 +64,9 @@ if __name__ == "__main__":
|
||||
else:
|
||||
# load uk data - special case
|
||||
print("\nLoading:", args["uk_split"], "\n")
|
||||
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
|
||||
dataset_name = (
|
||||
"uk_" + args["uk_split"]
|
||||
) # should be uk_diff, or uk_same
|
||||
datasets, _ = ts.get_train_test_data(
|
||||
args["ann_file"],
|
||||
args["audio_path"],
|
||||
|
@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
||||
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
||||
@ -143,7 +142,9 @@ if __name__ == "__main__":
|
||||
|
||||
# run model and filter detections so only keep ones in relevant time range
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
|
||||
results = du.process_file(
|
||||
args_cmd["audio_file"], model, run_config, device
|
||||
)
|
||||
pred_anns = filter_anns(
|
||||
results["pred_dict"]["annotation"],
|
||||
args_cmd["start_time"],
|
||||
|
@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
||||
parser.add_argument(
|
||||
"audio_file", type=str, help="Path to input audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_path", type=str, help="Path to trained BatDetect model"
|
||||
)
|
||||
|
@ -198,7 +198,6 @@ def save_summary_image(
|
||||
)
|
||||
ii = 0
|
||||
for row in ax:
|
||||
|
||||
if type(row) != np.ndarray:
|
||||
row = np.array([row])
|
||||
|
||||
@ -215,7 +214,9 @@ def save_summary_image(
|
||||
)
|
||||
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
||||
col.set_xticks([])
|
||||
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
|
||||
col.title.set_text(
|
||||
str(ii + 1) + " " + species_names[order[ii]]
|
||||
)
|
||||
col.tick_params(axis="both", which="major", labelsize=7)
|
||||
ii += 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user