Changed bat_detect to batdetect2

This commit is contained in:
Santiago Martinez 2023-04-07 11:24:22 -06:00
parent 7c80441d60
commit 9e79849d6f
41 changed files with 89 additions and 120 deletions

6
app.py
View File

@ -3,9 +3,9 @@ import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz
import batdetect2.utils.audio_utils as au
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as viz
# setup the arguments
args = {}

View File

@ -1,6 +1,6 @@
"""Python API for bat_detect.
"""Python API for batdetect2.
This module provides a Python API for bat_detect. It can be used to
This module provides a Python API for batdetect2. It can be used to
process audio files or spectrograms with the default model or a custom
model.
@ -8,7 +8,7 @@ Example
-------
You can use the default model to process audio files. To process a single
file, use the `process_file` function.
>>> import bat_detect.api as api
>>> import batdetect2.api as api
>>> # Process audio file
>>> results = api.process_file("audio_file.wav")
@ -16,7 +16,7 @@ To process multiple files, use the `list_audio_files` function to get a list
of audio files in a directory. Then use the `process_file` function to
process each file.
>>> import bat_detect.api as api
>>> import batdetect2.api as api
>>> # Get list of audio files
>>> audio_files = api.list_audio_files("audio_directory")
>>> # Process audio files
@ -44,7 +44,7 @@ array directly, or `process_spectrogram` to process spectrograms. This
allows you to do other preprocessing steps before running the model for
predictions.
>>> import bat_detect.api as api
>>> import batdetect2.api as api
>>> # Load audio
>>> audio = api.load_audio("audio_file.wav")
>>> # Process the audio array
@ -73,7 +73,7 @@ following:
If you wish to interact directly with the model, you can use the `model`
attribute to get the default model.
>>> import bat_detect.api as api
>>> import batdetect2.api as api
>>> # Get the default model
>>> model = api.model
>>> # Process the spectrogram
@ -84,7 +84,7 @@ model outputs are a collection of raw tensors. The `postprocess`
function can be used to convert the model outputs into a list of
detections and a list of CNN features.
>>> import bat_detect.api as api
>>> import batdetect2.api as api
>>> # Get the default model
>>> model = api.model
>>> # Process the spectrogram
@ -102,22 +102,22 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
from bat_detect.detector.parameters import (
import batdetect2.utils.audio_utils as au
import batdetect2.utils.detector_utils as du
from batdetect2.detector.parameters import (
DEFAULT_MODEL_PATH,
DEFAULT_PROCESSING_CONFIGURATIONS,
DEFAULT_SPECTROGRAM_PARAMETERS,
TARGET_SAMPLERATE_HZ,
)
from bat_detect.types import (
from batdetect2.types import (
Annotation,
DetectionModel,
ModelOutput,
ProcessingConfiguration,
SpectrogramParameters,
)
from bat_detect.utils.detector_utils import list_audio_files, load_model
from batdetect2.utils.detector_utils import list_audio_files, load_model
# Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch")

View File

@ -3,9 +3,9 @@ import os
import click
from bat_detect import api
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
from bat_detect.utils.detector_utils import save_results_to_file
from batdetect2 import api
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.utils.detector_utils import save_results_to_file
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

View File

@ -3,14 +3,14 @@ import torch.fft
import torch.nn.functional as F
from torch import nn
from bat_detect.detector.model_helpers import (
from batdetect2.detector.model_helpers import (
ConvBlockDownCoordF,
ConvBlockDownStandard,
ConvBlockUpF,
ConvBlockUpStandard,
SelfAttention,
)
from bat_detect.types import ModelOutput
from batdetect2.types import ModelOutput
__all__ = [
"Net2DFast",
@ -104,7 +104,6 @@ class Net2DFast(nn.Module):
)
def forward(self, ip, return_feats=False) -> ModelOutput:
# encoder
x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1)
@ -326,7 +325,6 @@ class Net2DFastNoCoordConv(nn.Module):
)
def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
@ -344,11 +342,12 @@ class Net2DFastNoCoordConv(nn.Module):
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb,
pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x,
)

View File

