mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Format code
This commit is contained in:
parent
904e8f23ea
commit
150305a273
@ -1,4 +1,5 @@
|
|||||||
"""Functions to compute features from predictions."""
|
"""Functions to compute features from predictions."""
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -219,7 +220,6 @@ def compute_call_interval(
|
|||||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: The order of the features in this dictionary is important. The
|
# NOTE: The order of the features in this dictionary is important. The
|
||||||
# features are extracted in this order and the order of the columns in the
|
# features are extracted in this order and the order of the columns in the
|
||||||
# output csv file is determined by this order. In order to avoid breaking
|
# output csv file is determined by this order. In order to avoid breaking
|
||||||
|
@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
num_filts // 4, 2, kernel_size=1, padding=0
|
num_filts // 4, 2, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
self.conv_classes_op = nn.Conv2d(
|
self.conv_classes_op = nn.Conv2d(
|
||||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
|
num_filts // 4,
|
||||||
|
self.num_classes + 1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.emb_dim > 0:
|
if self.emb_dim > 0:
|
||||||
|
@ -5,7 +5,10 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, computed_field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
|
from batdetect2.train.train_utils import (
|
||||||
|
get_genus_mapping,
|
||||||
|
get_short_class_names,
|
||||||
|
)
|
||||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256000
|
TARGET_SAMPLERATE_HZ = 256000
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Post-processing of the output of the model."""
|
"""Post-processing of the output of the model."""
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -20,7 +20,6 @@ from batdetect2.detector import parameters
|
|||||||
|
|
||||||
|
|
||||||
def get_blank_annotation(ip_str):
|
def get_blank_annotation(ip_str):
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
res["class_name"] = ""
|
res["class_name"] = ""
|
||||||
res["duration"] = -1
|
res["duration"] = -1
|
||||||
@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names):
|
|||||||
|
|
||||||
|
|
||||||
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
||||||
|
|
||||||
res, ann = get_blank_annotation("Generated by Tadarida")
|
res, ann = get_blank_annotation("Generated by Tadarida")
|
||||||
|
|
||||||
# create the annotations in the correct format
|
# create the annotations in the correct format
|
||||||
@ -120,7 +118,6 @@ def load_sonobat_meta(
|
|||||||
class_names,
|
class_names,
|
||||||
only_accepted_species=True,
|
only_accepted_species=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
sp_dict = {}
|
sp_dict = {}
|
||||||
for ss in class_names:
|
for ss in class_names:
|
||||||
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
||||||
@ -182,7 +179,6 @@ def load_sonobat_meta(
|
|||||||
|
|
||||||
|
|
||||||
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||||
|
|
||||||
# create the annotations in the correct format
|
# create the annotations in the correct format
|
||||||
res, ann = get_blank_annotation("Generated by Sonobat")
|
res, ann = get_blank_annotation("Generated by Sonobat")
|
||||||
res_c = copy.deepcopy(res)
|
res_c = copy.deepcopy(res)
|
||||||
@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
|||||||
|
|
||||||
|
|
||||||
def bb_overlap(bb_g_in, bb_p_in):
|
def bb_overlap(bb_g_in, bb_p_in):
|
||||||
|
|
||||||
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
||||||
bb_g = [
|
bb_g = [
|
||||||
bb_g_in["start_time"],
|
bb_g_in["start_time"],
|
||||||
@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"op_dir",
|
"op_dir",
|
||||||
|
@ -13,5 +13,3 @@ else:
|
|||||||
def pairwise(iterable: Sequence) -> Iterable:
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
yield x, y
|
yield x, y
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,5 +13,3 @@ else:
|
|||||||
def pairwise(iterable: Sequence) -> Iterable:
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
yield x, y
|
yield x, y
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +17,6 @@ def create_ax(
|
|||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||||
|
|
||||||
return ax # type: ignore
|
return ax # type: ignore
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Module for postprocessing model outputs."""
|
"""Module for postprocessing model outputs."""
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -30,6 +30,16 @@ class PostprocessConfig(BaseModel):
|
|||||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class RawPrediction(BaseModel):
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
low_freq: float
|
||||||
|
high_freq: float
|
||||||
|
detection_score: float
|
||||||
|
class_scores: Dict[str, float]
|
||||||
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
def postprocess_model_outputs(
|
def postprocess_model_outputs(
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
clips: List[data.Clip],
|
clips: List[data.Clip],
|
||||||
@ -125,6 +135,88 @@ def postprocess_model_outputs(
|
|||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def compute_predictions_from_outputs(
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
scores: torch.Tensor,
|
||||||
|
y_pos: torch.Tensor,
|
||||||
|
x_pos: torch.Tensor,
|
||||||
|
size_preds: torch.Tensor,
|
||||||
|
class_probs: torch.Tensor,
|
||||||
|
features: torch.Tensor,
|
||||||
|
classes: List[str],
|
||||||
|
min_freq: int = 10000,
|
||||||
|
max_freq: int = 120000,
|
||||||
|
detection_threshold: float = DETECTION_THRESHOLD,
|
||||||
|
) -> List[RawPrediction]:
|
||||||
|
_, freq_bins, time_bins = size_preds.shape
|
||||||
|
|
||||||
|
sorted_indices = torch.argsort(x_pos)
|
||||||
|
valid_indices = sorted_indices[
|
||||||
|
scores[sorted_indices] > detection_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = scores[valid_indices]
|
||||||
|
x_pos = x_pos[valid_indices]
|
||||||
|
y_pos = y_pos[valid_indices]
|
||||||
|
|
||||||
|
predictions: List[RawPrediction] = []
|
||||||
|
for score, x, y in zip(scores, x_pos, y_pos):
|
||||||
|
width, height = size_preds[:, y, x]
|
||||||
|
class_prob = class_probs[:, y, x].detach().numpy()
|
||||||
|
feats = features[:, y, x].detach().numpy()
|
||||||
|
|
||||||
|
start_time = np.interp(
|
||||||
|
x.item(),
|
||||||
|
[0, time_bins],
|
||||||
|
[start, end],
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = np.interp(
|
||||||
|
x.item() + width.item(),
|
||||||
|
[0, time_bins],
|
||||||
|
[start, end],
|
||||||
|
)
|
||||||
|
|
||||||
|
low_freq = np.interp(
|
||||||
|
y.item(),
|
||||||
|
[0, freq_bins],
|
||||||
|
[max_freq, min_freq],
|
||||||
|
)
|
||||||
|
|
||||||
|
high_freq = np.interp(
|
||||||
|
y.item() - height.item(),
|
||||||
|
[0, freq_bins],
|
||||||
|
[max_freq, min_freq],
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||||
|
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||||
|
|
||||||
|
predictions.append(
|
||||||
|
RawPrediction(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
detection_score=score.item(),
|
||||||
|
features=feats,
|
||||||
|
class_scores={
|
||||||
|
class_name: prob
|
||||||
|
for class_name, prob in zip(classes, class_prob)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def convert_raw_prediction_to_soundevent(
|
||||||
|
prediction: RawPrediction,
|
||||||
|
) -> data.SoundEventPrediction:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def compute_sound_events_from_outputs(
|
def compute_sound_events_from_outputs(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
scores: torch.Tensor,
|
scores: torch.Tensor,
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from sklearn.metrics import auc, roc_curve
|
from sklearn.metrics import auc, roc_curve
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_geometries
|
from soundevent.evaluation import match_geometries
|
||||||
|
|
||||||
|
from batdetect2.train.targets import build_encoder, get_class_names
|
||||||
|
|
||||||
|
|
||||||
def match_predictions_and_annotations(
|
def match_predictions_and_annotations(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
@ -48,6 +51,13 @@ def match_predictions_and_annotations(
|
|||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def compute_error_auc(op_str, gt, pred, prob):
|
def compute_error_auc(op_str, gt, pred, prob):
|
||||||
# classification error
|
# classification error
|
||||||
pred_int = (pred > prob).astype(np.int32)
|
pred_int = (pred > prob).astype(np.int32)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytorch_lightning as L
|
import pytorch_lightning as L
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -19,7 +21,7 @@ from batdetect2.post_process import (
|
|||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
postprocess_model_outputs,
|
postprocess_model_outputs,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess import PreprocessingConfig
|
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.train.evaluate import match_predictions_and_annotations
|
from batdetect2.train.evaluate import match_predictions_and_annotations
|
||||||
from batdetect2.train.losses import LossConfig, compute_loss
|
from batdetect2.train.losses import LossConfig, compute_loss
|
||||||
@ -146,12 +148,14 @@ class DetectorModel(L.LightningModule):
|
|||||||
config=self.config.postprocessing,
|
config=self.config.postprocessing,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
self.validation_predictions.extend(
|
matches = match_predictions_and_annotations(
|
||||||
match_predictions_and_annotations(clip_annotation, clip_prediction)
|
clip_annotation,
|
||||||
|
clip_prediction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.validation_predictions.extend(matches)
|
||||||
|
|
||||||
def on_validation_epoch_end(self) -> None:
|
def on_validation_epoch_end(self) -> None:
|
||||||
print(len(self.validation_predictions))
|
|
||||||
self.validation_predictions.clear()
|
self.validation_predictions.clear()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
@ -159,3 +163,23 @@ class DetectorModel(L.LightningModule):
|
|||||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
def process_clip(
|
||||||
|
self,
|
||||||
|
clip: data.Clip,
|
||||||
|
audio_dir: Optional[Path] = None,
|
||||||
|
) -> data.ClipPrediction:
|
||||||
|
spec = preprocess_audio_clip(
|
||||||
|
clip,
|
||||||
|
config=self.config.preprocessing,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||||
|
outputs = self.forward(tensor)
|
||||||
|
return postprocess_model_outputs(
|
||||||
|
outputs,
|
||||||
|
clips=[clip],
|
||||||
|
classes=self.class_names,
|
||||||
|
decoder=self.decoder,
|
||||||
|
config=self.config.postprocessing,
|
||||||
|
)[0]
|
||||||
|
@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
|||||||
|
|
||||||
|
|
||||||
def split_diff(ann_dir, wav_dir, load_extra=True):
|
def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||||
|
|
||||||
train_sets = []
|
train_sets = []
|
||||||
if load_extra:
|
if load_extra:
|
||||||
train_sets.append(
|
train_sets.append(
|
||||||
@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
|
|||||||
|
|
||||||
|
|
||||||
def split_same(ann_dir, wav_dir, load_extra=True):
|
def split_same(ann_dir, wav_dir, load_extra=True):
|
||||||
|
|
||||||
train_sets = []
|
train_sets = []
|
||||||
if load_extra:
|
if load_extra:
|
||||||
train_sets.append(
|
train_sets.append(
|
||||||
|
@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
|
|||||||
|
|
||||||
|
|
||||||
def standardize_low_freq(
|
def standardize_low_freq(
|
||||||
data: List[types.FileAnnotation], class_of_interest: str,
|
data: List[types.FileAnnotation],
|
||||||
|
class_of_interest: str,
|
||||||
) -> List[types.FileAnnotation]:
|
) -> List[types.FileAnnotation]:
|
||||||
# address the issue of highly variable low frequency annotations
|
# address the issue of highly variable low frequency annotations
|
||||||
# this often happens for contstant frequency calls
|
# this often happens for contstant frequency calls
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, List, NamedTuple, Optional
|
||||||
|
|
||||||
|
|
||||||
@ -594,8 +595,7 @@ class FeatureExtractor(Protocol):
|
|||||||
self,
|
self,
|
||||||
prediction: Prediction,
|
prediction: Prediction,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> float:
|
) -> float: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetDict(TypedDict):
|
class DatasetDict(TypedDict):
|
||||||
|
@ -100,7 +100,7 @@ def generate_spectrogram(
|
|||||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||||
spec = np.log1p(log_scaling * spec)
|
spec = np.log1p(log_scaling * spec)
|
||||||
elif params["spec_scale"] == "pcen":
|
elif params["spec_scale"] == "pcen":
|
||||||
spec = pcen(spec , sampling_rate)
|
spec = pcen(spec, sampling_rate)
|
||||||
|
|
||||||
elif params["spec_scale"] == "none":
|
elif params["spec_scale"] == "none":
|
||||||
pass
|
pass
|
||||||
|
@ -417,7 +417,9 @@ def plot_confusion_matrix(
|
|||||||
cm_norm = cm.sum(1)
|
cm_norm = cm.sum(1)
|
||||||
|
|
||||||
valid_inds = np.where(cm_norm > 0)[0]
|
valid_inds = np.where(cm_norm > 0)[0]
|
||||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
cm[valid_inds, :] = (
|
||||||
|
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||||
|
)
|
||||||
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -155,9 +155,9 @@ class InteractivePlotter:
|
|||||||
|
|
||||||
# draw bounding box around call
|
# draw bounding box around call
|
||||||
self.ax[1].patches[0].remove()
|
self.ax[1].patches[0].remove()
|
||||||
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
spec_width_orig = self.spec_slices[self.current_id].shape[
|
||||||
1.0 + 2.0 * self.spec_pad
|
1
|
||||||
)
|
] / (1.0 + 2.0 * self.spec_pad)
|
||||||
xx = w_diff + self.spec_pad * spec_width_orig
|
xx = w_diff + self.spec_pad * spec_width_orig
|
||||||
ww = spec_width_orig
|
ww = spec_width_orig
|
||||||
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
||||||
@ -183,7 +183,9 @@ class InteractivePlotter:
|
|||||||
round(self.call_info[self.current_id]["start_time"], 3)
|
round(self.call_info[self.current_id]["start_time"], 3)
|
||||||
)
|
)
|
||||||
+ ", prob="
|
+ ", prob="
|
||||||
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
+ str(
|
||||||
|
round(self.call_info[self.current_id]["det_prob"], 3)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.ax[0].set_xlabel(info_str)
|
self.ax[0].set_xlabel(info_str)
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ Functions
|
|||||||
`write`: Write a numpy array as a WAV file.
|
`write`: Write a numpy array as a WAV file.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -156,7 +157,6 @@ def read(filename, mmap=False):
|
|||||||
fid = open(filename, "rb")
|
fid = open(filename, "rb")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# some files seem to have the size recorded in the header greater than
|
# some files seem to have the size recorded in the header greater than
|
||||||
# the actual file size.
|
# the actual file size.
|
||||||
fid.seek(0, os.SEEK_END)
|
fid.seek(0, os.SEEK_END)
|
||||||
|
@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu
|
|||||||
import batdetect2.utils.audio_utils as au
|
import batdetect2.utils.audio_utils as au
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"audio_path", type=str, help="Input directory for audio"
|
"audio_path", type=str, help="Input directory for audio"
|
||||||
@ -65,7 +64,9 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
# load uk data - special case
|
# load uk data - special case
|
||||||
print("\nLoading:", args["uk_split"], "\n")
|
print("\nLoading:", args["uk_split"], "\n")
|
||||||
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
|
dataset_name = (
|
||||||
|
"uk_" + args["uk_split"]
|
||||||
|
) # should be uk_diff, or uk_same
|
||||||
datasets, _ = ts.get_train_test_data(
|
datasets, _ = ts.get_train_test_data(
|
||||||
args["ann_file"],
|
args["ann_file"],
|
||||||
args["audio_path"],
|
args["audio_path"],
|
||||||
|
@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
||||||
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
||||||
@ -143,7 +142,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# run model and filter detections so only keep ones in relevant time range
|
# run model and filter detections so only keep ones in relevant time range
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
|
results = du.process_file(
|
||||||
|
args_cmd["audio_file"], model, run_config, device
|
||||||
|
)
|
||||||
pred_anns = filter_anns(
|
pred_anns = filter_anns(
|
||||||
results["pred_dict"]["annotation"],
|
results["pred_dict"]["annotation"],
|
||||||
args_cmd["start_time"],
|
args_cmd["start_time"],
|
||||||
|
@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
parser.add_argument(
|
||||||
|
"audio_file", type=str, help="Path to input audio file"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"model_path", type=str, help="Path to trained BatDetect model"
|
"model_path", type=str, help="Path to trained BatDetect model"
|
||||||
)
|
)
|
||||||
|
@ -198,7 +198,6 @@ def save_summary_image(
|
|||||||
)
|
)
|
||||||
ii = 0
|
ii = 0
|
||||||
for row in ax:
|
for row in ax:
|
||||||
|
|
||||||
if type(row) != np.ndarray:
|
if type(row) != np.ndarray:
|
||||||
row = np.array([row])
|
row = np.array([row])
|
||||||
|
|
||||||
@ -215,7 +214,9 @@ def save_summary_image(
|
|||||||
)
|
)
|
||||||
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
||||||
col.set_xticks([])
|
col.set_xticks([])
|
||||||
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
|
col.title.set_text(
|
||||||
|
str(ii + 1) + " " + species_names[order[ii]]
|
||||||
|
)
|
||||||
col.tick_params(axis="both", which="major", labelsize=7)
|
col.tick_params(axis="both", which="major", labelsize=7)
|
||||||
ii += 1
|
ii += 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user