Changed bat_detect to batdetect2

This commit is contained in:
Santiago Martinez 2023-04-07 11:24:22 -06:00
parent 7c80441d60
commit 9e79849d6f
41 changed files with 89 additions and 120 deletions

6
app.py
View File

@ -3,9 +3,9 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import bat_detect.utils.audio_utils as au import batdetect2.utils.audio_utils as au
import bat_detect.utils.detector_utils as du import batdetect2.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz import batdetect2.utils.plot_utils as viz
# setup the arguments # setup the arguments
args = {} args = {}

View File

@ -1,6 +1,6 @@
"""Python API for bat_detect. """Python API for batdetect2.
This module provides a Python API for bat_detect. It can be used to This module provides a Python API for batdetect2. It can be used to
process audio files or spectrograms with the default model or a custom process audio files or spectrograms with the default model or a custom
model. model.
@ -8,7 +8,7 @@ Example
------- -------
You can use the default model to process audio files. To process a single You can use the default model to process audio files. To process a single
file, use the `process_file` function. file, use the `process_file` function.
>>> import bat_detect.api as api >>> import batdetect2.api as api
>>> # Process audio file >>> # Process audio file
>>> results = api.process_file("audio_file.wav") >>> results = api.process_file("audio_file.wav")
@ -16,7 +16,7 @@ To process multiple files, use the `list_audio_files` function to get a list
of audio files in a directory. Then use the `process_file` function to of audio files in a directory. Then use the `process_file` function to
process each file. process each file.
>>> import bat_detect.api as api >>> import batdetect2.api as api
>>> # Get list of audio files >>> # Get list of audio files
>>> audio_files = api.list_audio_files("audio_directory") >>> audio_files = api.list_audio_files("audio_directory")
>>> # Process audio files >>> # Process audio files
@ -44,7 +44,7 @@ array directly, or `process_spectrogram` to process spectrograms. This
allows you to do other preprocessing steps before running the model for allows you to do other preprocessing steps before running the model for
predictions. predictions.
>>> import bat_detect.api as api >>> import batdetect2.api as api
>>> # Load audio >>> # Load audio
>>> audio = api.load_audio("audio_file.wav") >>> audio = api.load_audio("audio_file.wav")
>>> # Process the audio array >>> # Process the audio array
@ -73,7 +73,7 @@ following:
If you wish to interact directly with the model, you can use the `model` If you wish to interact directly with the model, you can use the `model`
attribute to get the default model. attribute to get the default model.
>>> import bat_detect.api as api >>> import batdetect2.api as api
>>> # Get the default model >>> # Get the default model
>>> model = api.model >>> model = api.model
>>> # Process the spectrogram >>> # Process the spectrogram
@ -84,7 +84,7 @@ model outputs are a collection of raw tensors. The `postprocess`
function can be used to convert the model outputs into a list of function can be used to convert the model outputs into a list of
detections and a list of CNN features. detections and a list of CNN features.
>>> import bat_detect.api as api >>> import batdetect2.api as api
>>> # Get the default model >>> # Get the default model
>>> model = api.model >>> model = api.model
>>> # Process the spectrogram >>> # Process the spectrogram
@ -102,22 +102,22 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import bat_detect.utils.audio_utils as au import batdetect2.utils.audio_utils as au
import bat_detect.utils.detector_utils as du import batdetect2.utils.detector_utils as du
from bat_detect.detector.parameters import ( from batdetect2.detector.parameters import (
DEFAULT_MODEL_PATH, DEFAULT_MODEL_PATH,
DEFAULT_PROCESSING_CONFIGURATIONS, DEFAULT_PROCESSING_CONFIGURATIONS,
DEFAULT_SPECTROGRAM_PARAMETERS, DEFAULT_SPECTROGRAM_PARAMETERS,
TARGET_SAMPLERATE_HZ, TARGET_SAMPLERATE_HZ,
) )
from bat_detect.types import ( from batdetect2.types import (
Annotation, Annotation,
DetectionModel, DetectionModel,
ModelOutput, ModelOutput,
ProcessingConfiguration, ProcessingConfiguration,
SpectrogramParameters, SpectrogramParameters,
) )
from bat_detect.utils.detector_utils import list_audio_files, load_model from batdetect2.utils.detector_utils import list_audio_files, load_model
# Remove warnings from torch # Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch") warnings.filterwarnings("ignore", category=UserWarning, module="torch")

View File