@ -1,10 +1,7 @@
import datetime
import os
from bat_detect.types import (
ProcessingConfiguration,
SpectrogramParameters,
)
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
TARGET_SAMPLERATE_HZ = 256000
FFT_WIN_LENGTH_S = 512 / 256000.0

View File

@ -5,8 +5,8 @@ import numpy as np
import torch
from torch import nn
from bat_detect.detector.models import ModelOutput
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults
from batdetect2.detector.models import ModelOutput
from batdetect2.types import NonMaximumSuppressionConfig, PredictionResults
np.seterr(divide="ignore", invalid="ignore")

View File

@ -11,11 +11,11 @@ import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from bat_detect.detector import parameters
import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as pu
from batdetect2.detector import parameters
import batdetect2.train.evaluate as evl
import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu
def get_blank_annotation(ip_str):

View File

@ -1,4 +1,8 @@
# Evaluating BatDetect2
> **Warning**
> This code in currently broken. Will fix soon, stay tuned.
This script evaluates a trained model and outputs several plots summarizing the performance. It is used as follows:
`python path_to_store_images/ path_to_audio_files/ path_to_annotation_file/ path_to_trained_model/`

View File

@ -10,20 +10,18 @@ import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
sys.path.append(os.path.join("..", ".."))
import bat_detect.detector.models as models
import bat_detect.detector.parameters as parameters
import bat_detect.detector.post_process as pp
import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl
import bat_detect.train.losses as losses
import bat_detect.train.train_model as tm
import bat_detect.train.train_utils as tu
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as pu
import batdetect2.detector.models as models
import batdetect2.detector.parameters as parameters
import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.losses as losses
import batdetect2.train.train_model as tm
import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu
if __name__ == "__main__":
info_str = "\nBatDetect - Finetune Model\n"
print(info_str)
@ -272,7 +270,6 @@ if __name__ == "__main__":
# main train loop
for epoch in range(0, params["num_epochs"] + 1):
train_loss = tm.train(
model,
epoch,

View File

@ -1,16 +1,13 @@
import argparse
import json
import os
import sys
import numpy as np
sys.path.append(os.path.join("..", ".."))
import bat_detect.train.train_utils as tu
import batdetect2.train.train_utils as tu
def print_dataset_stats(data, split_name, classes_to_ignore):
print("\nSplit:", split_name)
print("Num files:", len(data))
@ -37,7 +34,6 @@ def print_dataset_stats(data, split_name, classes_to_ignore):
def load_file_names(file_name):
if os.path.isfile(file_name):
with open(file_name) as da:
files = [line.rstrip() for line in da.readlines()]
@ -53,7 +49,6 @@ def load_file_names(file_name):
if __name__ == "__main__":
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
print(info_str)

View File

@ -1,5 +1,9 @@
# Finetuning the BatDetet2 model on your own data
| :warning: WARNING |
|:---------------------------|
| This is not currently working, but we are working on fixing this code |
Main steps:
1. Annotate your data using the annotation GUI.
2. Run `prep_data_finetune.py` to create a training and validation split for your data.

View File

@ -7,8 +7,8 @@ import torch
import torch.nn.functional as F
import torchaudio
import bat_detect.utils.audio_utils as au
from bat_detect.types import AnnotationGroup, HeatmapParameters
import batdetect2.utils.audio_utils as au
from batdetect2.types import AnnotationGroup, HeatmapParameters
def generate_gt_heatmaps(

View File

@ -7,15 +7,15 @@ import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from bat_detect.detector import models
from bat_detect.detector import parameters
from bat_detect.train import losses
import bat_detect.detector.post_process as pp
import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl
import bat_detect.train.train_split as ts
import bat_detect.train.train_utils as tu
import bat_detect.utils.plot_utils as pu
from batdetect2.detector import models
from batdetect2.detector import parameters
from batdetect2.train import losses
import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu
import batdetect2.utils.plot_utils as pu
warnings.filterwarnings("ignore", category=UserWarning)

View File

@ -26,7 +26,7 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
"is_binary": True, # just a bat / not bat dataset ie no classes
"ann_path": ann_dir
+ "train_set_bulgaria_batdetective_with_bbs.json",
"wav_path": wav_dir + "bat_detective/audio/",
"wav_path": wav_dir + "batdetect2ive/audio/",
}
)
train_sets.append(
@ -154,7 +154,7 @@ def split_same(ann_dir, wav_dir, load_extra=True):
"is_binary": True,
"ann_path": ann_dir
+ "train_set_bulgaria_batdetective_with_bbs.json",
"wav_path": wav_dir + "bat_detective/audio/",
"wav_path": wav_dir + "batdetect2ive/audio/",
}
)
train_sets.append(

View File

@ -6,36 +6,12 @@ import librosa.core.spectrum
import numpy as np
import torch
from bat_detect.detector.parameters import (
DENOISE_SPEC_AVG,
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
SPEC_SCALE,
)
from . import wavfile
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
__all__ = [
"load_audio",
"generate_spectrogram",
"pad_audio",
"SpectrogramParameters",
"DEFAULT_SPECTROGRAM_PARAMETERS",
]
@ -60,7 +36,6 @@ def generate_spectrogram(
return_spec_for_viz=False,
check_spec_size=True,
):
# generate spectrogram
spec = gen_mag_spectrogram(
audio,

View File

@ -7,12 +7,12 @@ import pandas as pd
import torch
import torch.nn.functional as F
import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
from bat_detect.detector import models
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
from bat_detect.types import (
import batdetect2.detector.compute_features as feats
import batdetect2.detector.post_process as pp
import batdetect2.utils.audio_utils as au
from batdetect2.detector import models
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import (
Annotation,
DetectionModel,
FileAnnotations,

View File

@ -217,7 +217,6 @@ def plot_spec(
plot_boxes=True,
fixed_aspect=True,
):
if fixed_aspect:
# ouptut image will be this width irrespective of the duration of the audio file
width = 12

View File

@ -1,5 +1,5 @@
"""Run bat_detect.command.main() from the command line."""
from bat_detect.cli import detect
"""Run batdetect2.command.main() from the command line."""
from batdetect2.cli import detect
if __name__ == "__main__":
detect()

View File

@ -5,17 +5,15 @@ is the mean spectrogram for each class.
import argparse
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import viz_helpers as vz
sys.path.append(os.path.join(".."))
import bat_detect.detector.parameters as parameters
import bat_detect.train.train_split as ts
import bat_detect.train.train_utils as tu
import bat_detect.utils.audio_utils as au
import batdetect2.detector.parameters as parameters
import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu
import batdetect2.utils.audio_utils as au
if __name__ == "__main__":

View File

@ -15,11 +15,10 @@ import sys
import matplotlib.pyplot as plt
import numpy as np
sys.path.append(os.path.join(".."))
import bat_detect.evaluate.evaluate_models as evlm
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz
import batdetect2.evaluate.evaluate_models as evlm
import batdetect2.utils.audio_utils as au
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as viz
def filter_anns(anns, start_time, stop_time):

View File

@ -17,11 +17,10 @@ import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
sys.path.append(os.path.join(".."))
import bat_detect.detector.parameters as parameters
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz
import batdetect2.detector.parameters as parameters
import batdetect2.utils.audio_utils as au
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as viz
if __name__ == "__main__":

View File

@ -7,7 +7,7 @@ from scipy import ndimage
sys.path.append(os.path.join(".."))
import bat_detect.utils.audio_utils as au
import batdetect2.utils.audio_utils as au
def generate_spectrogram_data(

View File

@ -7,7 +7,7 @@ import numpy as np
import torch
from torch import nn
from bat_detect import api
from batdetect2 import api
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")

View File

@ -1,10 +1,11 @@
"""Test the command line interface."""
from click.testing import CliRunner
from bat_detect.cli import cli
from batdetect2.cli import cli
def test_cli_base_command():
"""Test the base command."""
runner = CliRunner()
result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0
@ -12,6 +13,7 @@ def test_cli_base_command():
def test_cli_detect_command_help():
"""Test the detect command help."""
runner = CliRunner()
result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
@ -19,6 +21,7 @@ def test_cli_detect_command_help():
def test_cli_detect_command_on_test_audio(tmp_path):
"""Test the detect command on test audio."""
results_dir = tmp_path / "results"
# Remove results dir if it exists