diff --git a/app.py b/app.py index 1b820b2..9c82b01 100644 --- a/app.py +++ b/app.py @@ -3,9 +3,9 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import bat_detect.utils.audio_utils as au -import bat_detect.utils.detector_utils as du -import bat_detect.utils.plot_utils as viz +import batdetect2.utils.audio_utils as au +import batdetect2.utils.detector_utils as du +import batdetect2.utils.plot_utils as viz # setup the arguments args = {} diff --git a/bat_detect/models/Net2DFast_UK_same.pth.tar b/bat_detect/models/Net2DFast_UK_same.pth.tar deleted file mode 100644 index e3704b7..0000000 Binary files a/bat_detect/models/Net2DFast_UK_same.pth.tar and /dev/null differ diff --git a/bat_detect/__init__.py b/batdetect2/__init__.py similarity index 100% rename from bat_detect/__init__.py rename to batdetect2/__init__.py diff --git a/bat_detect/api.py b/batdetect2/api.py similarity index 95% rename from bat_detect/api.py rename to batdetect2/api.py index df9c987..a1988ca 100644 --- a/bat_detect/api.py +++ b/batdetect2/api.py @@ -1,6 +1,6 @@ -"""Python API for bat_detect. +"""Python API for batdetect2. -This module provides a Python API for bat_detect. It can be used to +This module provides a Python API for batdetect2. It can be used to process audio files or spectrograms with the default model or a custom model. @@ -8,7 +8,7 @@ Example ------- You can use the default model to process audio files. To process a single file, use the `process_file` function. ->>> import bat_detect.api as api +>>> import batdetect2.api as api >>> # Process audio file >>> results = api.process_file("audio_file.wav") @@ -16,7 +16,7 @@ To process multiple files, use the `list_audio_files` function to get a list of audio files in a directory. Then use the `process_file` function to process each file. ->>> import bat_detect.api as api +>>> import batdetect2.api as api >>> # Get list of audio files >>> audio_files = api.list_audio_files("audio_directory") >>> # Process audio files @@ -44,7 +44,7 @@ array directly, or `process_spectrogram` to process spectrograms. This allows you to do other preprocessing steps before running the model for predictions. ->>> import bat_detect.api as api +>>> import batdetect2.api as api >>> # Load audio >>> audio = api.load_audio("audio_file.wav") >>> # Process the audio array @@ -73,7 +73,7 @@ following: If you wish to interact directly with the model, you can use the `model` attribute to get the default model. ->>> import bat_detect.api as api +>>> import batdetect2.api as api >>> # Get the default model >>> model = api.model >>> # Process the spectrogram @@ -84,7 +84,7 @@ model outputs are a collection of raw tensors. The `postprocess` function can be used to convert the model outputs into a list of detections and a list of CNN features. ->>> import bat_detect.api as api +>>> import batdetect2.api as api >>> # Get the default model >>> model = api.model >>> # Process the spectrogram @@ -102,22 +102,22 @@ from typing import List, Optional, Tuple import numpy as np import torch -import bat_detect.utils.audio_utils as au -import bat_detect.utils.detector_utils as du -from bat_detect.detector.parameters import ( +import batdetect2.utils.audio_utils as au +import batdetect2.utils.detector_utils as du +from batdetect2.detector.parameters import ( DEFAULT_MODEL_PATH, DEFAULT_PROCESSING_CONFIGURATIONS, DEFAULT_SPECTROGRAM_PARAMETERS, TARGET_SAMPLERATE_HZ, ) -from bat_detect.types import ( +from batdetect2.types import ( Annotation, DetectionModel, ModelOutput, ProcessingConfiguration, SpectrogramParameters, ) -from bat_detect.utils.detector_utils import list_audio_files, load_model +from batdetect2.utils.detector_utils import list_audio_files, load_model # Remove warnings from torch warnings.filterwarnings("ignore", category=UserWarning, module="torch") diff --git a/bat_detect/cli.py b/batdetect2/cli.py similarity index 95% rename from bat_detect/cli.py rename to batdetect2/cli.py index 29f4142..b5ef01a 100644 --- a/bat_detect/cli.py +++ b/batdetect2/cli.py @@ -3,9 +3,9 @@ import os import click -from bat_detect import api -from bat_detect.detector.parameters import DEFAULT_MODEL_PATH -from bat_detect.utils.detector_utils import save_results_to_file +from batdetect2 import api +from batdetect2.detector.parameters import DEFAULT_MODEL_PATH +from batdetect2.utils.detector_utils import save_results_to_file CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/bat_detect/detector/__init__.py b/batdetect2/detector/__init__.py similarity index 100% rename from bat_detect/detector/__init__.py rename to batdetect2/detector/__init__.py diff --git a/bat_detect/detector/compute_features.py b/batdetect2/detector/compute_features.py similarity index 100% rename from bat_detect/detector/compute_features.py rename to batdetect2/detector/compute_features.py diff --git a/bat_detect/detector/model_helpers.py b/batdetect2/detector/model_helpers.py similarity index 100% rename from bat_detect/detector/model_helpers.py rename to batdetect2/detector/model_helpers.py diff --git a/bat_detect/detector/models.py b/batdetect2/detector/models.py similarity index 98% rename from bat_detect/detector/models.py rename to batdetect2/detector/models.py index 99a48e1..56e63f3 100644 --- a/bat_detect/detector/models.py +++ b/batdetect2/detector/models.py @@ -3,14 +3,14 @@ import torch.fft import torch.nn.functional as F from torch import nn -from bat_detect.detector.model_helpers import ( +from batdetect2.detector.model_helpers import ( ConvBlockDownCoordF, ConvBlockDownStandard, ConvBlockUpF, ConvBlockUpStandard, SelfAttention, ) -from bat_detect.types import ModelOutput +from batdetect2.types import ModelOutput __all__ = [ "Net2DFast", @@ -104,7 +104,6 @@ class Net2DFast(nn.Module): ) def forward(self, ip, return_feats=False) -> ModelOutput: - # encoder x1 = self.conv_dn_0(ip) x2 = self.conv_dn_1(x1) @@ -326,7 +325,6 @@ class Net2DFastNoCoordConv(nn.Module): ) def forward(self, ip, return_feats=False) -> ModelOutput: - x1 = self.conv_dn_0(ip) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) @@ -344,11 +342,12 @@ class Net2DFastNoCoordConv(nn.Module): cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) + pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,) + return ModelOutput( pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_class=comb, pred_class_un_norm=cls, - pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, features=x, ) diff --git a/bat_detect/detector/parameters.py b/batdetect2/detector/parameters.py similarity index 98% rename from bat_detect/detector/parameters.py rename to batdetect2/detector/parameters.py index f733062..04544ed 100644 --- a/bat_detect/detector/parameters.py +++ b/batdetect2/detector/parameters.py @@ -1,10 +1,7 @@ import datetime import os -from bat_detect.types import ( - ProcessingConfiguration, - SpectrogramParameters, -) +from batdetect2.types import ProcessingConfiguration, SpectrogramParameters TARGET_SAMPLERATE_HZ = 256000 FFT_WIN_LENGTH_S = 512 / 256000.0 diff --git a/bat_detect/detector/post_process.py b/batdetect2/detector/post_process.py similarity index 98% rename from bat_detect/detector/post_process.py rename to batdetect2/detector/post_process.py index 5aa6895..a2ba353 100644 --- a/bat_detect/detector/post_process.py +++ b/batdetect2/detector/post_process.py @@ -5,8 +5,8 @@ import numpy as np import torch from torch import nn -from bat_detect.detector.models import ModelOutput -from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults +from batdetect2.detector.models import ModelOutput +from batdetect2.types import NonMaximumSuppressionConfig, PredictionResults np.seterr(divide="ignore", invalid="ignore") diff --git a/bat_detect/evaluate/__init__.py b/batdetect2/evaluate/__init__.py similarity index 100% rename from bat_detect/evaluate/__init__.py rename to batdetect2/evaluate/__init__.py diff --git a/bat_detect/evaluate/evaluate_models.py b/batdetect2/evaluate/evaluate_models.py similarity index 99% rename from bat_detect/evaluate/evaluate_models.py rename to batdetect2/evaluate/evaluate_models.py index bf70f15..97c1bd1 100644 --- a/bat_detect/evaluate/evaluate_models.py +++ b/batdetect2/evaluate/evaluate_models.py @@ -11,11 +11,11 @@ import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier -from bat_detect.detector import parameters -import bat_detect.train.evaluate as evl -import bat_detect.train.train_utils as tu -import bat_detect.utils.detector_utils as du -import bat_detect.utils.plot_utils as pu +from batdetect2.detector import parameters +import batdetect2.train.evaluate as evl +import batdetect2.train.train_utils as tu +import batdetect2.utils.detector_utils as du +import batdetect2.utils.plot_utils as pu def get_blank_annotation(ip_str): diff --git a/bat_detect/evaluate/readme.md b/batdetect2/evaluate/readme.md similarity index 96% rename from bat_detect/evaluate/readme.md rename to batdetect2/evaluate/readme.md index b199cfa..fec91ca 100644 --- a/bat_detect/evaluate/readme.md +++ b/batdetect2/evaluate/readme.md @@ -1,4 +1,8 @@ # Evaluating BatDetect2 + +> **Warning** +> This code in currently broken. Will fix soon, stay tuned. + This script evaluates a trained model and outputs several plots summarizing the performance. It is used as follows: `python path_to_store_images/ path_to_audio_files/ path_to_annotation_file/ path_to_trained_model/` diff --git a/bat_detect/finetune/__init__.py b/batdetect2/finetune/__init__.py similarity index 100% rename from bat_detect/finetune/__init__.py rename to batdetect2/finetune/__init__.py diff --git a/bat_detect/finetune/finetune_model.py b/batdetect2/finetune/finetune_model.py similarity index 95% rename from bat_detect/finetune/finetune_model.py rename to batdetect2/finetune/finetune_model.py index 8988096..77a2711 100644 --- a/bat_detect/finetune/finetune_model.py +++ b/batdetect2/finetune/finetune_model.py @@ -10,20 +10,18 @@ import torch import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR -sys.path.append(os.path.join("..", "..")) -import bat_detect.detector.models as models -import bat_detect.detector.parameters as parameters -import bat_detect.detector.post_process as pp -import bat_detect.train.audio_dataloader as adl -import bat_detect.train.evaluate as evl -import bat_detect.train.losses as losses -import bat_detect.train.train_model as tm -import bat_detect.train.train_utils as tu -import bat_detect.utils.detector_utils as du -import bat_detect.utils.plot_utils as pu +import batdetect2.detector.models as models +import batdetect2.detector.parameters as parameters +import batdetect2.detector.post_process as pp +import batdetect2.train.audio_dataloader as adl +import batdetect2.train.evaluate as evl +import batdetect2.train.losses as losses +import batdetect2.train.train_model as tm +import batdetect2.train.train_utils as tu +import batdetect2.utils.detector_utils as du +import batdetect2.utils.plot_utils as pu if __name__ == "__main__": - info_str = "\nBatDetect - Finetune Model\n" print(info_str) @@ -272,7 +270,6 @@ if __name__ == "__main__": # main train loop for epoch in range(0, params["num_epochs"] + 1): - train_loss = tm.train( model, epoch, diff --git a/bat_detect/finetune/prep_data_finetune.py b/batdetect2/finetune/prep_data_finetune.py similarity index 98% rename from bat_detect/finetune/prep_data_finetune.py rename to batdetect2/finetune/prep_data_finetune.py index d8d1df8..11702a9 100644 --- a/bat_detect/finetune/prep_data_finetune.py +++ b/batdetect2/finetune/prep_data_finetune.py @@ -1,16 +1,13 @@ import argparse import json import os -import sys import numpy as np -sys.path.append(os.path.join("..", "..")) -import bat_detect.train.train_utils as tu +import batdetect2.train.train_utils as tu def print_dataset_stats(data, split_name, classes_to_ignore): - print("\nSplit:", split_name) print("Num files:", len(data)) @@ -37,7 +34,6 @@ def print_dataset_stats(data, split_name, classes_to_ignore): def load_file_names(file_name): - if os.path.isfile(file_name): with open(file_name) as da: files = [line.rstrip() for line in da.readlines()] @@ -53,7 +49,6 @@ def load_file_names(file_name): if __name__ == "__main__": - info_str = "\nBatDetect - Prepare Data for Finetuning\n" print(info_str) diff --git a/bat_detect/finetune/readme.md b/batdetect2/finetune/readme.md similarity index 95% rename from bat_detect/finetune/readme.md rename to batdetect2/finetune/readme.md index 5ee54bb..29d2e36 100644 --- a/bat_detect/finetune/readme.md +++ b/batdetect2/finetune/readme.md @@ -1,5 +1,9 @@ - # Finetuning the BatDetet2 model on your own data + +| :warning: WARNING | +|:---------------------------| +| This is not currently working, but we are working on fixing this code | + Main steps: 1. Annotate your data using the annotation GUI. 2. Run `prep_data_finetune.py` to create a training and validation split for your data. diff --git a/bat_detect/models/readme.md b/batdetect2/models/readme.md similarity index 100% rename from bat_detect/models/readme.md rename to batdetect2/models/readme.md diff --git a/bat_detect/train/__init__.py b/batdetect2/train/__init__.py similarity index 100% rename from bat_detect/train/__init__.py rename to batdetect2/train/__init__.py diff --git a/bat_detect/train/audio_dataloader.py b/batdetect2/train/audio_dataloader.py similarity index 99% rename from bat_detect/train/audio_dataloader.py rename to batdetect2/train/audio_dataloader.py index 6d4d9d8..8130ec6 100644 --- a/bat_detect/train/audio_dataloader.py +++ b/batdetect2/train/audio_dataloader.py @@ -7,8 +7,8 @@ import torch import torch.nn.functional as F import torchaudio -import bat_detect.utils.audio_utils as au -from bat_detect.types import AnnotationGroup, HeatmapParameters +import batdetect2.utils.audio_utils as au +from batdetect2.types import AnnotationGroup, HeatmapParameters def generate_gt_heatmaps( diff --git a/bat_detect/train/evaluate.py b/batdetect2/train/evaluate.py similarity index 100% rename from bat_detect/train/evaluate.py rename to batdetect2/train/evaluate.py diff --git a/bat_detect/train/losses.py b/batdetect2/train/losses.py similarity index 100% rename from bat_detect/train/losses.py rename to batdetect2/train/losses.py diff --git a/bat_detect/train/readme.md b/batdetect2/train/readme.md similarity index 100% rename from bat_detect/train/readme.md rename to batdetect2/train/readme.md diff --git a/bat_detect/train/train_model.py b/batdetect2/train/train_model.py similarity index 97% rename from bat_detect/train/train_model.py rename to batdetect2/train/train_model.py index 1f4ea5f..759c2d7 100644 --- a/bat_detect/train/train_model.py +++ b/batdetect2/train/train_model.py @@ -7,15 +7,15 @@ import numpy as np import torch from torch.optim.lr_scheduler import CosineAnnealingLR -from bat_detect.detector import models -from bat_detect.detector import parameters -from bat_detect.train import losses -import bat_detect.detector.post_process as pp -import bat_detect.train.audio_dataloader as adl -import bat_detect.train.evaluate as evl -import bat_detect.train.train_split as ts -import bat_detect.train.train_utils as tu -import bat_detect.utils.plot_utils as pu +from batdetect2.detector import models +from batdetect2.detector import parameters +from batdetect2.train import losses +import batdetect2.detector.post_process as pp +import batdetect2.train.audio_dataloader as adl +import batdetect2.train.evaluate as evl +import batdetect2.train.train_split as ts +import batdetect2.train.train_utils as tu +import batdetect2.utils.plot_utils as pu warnings.filterwarnings("ignore", category=UserWarning) diff --git a/bat_detect/train/train_split.py b/batdetect2/train/train_split.py similarity index 99% rename from bat_detect/train/train_split.py rename to batdetect2/train/train_split.py index 01b5c03..902fe82 100644 --- a/bat_detect/train/train_split.py +++ b/batdetect2/train/train_split.py @@ -26,7 +26,7 @@ def split_diff(ann_dir, wav_dir, load_extra=True): "is_binary": True, # just a bat / not bat dataset ie no classes "ann_path": ann_dir + "train_set_bulgaria_batdetective_with_bbs.json", - "wav_path": wav_dir + "bat_detective/audio/", + "wav_path": wav_dir + "batdetect2ive/audio/", } ) train_sets.append( @@ -154,7 +154,7 @@ def split_same(ann_dir, wav_dir, load_extra=True): "is_binary": True, "ann_path": ann_dir + "train_set_bulgaria_batdetective_with_bbs.json", - "wav_path": wav_dir + "bat_detective/audio/", + "wav_path": wav_dir + "batdetect2ive/audio/", } ) train_sets.append( diff --git a/bat_detect/train/train_utils.py b/batdetect2/train/train_utils.py similarity index 100% rename from bat_detect/train/train_utils.py rename to batdetect2/train/train_utils.py diff --git a/bat_detect/types.py b/batdetect2/types.py similarity index 100% rename from bat_detect/types.py rename to batdetect2/types.py diff --git a/bat_detect/utils/__init__.py b/batdetect2/utils/__init__.py similarity index 100% rename from bat_detect/utils/__init__.py rename to batdetect2/utils/__init__.py diff --git a/bat_detect/utils/audio_utils.py b/batdetect2/utils/audio_utils.py similarity index 94% rename from bat_detect/utils/audio_utils.py rename to batdetect2/utils/audio_utils.py index ba12798..7c5852a 100644 --- a/bat_detect/utils/audio_utils.py +++ b/batdetect2/utils/audio_utils.py @@ -6,36 +6,12 @@ import librosa.core.spectrum import numpy as np import torch -from bat_detect.detector.parameters import ( - DENOISE_SPEC_AVG, - DETECTION_THRESHOLD, - FFT_OVERLAP, - FFT_WIN_LENGTH_S, - MAX_FREQ_HZ, - MAX_SCALE_SPEC, - MIN_FREQ_HZ, - NMS_KERNEL_SIZE, - NMS_TOP_K_PER_SEC, - RESIZE_FACTOR, - SCALE_RAW_AUDIO, - SPEC_DIVIDE_FACTOR, - SPEC_HEIGHT, - SPEC_SCALE, -) - from . import wavfile -try: - from typing import TypedDict -except ImportError: - from typing_extensions import TypedDict - __all__ = [ "load_audio", "generate_spectrogram", "pad_audio", - "SpectrogramParameters", - "DEFAULT_SPECTROGRAM_PARAMETERS", ] @@ -60,7 +36,6 @@ def generate_spectrogram( return_spec_for_viz=False, check_spec_size=True, ): - # generate spectrogram spec = gen_mag_spectrogram( audio, diff --git a/bat_detect/utils/detector_utils.py b/batdetect2/utils/detector_utils.py similarity index 98% rename from bat_detect/utils/detector_utils.py rename to batdetect2/utils/detector_utils.py index cd71ee6..1dbc3a1 100644 --- a/bat_detect/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -7,12 +7,12 @@ import pandas as pd import torch import torch.nn.functional as F -import bat_detect.detector.compute_features as feats -import bat_detect.detector.post_process as pp -import bat_detect.utils.audio_utils as au -from bat_detect.detector import models -from bat_detect.detector.parameters import DEFAULT_MODEL_PATH -from bat_detect.types import ( +import batdetect2.detector.compute_features as feats +import batdetect2.detector.post_process as pp +import batdetect2.utils.audio_utils as au +from batdetect2.detector import models +from batdetect2.detector.parameters import DEFAULT_MODEL_PATH +from batdetect2.types import ( Annotation, DetectionModel, FileAnnotations, diff --git a/bat_detect/utils/plot_utils.py b/batdetect2/utils/plot_utils.py similarity index 99% rename from bat_detect/utils/plot_utils.py rename to batdetect2/utils/plot_utils.py index 6fcb387..4bfde7a 100644 --- a/bat_detect/utils/plot_utils.py +++ b/batdetect2/utils/plot_utils.py @@ -217,7 +217,6 @@ def plot_spec( plot_boxes=True, fixed_aspect=True, ): - if fixed_aspect: # ouptut image will be this width irrespective of the duration of the audio file width = 12 diff --git a/bat_detect/utils/visualize.py b/batdetect2/utils/visualize.py similarity index 100% rename from bat_detect/utils/visualize.py rename to batdetect2/utils/visualize.py diff --git a/bat_detect/utils/wavfile.py b/batdetect2/utils/wavfile.py similarity index 100% rename from bat_detect/utils/wavfile.py rename to batdetect2/utils/wavfile.py diff --git a/run_batdetect.py b/run_batdetect.py index adab803..3079eca 100644 --- a/run_batdetect.py +++ b/run_batdetect.py @@ -1,5 +1,5 @@ -"""Run bat_detect.command.main() from the command line.""" -from bat_detect.cli import detect +"""Run batdetect2.command.main() from the command line.""" +from batdetect2.cli import detect if __name__ == "__main__": detect() diff --git a/scripts/gen_dataset_summary_image.py b/scripts/gen_dataset_summary_image.py index 7e424ad..a916900 100644 --- a/scripts/gen_dataset_summary_image.py +++ b/scripts/gen_dataset_summary_image.py @@ -5,17 +5,15 @@ is the mean spectrogram for each class. import argparse import os -import sys import matplotlib.pyplot as plt import numpy as np import viz_helpers as vz -sys.path.append(os.path.join("..")) -import bat_detect.detector.parameters as parameters -import bat_detect.train.train_split as ts -import bat_detect.train.train_utils as tu -import bat_detect.utils.audio_utils as au +import batdetect2.detector.parameters as parameters +import batdetect2.train.train_split as ts +import batdetect2.train.train_utils as tu +import batdetect2.utils.audio_utils as au if __name__ == "__main__": diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py index c8f8639..e296979 100644 --- a/scripts/gen_spec_image.py +++ b/scripts/gen_spec_image.py @@ -15,11 +15,10 @@ import sys import matplotlib.pyplot as plt import numpy as np -sys.path.append(os.path.join("..")) -import bat_detect.evaluate.evaluate_models as evlm -import bat_detect.utils.audio_utils as au -import bat_detect.utils.detector_utils as du -import bat_detect.utils.plot_utils as viz +import batdetect2.evaluate.evaluate_models as evlm +import batdetect2.utils.audio_utils as au +import batdetect2.utils.detector_utils as du +import batdetect2.utils.plot_utils as viz def filter_anns(anns, start_time, stop_time): diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py index 2588ede..e7ffc06 100644 --- a/scripts/gen_spec_video.py +++ b/scripts/gen_spec_video.py @@ -17,11 +17,10 @@ import matplotlib.pyplot as plt import numpy as np from scipy.io import wavfile -sys.path.append(os.path.join("..")) -import bat_detect.detector.parameters as parameters -import bat_detect.utils.audio_utils as au -import bat_detect.utils.detector_utils as du -import bat_detect.utils.plot_utils as viz +import batdetect2.detector.parameters as parameters +import batdetect2.utils.audio_utils as au +import batdetect2.utils.detector_utils as du +import batdetect2.utils.plot_utils as viz if __name__ == "__main__": diff --git a/scripts/viz_helpers.py b/scripts/viz_helpers.py index 5044b8e..13d09b6 100644 --- a/scripts/viz_helpers.py +++ b/scripts/viz_helpers.py @@ -7,7 +7,7 @@ from scipy import ndimage sys.path.append(os.path.join("..")) -import bat_detect.utils.audio_utils as au +import batdetect2.utils.audio_utils as au def generate_spectrogram_data( diff --git a/tests/test_api.py b/tests/test_api.py index 8158a1f..ed0202a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,7 +7,7 @@ import numpy as np import torch from torch import nn -from bat_detect import api +from batdetect2 import api PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") diff --git a/tests/test_cli.py b/tests/test_cli.py index 4570cf5..ffad17e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,11 @@ """Test the command line interface.""" from click.testing import CliRunner -from bat_detect.cli import cli +from batdetect2.cli import cli def test_cli_base_command(): + """Test the base command.""" runner = CliRunner() result = runner.invoke(cli, ["--help"]) assert result.exit_code == 0 @@ -12,6 +13,7 @@ def test_cli_base_command(): def test_cli_detect_command_help(): + """Test the detect command help.""" runner = CliRunner() result = runner.invoke(cli, ["detect", "--help"]) assert result.exit_code == 0 @@ -19,6 +21,7 @@ def test_cli_detect_command_help(): def test_cli_detect_command_on_test_audio(tmp_path): + """Test the detect command on test audio.""" results_dir = tmp_path / "results" # Remove results dir if it exists