Format code

This commit is contained in:
mbsantiago 2025-02-25 11:25:16 +00:00
parent 904e8f23ea
commit 150305a273
22 changed files with 170 additions and 39 deletions

View File

@ -1,4 +1,5 @@
"""Functions to compute features from predictions.""" """Functions to compute features from predictions."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np import numpy as np
@ -219,7 +220,6 @@ def compute_call_interval(
return round(prediction["start_time"] - previous["end_time"], 5) return round(prediction["start_time"] - previous["end_time"], 5)
# NOTE: The order of the features in this dictionary is important. The # NOTE: The order of the features in this dictionary is important. The
# features are extracted in this order and the order of the columns in the # features are extracted in this order and the order of the columns in the
# output csv file is determined by this order. In order to avoid breaking # output csv file is determined by this order. In order to avoid breaking

View File

@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module):
num_filts // 4, 2, kernel_size=1, padding=0 num_filts // 4, 2, kernel_size=1, padding=0
) )
self.conv_classes_op = nn.Conv2d( self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0, num_filts // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
) )
if self.emb_dim > 0: if self.emb_dim > 0:

View File

@ -5,7 +5,10 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field, computed_field from pydantic import BaseModel, Field, computed_field
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names from batdetect2.train.train_utils import (
get_genus_mapping,
get_short_class_names,
)
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
TARGET_SAMPLERATE_HZ = 256000 TARGET_SAMPLERATE_HZ = 256000

View File

@ -1,4 +1,5 @@
"""Post-processing of the output of the model.""" """Post-processing of the output of the model."""
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np import numpy as np

View File

@ -20,7 +20,6 @@ from batdetect2.detector import parameters
def get_blank_annotation(ip_str): def get_blank_annotation(ip_str):
res = {} res = {}
res["class_name"] = "" res["class_name"] = ""
res["duration"] = -1 res["duration"] = -1
@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names):
def load_tadarida_pred(ip_dir, dataset, file_of_interest): def load_tadarida_pred(ip_dir, dataset, file_of_interest):
res, ann = get_blank_annotation("Generated by Tadarida") res, ann = get_blank_annotation("Generated by Tadarida")
# create the annotations in the correct format # create the annotations in the correct format
@ -120,7 +118,6 @@ def load_sonobat_meta(
class_names, class_names,
only_accepted_species=True, only_accepted_species=True,
): ):
sp_dict = {} sp_dict = {}
for ss in class_names: for ss in class_names:
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3] sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
@ -182,7 +179,6 @@ def load_sonobat_meta(
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
# create the annotations in the correct format # create the annotations in the correct format
res, ann = get_blank_annotation("Generated by Sonobat") res, ann = get_blank_annotation("Generated by Sonobat")
res_c = copy.deepcopy(res) res_c = copy.deepcopy(res)
@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
def bb_overlap(bb_g_in, bb_p_in): def bb_overlap(bb_g_in, bb_p_in):
freq_scale = 10000000.0 # ensure that both axis are roughly the same range freq_scale = 10000000.0 # ensure that both axis are roughly the same range
bb_g = [ bb_g = [
bb_g_in["start_time"], bb_g_in["start_time"],
@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"op_dir", "op_dir",

View File

@ -13,5 +13,3 @@ else:
def pairwise(iterable: Sequence) -> Iterable: def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]): for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y yield x, y

View File

@ -13,5 +13,3 @@ else:
def pairwise(iterable: Sequence) -> Iterable: def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]): for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y yield x, y

View File

@ -1,6 +1,6 @@
"""Module for postprocessing model outputs.""" """Module for postprocessing model outputs."""
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -30,6 +30,16 @@ class PostprocessConfig(BaseModel):
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
class RawPrediction(BaseModel):
start_time: float
end_time: float
low_freq: float
high_freq: float
detection_score: float
class_scores: Dict[str, float]
features: np.ndarray
def postprocess_model_outputs( def postprocess_model_outputs(
outputs: ModelOutput, outputs: ModelOutput,
clips: List[data.Clip], clips: List[data.Clip],
@ -125,6 +135,88 @@ def postprocess_model_outputs(
return predictions return predictions
def compute_predictions_from_outputs(
start: float,
end: float,
scores: torch.Tensor,
y_pos: torch.Tensor,
x_pos: torch.Tensor,
size_preds: torch.Tensor,
class_probs: torch.Tensor,
features: torch.Tensor,
classes: List[str],
min_freq: int = 10000,
max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD,
) -> List[RawPrediction]:
_, freq_bins, time_bins = size_preds.shape
sorted_indices = torch.argsort(x_pos)
valid_indices = sorted_indices[
scores[sorted_indices] > detection_threshold
]
scores = scores[valid_indices]
x_pos = x_pos[valid_indices]
y_pos = y_pos[valid_indices]
predictions: List[RawPrediction] = []
for score, x, y in zip(scores, x_pos, y_pos):
width, height = size_preds[:, y, x]
class_prob = class_probs[:, y, x].detach().numpy()
feats = features[:, y, x].detach().numpy()
start_time = np.interp(
x.item(),
[0, time_bins],
[start, end],
)
end_time = np.interp(
x.item() + width.item(),
[0, time_bins],
[start, end],
)
low_freq = np.interp(
y.item(),
[0, freq_bins],
[max_freq, min_freq],
)
high_freq = np.interp(
y.item() - height.item(),
[0, freq_bins],
[max_freq, min_freq],
)
start_time, end_time = sorted([float(start_time), float(end_time)])
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
predictions.append(
RawPrediction(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
detection_score=score.item(),
features=feats,
class_scores={
class_name: prob
for class_name, prob in zip(classes, class_prob)
},
)
)
return predictions
def convert_raw_prediction_to_soundevent(
prediction: RawPrediction,
) -> data.SoundEventPrediction:
pass
def compute_sound_events_from_outputs( def compute_sound_events_from_outputs(
clip: data.Clip, clip: data.Clip,
scores: torch.Tensor, scores: torch.Tensor,

View File

@ -1,10 +1,13 @@
from typing import List from typing import List
import numpy as np import numpy as np
import pandas as pd
from sklearn.metrics import auc, roc_curve from sklearn.metrics import auc, roc_curve
from soundevent import data from soundevent import data
from soundevent.evaluation import match_geometries from soundevent.evaluation import match_geometries
from batdetect2.train.targets import build_encoder, get_class_names
def match_predictions_and_annotations( def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
@ -48,6 +51,13 @@ def match_predictions_and_annotations(
return matches return matches
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
ret = []
for match in matches:
pass
def compute_error_auc(op_str, gt, pred, prob): def compute_error_auc(op_str, gt, pred, prob):
# classification error # classification error
pred_int = (pred > prob).astype(np.int32) pred_int = (pred > prob).astype(np.int32)

View File

@ -1,8 +1,10 @@
from pathlib import Path
from typing import Optional from typing import Optional
import pytorch_lightning as L import pytorch_lightning as L
import torch import torch
from pydantic import Field from pydantic import Field
from soundevent import data
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -19,7 +21,7 @@ from batdetect2.post_process import (
PostprocessConfig, PostprocessConfig,
postprocess_model_outputs, postprocess_model_outputs,
) )
from batdetect2.preprocess import PreprocessingConfig from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.evaluate import match_predictions_and_annotations from batdetect2.train.evaluate import match_predictions_and_annotations
from batdetect2.train.losses import LossConfig, compute_loss from batdetect2.train.losses import LossConfig, compute_loss
@ -146,12 +148,14 @@ class DetectorModel(L.LightningModule):
config=self.config.postprocessing, config=self.config.postprocessing,
)[0] )[0]
self.validation_predictions.extend( matches = match_predictions_and_annotations(
match_predictions_and_annotations(clip_annotation, clip_prediction) clip_annotation,
clip_prediction,
) )
self.validation_predictions.extend(matches)
def on_validation_epoch_end(self) -> None: def on_validation_epoch_end(self) -> None:
print(len(self.validation_predictions))
self.validation_predictions.clear() self.validation_predictions.clear()
def configure_optimizers(self): def configure_optimizers(self):
@ -159,3 +163,23 @@ class DetectorModel(L.LightningModule):
optimizer = Adam(self.parameters(), lr=conf.learning_rate) optimizer = Adam(self.parameters(), lr=conf.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max) scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
return [optimizer], [scheduler] return [optimizer], [scheduler]
def process_clip(
self,
clip: data.Clip,
audio_dir: Optional[Path] = None,
) -> data.ClipPrediction:
spec = preprocess_audio_clip(
clip,
config=self.config.preprocessing,
audio_dir=audio_dir,
)
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
outputs = self.forward(tensor)
return postprocess_model_outputs(
outputs,
clips=[clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]

View File

@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
def split_diff(ann_dir, wav_dir, load_extra=True): def split_diff(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append( train_sets.append(
@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
def split_same(ann_dir, wav_dir, load_extra=True): def split_same(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append( train_sets.append(

View File

@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
def standardize_low_freq( def standardize_low_freq(
data: List[types.FileAnnotation], class_of_interest: str, data: List[types.FileAnnotation],
class_of_interest: str,
) -> List[types.FileAnnotation]: ) -> List[types.FileAnnotation]:
# address the issue of highly variable low frequency annotations # address the issue of highly variable low frequency annotations
# this often happens for contstant frequency calls # this often happens for contstant frequency calls

View File

@ -1,4 +1,5 @@
"""Types used in the code base.""" """Types used in the code base."""
from typing import Any, List, NamedTuple, Optional from typing import Any, List, NamedTuple, Optional
@ -594,8 +595,7 @@ class FeatureExtractor(Protocol):
self, self,
prediction: Prediction, prediction: Prediction,
**kwargs: Any, **kwargs: Any,
) -> float: ) -> float: ...
...
class DatasetDict(TypedDict): class DatasetDict(TypedDict):

View File

@ -417,7 +417,9 @@ def plot_confusion_matrix(
cm_norm = cm.sum(1) cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0] valid_inds = np.where(cm_norm > 0)[0]
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] cm[valid_inds, :] = (
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
)
cm[np.where(cm_norm == -0)[0], :] = np.nan cm[np.where(cm_norm == -0)[0], :] = np.nan
if verbose: if verbose:

View File

@ -155,9 +155,9 @@ class InteractivePlotter:
# draw bounding box around call # draw bounding box around call
self.ax[1].patches[0].remove() self.ax[1].patches[0].remove()
spec_width_orig = self.spec_slices[self.current_id].shape[1] / ( spec_width_orig = self.spec_slices[self.current_id].shape[
1.0 + 2.0 * self.spec_pad 1
) ] / (1.0 + 2.0 * self.spec_pad)
xx = w_diff + self.spec_pad * spec_width_orig xx = w_diff + self.spec_pad * spec_width_orig
ww = spec_width_orig ww = spec_width_orig
yy = self.call_info[self.current_id]["low_freq"] / 1000 yy = self.call_info[self.current_id]["low_freq"] / 1000
@ -183,7 +183,9 @@ class InteractivePlotter:
round(self.call_info[self.current_id]["start_time"], 3) round(self.call_info[self.current_id]["start_time"], 3)
) )
+ ", prob=" + ", prob="
+ str(round(self.call_info[self.current_id]["det_prob"], 3)) + str(
round(self.call_info[self.current_id]["det_prob"], 3)
)
) )
self.ax[0].set_xlabel(info_str) self.ax[0].set_xlabel(info_str)

View File

@ -8,6 +8,7 @@ Functions
`write`: Write a numpy array as a WAV file. `write`: Write a numpy array as a WAV file.
""" """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os import os
@ -156,7 +157,6 @@ def read(filename, mmap=False):
fid = open(filename, "rb") fid = open(filename, "rb")
try: try:
# some files seem to have the size recorded in the header greater than # some files seem to have the size recorded in the header greater than
# the actual file size. # the actual file size.
fid.seek(0, os.SEEK_END) fid.seek(0, os.SEEK_END)

View File

@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu
import batdetect2.utils.audio_utils as au import batdetect2.utils.audio_utils as au
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"audio_path", type=str, help="Input directory for audio" "audio_path", type=str, help="Input directory for audio"
@ -65,7 +64,9 @@ if __name__ == "__main__":
else: else:
# load uk data - special case # load uk data - special case
print("\nLoading:", args["uk_split"], "\n") print("\nLoading:", args["uk_split"], "\n")
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same dataset_name = (
"uk_" + args["uk_split"]
) # should be uk_diff, or uk_same
datasets, _ = ts.get_train_test_data( datasets, _ = ts.get_train_test_data(
args["ann_file"], args["ann_file"],
args["audio_path"], args["audio_path"],

View File

@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to audio file") parser.add_argument("audio_file", type=str, help="Path to audio file")
parser.add_argument("model_path", type=str, help="Path to BatDetect model") parser.add_argument("model_path", type=str, help="Path to BatDetect model")
@ -143,7 +142,9 @@ if __name__ == "__main__":
# run model and filter detections so only keep ones in relevant time range # run model and filter detections so only keep ones in relevant time range
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = du.process_file(args_cmd["audio_file"], model, run_config, device) results = du.process_file(
args_cmd["audio_file"], model, run_config, device
)
pred_anns = filter_anns( pred_anns = filter_anns(
results["pred_dict"]["annotation"], results["pred_dict"]["annotation"],
args_cmd["start_time"], args_cmd["start_time"],

View File

@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to input audio file") parser.add_argument(
"audio_file", type=str, help="Path to input audio file"
)
parser.add_argument( parser.add_argument(
"model_path", type=str, help="Path to trained BatDetect model" "model_path", type=str, help="Path to trained BatDetect model"
) )

View File

@ -198,7 +198,6 @@ def save_summary_image(
) )
ii = 0 ii = 0
for row in ax: for row in ax:
if type(row) != np.ndarray: if type(row) != np.ndarray:
row = np.array([row]) row = np.array([row])
@ -215,7 +214,9 @@ def save_summary_image(
) )
col.grid(color="w", alpha=0.3, linewidth=0.3) col.grid(color="w", alpha=0.3, linewidth=0.3)
col.set_xticks([]) col.set_xticks([])
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]]) col.title.set_text(
str(ii + 1) + " " + species_names[order[ii]]
)
col.tick_params(axis="both", which="major", labelsize=7) col.tick_params(axis="both", which="major", labelsize=7)
ii += 1 ii += 1