@ -3,9 +3,9 @@ import os
import click import click
from bat_detect import api from batdetect2 import api
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from bat_detect.utils.detector_utils import save_results_to_file from batdetect2.utils.detector_utils import save_results_to_file
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

View File

@ -3,14 +3,14 @@ import torch.fft
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from bat_detect.detector.model_helpers import ( from batdetect2.detector.model_helpers import (
ConvBlockDownCoordF, ConvBlockDownCoordF,
ConvBlockDownStandard, ConvBlockDownStandard,
ConvBlockUpF, ConvBlockUpF,
ConvBlockUpStandard, ConvBlockUpStandard,
SelfAttention, SelfAttention,
) )
from bat_detect.types import ModelOutput from batdetect2.types import ModelOutput
__all__ = [ __all__ = [
"Net2DFast", "Net2DFast",
@ -104,7 +104,6 @@ class Net2DFast(nn.Module):
) )
def forward(self, ip, return_feats=False) -> ModelOutput: def forward(self, ip, return_feats=False) -> ModelOutput:
# encoder # encoder
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
@ -326,7 +325,6 @@ class Net2DFastNoCoordConv(nn.Module):
) )
def forward(self, ip, return_feats=False) -> ModelOutput: def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
@ -344,11 +342,12 @@ class Net2DFastNoCoordConv(nn.Module):
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
return ModelOutput( return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x, features=x,
) )

View File

@ -1,10 +1,7 @@
import datetime import datetime
import os import os
from bat_detect.types import ( from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
ProcessingConfiguration,
SpectrogramParameters,
)
TARGET_SAMPLERATE_HZ = 256000 TARGET_SAMPLERATE_HZ = 256000
FFT_WIN_LENGTH_S = 512 / 256000.0 FFT_WIN_LENGTH_S = 512 / 256000.0

View File

@ -5,8 +5,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from bat_detect.detector.models import ModelOutput from batdetect2.detector.models import ModelOutput
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults from batdetect2.types import NonMaximumSuppressionConfig, PredictionResults
np.seterr(divide="ignore", invalid="ignore") np.seterr(divide="ignore", invalid="ignore")

View File

@ -11,11 +11,11 @@ import numpy as np
import pandas as pd import pandas as pd
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from bat_detect.detector import parameters from batdetect2.detector import parameters
import bat_detect.train.evaluate as evl import batdetect2.train.evaluate as evl
import bat_detect.train.train_utils as tu import batdetect2.train.train_utils as tu
import bat_detect.utils.detector_utils as du import batdetect2.utils.detector_utils as du
import bat_detect.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
def get_blank_annotation(ip_str): def get_blank_annotation(ip_str):

View File

@ -1,4 +1,8 @@
# Evaluating BatDetect2 # Evaluating BatDetect2
> **Warning**
> This code in currently broken. Will fix soon, stay tuned.
This script evaluates a trained model and outputs several plots summarizing the performance. It is used as follows: This script evaluates a trained model and outputs several plots summarizing the performance. It is used as follows:
`python path_to_store_images/ path_to_audio_files/ path_to_annotation_file/ path_to_trained_model/` `python path_to_store_images/ path_to_audio_files/ path_to_annotation_file/ path_to_trained_model/`

View File

