Format code

This commit is contained in:
mbsantiago 2025-02-25 11:25:16 +00:00
parent 904e8f23ea
commit 150305a273
22 changed files with 170 additions and 39 deletions

View File

@ -1,4 +1,5 @@
"""Functions to compute features from predictions."""
from typing import Dict, List, Optional
import numpy as np
@ -219,7 +220,6 @@ def compute_call_interval(
return round(prediction["start_time"] - previous["end_time"], 5)
# NOTE: The order of the features in this dictionary is important. The
# features are extracted in this order and the order of the columns in the
# output csv file is determined by this order. In order to avoid breaking

View File

@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module):
num_filts // 4, 2, kernel_size=1, padding=0
)
self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
num_filts // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
)
if self.emb_dim > 0:

View File

@ -5,7 +5,10 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field, computed_field
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
from batdetect2.train.train_utils import (
get_genus_mapping,
get_short_class_names,
)
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
TARGET_SAMPLERATE_HZ = 256000

View File

@ -1,4 +1,5 @@
"""Post-processing of the output of the model."""
from typing import List, Tuple, Union
import numpy as np

View File

@ -20,7 +20,6 @@ from batdetect2.detector import parameters
def get_blank_annotation(ip_str):
res = {}
res["class_name"] = ""
res["duration"] = -1
@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names):
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
res, ann = get_blank_annotation("Generated by Tadarida")
# create the annotations in the correct format
@ -120,7 +118,6 @@ def load_sonobat_meta(
class_names,
only_accepted_species=True,
):
sp_dict = {}
for ss in class_names:
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
@ -182,7 +179,6 @@ def load_sonobat_meta(
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
# create the annotations in the correct format
res, ann = get_blank_annotation("Generated by Sonobat")
res_c = copy.deepcopy(res)
@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
def bb_overlap(bb_g_in, bb_p_in):
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
bb_g = [
bb_g_in["start_time"],
@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"op_dir",

View File

@ -13,5 +13,3 @@ else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -13,5 +13,3 @@ else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -1,6 +1,6 @@
"""Module for postprocessing model outputs."""
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -30,6 +30,16 @@ class PostprocessConfig(BaseModel):
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
class RawPrediction(BaseModel):
start_time: float
end_time: float
low_freq: float
high_freq: float
detection_score: float
class_scores: Dict[str, float]
features: np.ndarray
def postprocess_model_outputs(
outputs: ModelOutput,
clips: List[data.Clip],
@ -125,6 +135,88 @@ def postprocess_model_outputs(
return predictions
def compute_predictions_from_outputs(
start: float,
end: float,
scores: torch.Tensor,
y_pos: torch.Tensor,
x_pos: torch.Tensor,
size_preds: torch.Tensor,
class_probs: torch.Tensor,
features: torch.Tensor,
classes: List[str],
min_freq: int = 10000,
max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD,
) -> List[RawPrediction]:
_, freq_bins, time_bins = size_preds.shape
sorted_indices = torch.argsort(x_pos)
valid_indices = sorted_indices[
scores[sorted_indices] > detection_threshold
]
scores = scores[valid_indices]
x_pos = x_pos[valid_indices]
y_pos = y_pos[valid_indices]
predictions: List[RawPrediction] = []
for score, x, y in zip(scores, x_pos, y_pos):
width, height = size_preds[:, y, x]
class_prob = class_probs[:, y, x].detach().numpy()
feats = features[:, y, x].detach().numpy()
start_time = np.interp(
x.item(),
[0, time_bins],
[start, end],
)
end_time = np.interp(
x.item() + width.item(),
[0, time_bins],
[start, end],
)
low_freq = np.interp(
y.item(),
[0, freq_bins],
[max_freq, min_freq],
)
high_freq = np.interp(
y.item() - height.item(),
[0, freq_bins],
[max_freq, min_freq],
)
start_time, end_time = sorted([float(start_time), float(end_time)])
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
predictions.append(
RawPrediction(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
detection_score=score.item(),
features=feats,
class_scores={
class_name: prob
for class_name, prob in zip(classes, class_prob)
},
)
)
return predictions
def convert_raw_prediction_to_soundevent(
prediction: RawPrediction,
) -> data.SoundEventPrediction:
pass
def compute_sound_events_from_outputs(
clip: data.Clip,
scores: torch.Tensor,

View File

@ -1,10 +1,13 @@
from typing import List
import numpy as np
import pandas as pd
from sklearn.metrics import auc, roc_curve
from soundevent import data
from soundevent.evaluation import match_geometries
from batdetect2.train.targets import build_encoder, get_class_names
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
@ -48,6 +51,13 @@ def match_predictions_and_annotations(
return matches
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
ret = []
for match in matches:
pass
def compute_error_auc(op_str, gt, pred, prob):
# classification error
pred_int = (pred > prob).astype(np.int32)

View File

@ -1,8 +1,10 @@
from pathlib import Path
from typing import Optional
import pytorch_lightning as L
import torch
from pydantic import Field
from soundevent import data
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
@ -19,7 +21,7 @@ from batdetect2.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.evaluate import match_predictions_and_annotations
from batdetect2.train.losses import LossConfig, compute_loss
@ -146,12 +148,14 @@ class DetectorModel(L.LightningModule):
config=self.config.postprocessing,
)[0]
self.validation_predictions.extend(
match_predictions_and_annotations(clip_annotation, clip_prediction)
matches = match_predictions_and_annotations(
clip_annotation,
clip_prediction,
)
self.validation_predictions.extend(matches)
def on_validation_epoch_end(self) -> None:
print(len(self.validation_predictions))
self.validation_predictions.clear()
def configure_optimizers(self):
@ -159,3 +163,23 @@ class DetectorModel(L.LightningModule):
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
return [optimizer], [scheduler]
def process_clip(
self,
clip: data.Clip,
audio_dir: Optional[Path] = None,
) -> data.ClipPrediction:
spec = preprocess_audio_clip(
clip,
config=self.config.preprocessing,
audio_dir=audio_dir,
)
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
outputs = self.forward(tensor)
return postprocess_model_outputs(
outputs,
clips=[clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]

View File

@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
def split_diff(ann_dir, wav_dir, load_extra=True):
train_sets = []
if load_extra:
train_sets.append(
@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
def split_same(ann_dir, wav_dir, load_extra=True):
train_sets = []
if load_extra:
train_sets.append(

View File

@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
def standardize_low_freq(
data: List[types.FileAnnotation], class_of_interest: str,
data: List[types.FileAnnotation],
class_of_interest: str,
) -> List[types.FileAnnotation]:
# address the issue of highly variable low frequency annotations
# this often happens for contstant frequency calls

View File

@ -1,4 +1,5 @@
"""Types used in the code base."""
from typing import Any, List, NamedTuple, Optional
@ -594,8 +595,7 @@ class FeatureExtractor(Protocol):
self,
prediction: Prediction,
**kwargs: Any,
) -> float:
...
) -> float: ...
class DatasetDict(TypedDict):

View File

@ -100,7 +100,7 @@ def generate_spectrogram(
# log_scaling = (1.0 / sampling_rate)*10e4
spec = np.log1p(log_scaling * spec)
elif params["spec_scale"] == "pcen":
spec = pcen(spec , sampling_rate)
spec = pcen(spec, sampling_rate)
elif params["spec_scale"] == "none":
pass

View File

@ -417,7 +417,9 @@ def plot_confusion_matrix(
cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0]
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
cm[valid_inds, :] = (
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
)
cm[np.where(cm_norm == -0)[0], :] = np.nan
if verbose:

View File

@ -155,9 +155,9 @@ class InteractivePlotter:
# draw bounding box around call
self.ax[1].patches[0].remove()
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
1.0 + 2.0 * self.spec_pad
)
spec_width_orig = self.spec_slices[self.current_id].shape[
1
] / (1.0 + 2.0 * self.spec_pad)
xx = w_diff + self.spec_pad * spec_width_orig
ww = spec_width_orig
yy = self.call_info[self.current_id]["low_freq"] / 1000
@ -183,7 +183,9 @@ class InteractivePlotter:
round(self.call_info[self.current_id]["start_time"], 3)
)
+ ", prob="
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
+ str(
round(self.call_info[self.current_id]["det_prob"], 3)
)
)
self.ax[0].set_xlabel(info_str)

View File

@ -8,6 +8,7 @@ Functions
`write`: Write a numpy array as a WAV file.
"""
from __future__ import absolute_import, division, print_function
import os
@ -156,7 +157,6 @@ def read(filename, mmap=False):
fid = open(filename, "rb")
try:
# some files seem to have the size recorded in the header greater than
# the actual file size.
fid.seek(0, os.SEEK_END)

View File

@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu
import batdetect2.utils.audio_utils as au
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"audio_path", type=str, help="Input directory for audio"
@ -65,7 +64,9 @@ if __name__ == "__main__":
else:
# load uk data - special case
print("\nLoading:", args["uk_split"], "\n")
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
dataset_name = (
"uk_" + args["uk_split"]
) # should be uk_diff, or uk_same
datasets, _ = ts.get_train_test_data(
args["ann_file"],
args["audio_path"],

View File

@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to audio file")
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
@ -143,7 +142,9 @@ if __name__ == "__main__":
# run model and filter detections so only keep ones in relevant time range
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
results = du.process_file(
args_cmd["audio_file"], model, run_config, device
)
pred_anns = filter_anns(
results["pred_dict"]["annotation"],
args_cmd["start_time"],

View File

@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to input audio file")
parser.add_argument(
"audio_file", type=str, help="Path to input audio file"
)
parser.add_argument(
"model_path", type=str, help="Path to trained BatDetect model"
)

View File

@ -198,7 +198,6 @@ def save_summary_image(
)
ii = 0
for row in ax:
if type(row) != np.ndarray:
row = np.array([row])
@ -215,7 +214,9 @@ def save_summary_image(
)
col.grid(color="w", alpha=0.3, linewidth=0.3)
col.set_xticks([])
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
col.title.set_text(
str(ii + 1) + " " + species_names[order[ii]]
)
col.tick_params(axis="both", which="major", labelsize=7)
ii += 1