@ -10,20 +10,18 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
sys.path.append(os.path.join("..", "..")) import batdetect2.detector.models as models
import bat_detect.detector.models as models import batdetect2.detector.parameters as parameters
import bat_detect.detector.parameters as parameters import batdetect2.detector.post_process as pp
import bat_detect.detector.post_process as pp import batdetect2.train.audio_dataloader as adl
import bat_detect.train.audio_dataloader as adl import batdetect2.train.evaluate as evl
import bat_detect.train.evaluate as evl import batdetect2.train.losses as losses
import bat_detect.train.losses as losses import batdetect2.train.train_model as tm
import bat_detect.train.train_model as tm import batdetect2.train.train_utils as tu
import bat_detect.train.train_utils as tu import batdetect2.utils.detector_utils as du
import bat_detect.utils.detector_utils as du import batdetect2.utils.plot_utils as pu
import bat_detect.utils.plot_utils as pu
if __name__ == "__main__": if __name__ == "__main__":
info_str = "\nBatDetect - Finetune Model\n" info_str = "\nBatDetect - Finetune Model\n"
print(info_str) print(info_str)
@ -272,7 +270,6 @@ if __name__ == "__main__":
# main train loop # main train loop
for epoch in range(0, params["num_epochs"] + 1): for epoch in range(0, params["num_epochs"] + 1):
train_loss = tm.train( train_loss = tm.train(
model, model,
epoch, epoch,

View File

@ -1,16 +1,13 @@
import argparse import argparse
import json import json
import os import os
import sys
import numpy as np import numpy as np
sys.path.append(os.path.join("..", "..")) import batdetect2.train.train_utils as tu
import bat_detect.train.train_utils as tu
def print_dataset_stats(data, split_name, classes_to_ignore): def print_dataset_stats(data, split_name, classes_to_ignore):
print("\nSplit:", split_name) print("\nSplit:", split_name)
print("Num files:", len(data)) print("Num files:", len(data))
@ -37,7 +34,6 @@ def print_dataset_stats(data, split_name, classes_to_ignore):
def load_file_names(file_name): def load_file_names(file_name):
if os.path.isfile(file_name): if os.path.isfile(file_name):
with open(file_name) as da: with open(file_name) as da:
files = [line.rstrip() for line in da.readlines()] files = [line.rstrip() for line in da.readlines()]
@ -53,7 +49,6 @@ def load_file_names(file_name):
if __name__ == "__main__": if __name__ == "__main__":
info_str = "\nBatDetect - Prepare Data for Finetuning\n" info_str = "\nBatDetect - Prepare Data for Finetuning\n"
print(info_str) print(info_str)

View File

@ -1,5 +1,9 @@
# Finetuning the BatDetet2 model on your own data # Finetuning the BatDetet2 model on your own data
| :warning: WARNING |
|:---------------------------|
| This is not currently working, but we are working on fixing this code |
Main steps: Main steps:
1. Annotate your data using the annotation GUI. 1. Annotate your data using the annotation GUI.
2. Run `prep_data_finetune.py` to create a training and validation split for your data. 2. Run `prep_data_finetune.py` to create a training and validation split for your data.

View File

@ -7,8 +7,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import bat_detect.utils.audio_utils as au import batdetect2.utils.audio_utils as au
from bat_detect.types import AnnotationGroup, HeatmapParameters from batdetect2.types import AnnotationGroup, HeatmapParameters
def generate_gt_heatmaps( def generate_gt_heatmaps(

View File

@ -7,15 +7,15 @@ import numpy as np
import torch import torch
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from bat_detect.detector import models from batdetect2.detector import models
from bat_detect.detector import parameters from batdetect2.detector import parameters
from bat_detect.train import losses from batdetect2.train import losses
import bat_detect.detector.post_process as pp import batdetect2.detector.post_process as pp
import bat_detect.train.audio_dataloader as adl import batdetect2.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl import batdetect2.train.evaluate as evl
import bat_detect.train.train_split as ts import batdetect2.train.train_split as ts
import bat_detect.train.train_utils as tu import batdetect2.train.train_utils as tu
import bat_detect.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)

View File

@ -26,7 +26,7 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
"is_binary": True, # just a bat / not bat dataset ie no classes "is_binary": True, # just a bat / not bat dataset ie no classes
"ann_path": ann_dir "ann_path": ann_dir
+ "train_set_bulgaria_batdetective_with_bbs.json", + "train_set_bulgaria_batdetective_with_bbs.json",
"wav_path": wav_dir + "bat_detective/audio/", "wav_path": wav_dir + "batdetect2ive/audio/",
} }
) )
train_sets.append( train_sets.append(
@ -154,7 +154,7 @@ def split_same(ann_dir, wav_dir, load_extra=True):
"is_binary": True, "is_binary": True,
"ann_path": ann_dir "ann_path": ann_dir
+ "train_set_bulgaria_batdetective_with_bbs.json", + "train_set_bulgaria_batdetective_with_bbs.json",
"wav_path": wav_dir + "bat_detective/audio/", "wav_path": wav_dir + "batdetect2ive/audio/",
} }
) )
train_sets.append( train_sets.append(

View File

@ -6,36 +6,12 @@ import librosa.core.spectrum
import numpy as np import numpy as np
import torch import torch
from bat_detect.detector.parameters import (
DENOISE_SPEC_AVG,
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
SPEC_SCALE,
)
from . import wavfile from . import wavfile
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
__all__ = [ __all__ = [
"load_audio", "load_audio",
"generate_spectrogram", "generate_spectrogram",
"pad_audio", "pad_audio",
"SpectrogramParameters",
"DEFAULT_SPECTROGRAM_PARAMETERS",
] ]
@ -60,7 +36,6 @@ def generate_spectrogram(
return_spec_for_viz=False, return_spec_for_viz=False,
check_spec_size=True, check_spec_size=True,
): ):
# generate spectrogram # generate spectrogram
spec = gen_mag_spectrogram( spec = gen_mag_spectrogram(
audio, audio,

View File

@ -7,12 +7,12 @@ import pandas as pd
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import bat_detect.detector.compute_features as feats import batdetect2.detector.compute_features as feats
import bat_detect.detector.post_process as pp import batdetect2.detector.post_process as pp
import bat_detect.utils.audio_utils as au import batdetect2.utils.audio_utils as au
from bat_detect.detector import models from batdetect2.detector import models
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from bat_detect.types import ( from batdetect2.types import (
Annotation, Annotation,
DetectionModel, DetectionModel,
FileAnnotations, FileAnnotations,

View File

@ -217,7 +217,6 @@ def plot_spec(
plot_boxes=True, plot_boxes=True,
fixed_aspect=True, fixed_aspect=True,
): ):
if fixed_aspect: if fixed_aspect:
# ouptut image will be this width irrespective of the duration of the audio file # ouptut image will be this width irrespective of the duration of the audio file
width = 12 width = 12

View File

@ -1,5 +1,5 @@
"""Run bat_detect.command.main() from the command line.""" """Run batdetect2.command.main() from the command line."""
from bat_detect.cli import detect from batdetect2.cli import detect
if __name__ == "__main__": if __name__ == "__main__":
detect() detect()

View File

@ -5,17 +5,15 @@ is the mean spectrogram for each class.
import argparse import argparse
import os import os
import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import viz_helpers as vz import viz_helpers as vz
sys.path.append(os.path.join("..")) import batdetect2.detector.parameters as parameters
import bat_detect.detector.parameters as parameters import batdetect2.train.train_split as ts
import bat_detect.train.train_split as ts import batdetect2.train.train_utils as tu
import bat_detect.train.train_utils as tu import batdetect2.utils.audio_utils as au
import bat_detect.utils.audio_utils as au
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -15,11 +15,10 @@ import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
sys.path.append(os.path.join("..")) import batdetect2.evaluate.evaluate_models as evlm
import bat_detect.evaluate.evaluate_models as evlm import batdetect2.utils.audio_utils as au
import bat_detect.utils.audio_utils as au import batdetect2.utils.detector_utils as du
import bat_detect.utils.detector_utils as du import batdetect2.utils.plot_utils as viz
import bat_detect.utils.plot_utils as viz
def filter_anns(anns, start_time, stop_time): def filter_anns(anns, start_time, stop_time):

View File

@ -17,11 +17,10 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from scipy.io import wavfile from scipy.io import wavfile
sys.path.append(os.path.join("..")) import batdetect2.detector.parameters as parameters
import bat_detect.detector.parameters as parameters import batdetect2.utils.audio_utils as au
import bat_detect.utils.audio_utils as au import batdetect2.utils.detector_utils as du
import bat_detect.utils.detector_utils as du import batdetect2.utils.plot_utils as viz
import bat_detect.utils.plot_utils as viz
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,7 +7,7 @@ from scipy import ndimage
sys.path.append(os.path.join("..")) sys.path.append(os.path.join(".."))
import bat_detect.utils.audio_utils as au import batdetect2.utils.audio_utils as au
def generate_spectrogram_data( def generate_spectrogram_data(

View File

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from bat_detect import api from batdetect2 import api
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")

View File

@ -1,10 +1,11 @@
"""Test the command line interface.""" """Test the command line interface."""
from click.testing import CliRunner from click.testing import CliRunner
from bat_detect.cli import cli from batdetect2.cli import cli
def test_cli_base_command(): def test_cli_base_command():
"""Test the base command."""
runner = CliRunner() runner = CliRunner()
result = runner.invoke(cli, ["--help"]) result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0 assert result.exit_code == 0
@ -12,6 +13,7 @@ def test_cli_base_command():
def test_cli_detect_command_help(): def test_cli_detect_command_help():
"""Test the detect command help."""
runner = CliRunner() runner = CliRunner()
result = runner.invoke(cli, ["detect", "--help"]) result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0 assert result.exit_code == 0
@ -19,6 +21,7 @@ def test_cli_detect_command_help():
def test_cli_detect_command_on_test_audio(tmp_path): def test_cli_detect_command_on_test_audio(tmp_path):
"""Test the detect command on test audio."""
results_dir = tmp_path / "results" results_dir = tmp_path / "results"
# Remove results dir if it exists # Remove results dir if it exists