mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Changed public API to use trained model by default
This commit is contained in:
parent
a2deab9f3f
commit
b0d9576a24
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -6,6 +7,7 @@ import torch
|
|||||||
import bat_detect.utils.audio_utils as au
|
import bat_detect.utils.audio_utils as au
|
||||||
import bat_detect.utils.detector_utils as du
|
import bat_detect.utils.detector_utils as du
|
||||||
from bat_detect.detector.parameters import (
|
from bat_detect.detector.parameters import (
|
||||||
|
DEFAULT_MODEL_PATH,
|
||||||
DEFAULT_PROCESSING_CONFIGURATIONS,
|
DEFAULT_PROCESSING_CONFIGURATIONS,
|
||||||
DEFAULT_SPECTROGRAM_PARAMETERS,
|
DEFAULT_SPECTROGRAM_PARAMETERS,
|
||||||
TARGET_SAMPLERATE_HZ,
|
TARGET_SAMPLERATE_HZ,
|
||||||
@ -18,8 +20,8 @@ from bat_detect.types import (
|
|||||||
)
|
)
|
||||||
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||||
|
|
||||||
# Use GPU if available
|
# Remove warnings from torch
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
@ -30,9 +32,19 @@ __all__ = [
|
|||||||
"process_file",
|
"process_file",
|
||||||
"process_spectrogram",
|
"process_spectrogram",
|
||||||
"process_audio",
|
"process_audio",
|
||||||
|
"model",
|
||||||
|
"config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Use GPU if available
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
# Default model
|
||||||
|
MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE)
|
||||||
|
|
||||||
|
|
||||||
def get_config(**kwargs) -> ProcessingConfiguration:
|
def get_config(**kwargs) -> ProcessingConfiguration:
|
||||||
"""Get default processing configuration.
|
"""Get default processing configuration.
|
||||||
|
|
||||||
@ -41,15 +53,22 @@ def get_config(**kwargs) -> ProcessingConfiguration:
|
|||||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# Default processing configuration
|
||||||
|
CONFIG = get_config(**PARAMS)
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
path: str,
|
path: str,
|
||||||
time_exp_fact: float = 1,
|
time_exp_fact: float = 1,
|
||||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> np.ndarray:
|
||||||
"""Load audio from file.
|
"""Load audio from file.
|
||||||
|
|
||||||
|
All audio will be resampled to the target sample rate. If the audio is
|
||||||
|
longer than max_duration, it will be truncated to max_duration.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : str
|
path : str
|
||||||
@ -67,21 +86,20 @@ def load_audio(
|
|||||||
-------
|
-------
|
||||||
np.ndarray
|
np.ndarray
|
||||||
Audio data.
|
Audio data.
|
||||||
int
|
|
||||||
Sample rate.
|
|
||||||
"""
|
"""
|
||||||
return au.load_audio(
|
_, audio = au.load_audio(
|
||||||
path,
|
path,
|
||||||
time_exp_fact,
|
time_exp_fact,
|
||||||
target_samp_rate,
|
target_samp_rate,
|
||||||
scale,
|
scale,
|
||||||
max_duration,
|
max_duration,
|
||||||
)
|
)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: Optional[SpectrogramParameters] = None,
|
config: Optional[SpectrogramParameters] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -91,8 +109,10 @@ def generate_spectrogram(
|
|||||||
----------
|
----------
|
||||||
audio : np.ndarray
|
audio : np.ndarray
|
||||||
Audio data.
|
Audio data.
|
||||||
samp_rate : int
|
samp_rate : int, optional
|
||||||
Sample rate.
|
Sample rate. Defaults to 256000 which is the target sample rate of
|
||||||
|
the default model. Only change if you loaded the audio with a
|
||||||
|
different sample rate.
|
||||||
config : Optional[SpectrogramParameters], optional
|
config : Optional[SpectrogramParameters], optional
|
||||||
Spectrogram parameters, by default None (uses default parameters).
|
Spectrogram parameters, by default None (uses default parameters).
|
||||||
|
|
||||||
@ -117,7 +137,7 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
model: DetectionModel,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> du.RunResults:
|
) -> du.RunResults:
|
||||||
@ -127,15 +147,15 @@ def process_file(
|
|||||||
----------
|
----------
|
||||||
audio_file : str
|
audio_file : str
|
||||||
Path to audio file.
|
Path to audio file.
|
||||||
model : DetectionModel
|
model : DetectionModel, optional
|
||||||
Detection model.
|
Detection model. Uses default model if not specified.
|
||||||
config : Optional[ProcessingConfiguration], optional
|
config : Optional[ProcessingConfiguration], optional
|
||||||
Processing configuration, by default None (uses default parameters).
|
Processing configuration, by default None (uses default parameters).
|
||||||
device : torch.device, optional
|
device : torch.device, optional
|
||||||
Device to use, by default tries to use GPU if available.
|
Device to use, by default tries to use GPU if available.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
config = CONFIG
|
||||||
|
|
||||||
return du.process_file(
|
return du.process_file(
|
||||||
audio_file,
|
audio_file,
|
||||||
@ -147,8 +167,8 @@ def process_file(
|
|||||||
|
|
||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samp_rate: int,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||||
"""Process spectrogram with model.
|
"""Process spectrogram with model.
|
||||||
@ -157,10 +177,13 @@ def process_spectrogram(
|
|||||||
----------
|
----------
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Spectrogram.
|
Spectrogram.
|
||||||
samp_rate : int
|
samp_rate : int, optional
|
||||||
Sample rate of the audio from which the spectrogram was generated.
|
Sample rate of the audio from which the spectrogram was generated.
|
||||||
model : DetectionModel
|
Defaults to 256000 which is the target sample rate of the default
|
||||||
Detection model.
|
model. Only change if you generated the spectrogram with a different
|
||||||
|
sample rate.
|
||||||
|
model : DetectionModel, optional
|
||||||
|
Detection model. Uses default model if not specified.
|
||||||
config : Optional[ProcessingConfiguration], optional
|
config : Optional[ProcessingConfiguration], optional
|
||||||
Processing configuration, by default None (uses default parameters).
|
Processing configuration, by default None (uses default parameters).
|
||||||
|
|
||||||
@ -169,7 +192,7 @@ def process_spectrogram(
|
|||||||
DetectionResult
|
DetectionResult
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
config = CONFIG
|
||||||
|
|
||||||
return du.process_spectrogram(
|
return du.process_spectrogram(
|
||||||
spec,
|
spec,
|
||||||
@ -181,8 +204,8 @@ def process_spectrogram(
|
|||||||
|
|
||||||
def process_audio(
|
def process_audio(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||||
@ -192,10 +215,11 @@ def process_audio(
|
|||||||
----------
|
----------
|
||||||
audio : np.ndarray
|
audio : np.ndarray
|
||||||
Audio data.
|
Audio data.
|
||||||
samp_rate : int
|
samp_rate : int, optional
|
||||||
Sample rate.
|
Sample rate, by default 256000. Only change if you loaded the audio
|
||||||
model : DetectionModel
|
with a different sample rate.
|
||||||
Detection model.
|
model : DetectionModel, optional
|
||||||
|
Detection model. Uses default model if not specified.
|
||||||
config : Optional[ProcessingConfiguration], optional
|
config : Optional[ProcessingConfiguration], optional
|
||||||
Processing configuration, by default None (uses default parameters).
|
Processing configuration, by default None (uses default parameters).
|
||||||
device : torch.device, optional
|
device : torch.device, optional
|
||||||
@ -213,7 +237,7 @@ def process_audio(
|
|||||||
Spectrogram of the audio used for prediction.
|
Spectrogram of the audio used for prediction.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
config = CONFIG
|
||||||
|
|
||||||
return du.process_audio_array(
|
return du.process_audio_array(
|
||||||
audio,
|
audio,
|
||||||
@ -222,3 +246,10 @@ def process_audio(
|
|||||||
config,
|
config,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
model: DetectionModel = MODEL
|
||||||
|
"""Base detection model."""
|
||||||
|
|
||||||
|
config: ProcessingConfiguration = CONFIG
|
||||||
|
"""Default processing configuration."""
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
"""BatDetect2 command line interface."""
|
"""BatDetect2 command line interface."""
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
import click
|
||||||
|
|
||||||
import click # noqa: E402
|
from bat_detect import api
|
||||||
|
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
||||||
from bat_detect import api # noqa: E402
|
from bat_detect.utils.detector_utils import save_results_to_file
|
||||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402
|
|
||||||
from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402
|
|
||||||
|
|
||||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
@ -124,15 +121,14 @@ def detect(
|
|||||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||||
save_results_to_file(results, results_path)
|
save_results_to_file(results, results_path)
|
||||||
except (RuntimeError, ValueError, LookupError) as err:
|
except (RuntimeError, ValueError, LookupError) as err:
|
||||||
# TODO: Check what other errors can be thrown
|
|
||||||
error_files.append(audio_file)
|
error_files.append(audio_file)
|
||||||
click.echo(f"Error processing file!: {err}")
|
click.secho(f"Error processing file!: {err}", fg="red")
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
click.echo(f"\nResults saved to: {ann_dir}")
|
click.echo(f"\nResults saved to: {ann_dir}")
|
||||||
|
|
||||||
if len(error_files) > 0:
|
if len(error_files) > 0:
|
||||||
click.echo("\nUnable to process the follow files:")
|
click.secho("\nUnable to process the follow files:", fg="red")
|
||||||
for err in error_files:
|
for err in error_files:
|
||||||
click.echo(f" {err}")
|
click.echo(f" {err}")
|
||||||
|
|
||||||
|
@ -6,8 +6,6 @@ from matplotlib import patches
|
|||||||
from matplotlib.collections import PatchCollection
|
from matplotlib.collections import PatchCollection
|
||||||
from sklearn.metrics import confusion_matrix
|
from sklearn.metrics import confusion_matrix
|
||||||
|
|
||||||
from . import audio_utils as au
|
|
||||||
|
|
||||||
|
|
||||||
def create_box_image(
|
def create_box_image(
|
||||||
spec,
|
spec,
|
||||||
|
@ -7,16 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from bat_detect.api import (
|
from bat_detect import api
|
||||||
generate_spectrogram,
|
|
||||||
get_config,
|
|
||||||
list_audio_files,
|
|
||||||
load_audio,
|
|
||||||
load_model,
|
|
||||||
process_audio,
|
|
||||||
process_file,
|
|
||||||
process_spectrogram,
|
|
||||||
)
|
|
||||||
|
|
||||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||||
@ -25,7 +16,7 @@ TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
|||||||
|
|
||||||
def test_load_model_with_default_params():
|
def test_load_model_with_default_params():
|
||||||
"""Test loading model with default parameters."""
|
"""Test loading model with default parameters."""
|
||||||
model, params = load_model()
|
model, params = api.load_model()
|
||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
assert isinstance(model, nn.Module)
|
assert isinstance(model, nn.Module)
|
||||||
@ -50,7 +41,7 @@ def test_load_model_with_default_params():
|
|||||||
|
|
||||||
def test_list_audio_files():
|
def test_list_audio_files():
|
||||||
"""Test listing audio files."""
|
"""Test listing audio files."""
|
||||||
audio_files = list_audio_files(TEST_DATA_DIR)
|
audio_files = api.list_audio_files(TEST_DATA_DIR)
|
||||||
|
|
||||||
assert len(audio_files) == 3
|
assert len(audio_files) == 3
|
||||||
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||||
@ -58,18 +49,17 @@ def test_list_audio_files():
|
|||||||
|
|
||||||
def test_load_audio():
|
def test_load_audio():
|
||||||
"""Test loading audio."""
|
"""Test loading audio."""
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
audio = api.load_audio(TEST_DATA[0])
|
||||||
|
|
||||||
assert audio is not None
|
assert audio is not None
|
||||||
assert samplerate == 256000
|
|
||||||
assert isinstance(audio, np.ndarray)
|
assert isinstance(audio, np.ndarray)
|
||||||
assert audio.shape == (128000,)
|
assert audio.shape == (128000,)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_spectrogram():
|
def test_generate_spectrogram():
|
||||||
"""Test generating spectrogram."""
|
"""Test generating spectrogram."""
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
audio = api.load_audio(TEST_DATA[0])
|
||||||
spectrogram = generate_spectrogram(audio, samplerate)
|
spectrogram = api.generate_spectrogram(audio)
|
||||||
|
|
||||||
assert spectrogram is not None
|
assert spectrogram is not None
|
||||||
assert isinstance(spectrogram, torch.Tensor)
|
assert isinstance(spectrogram, torch.Tensor)
|
||||||
@ -78,7 +68,7 @@ def test_generate_spectrogram():
|
|||||||
|
|
||||||
def test_get_default_config():
|
def test_get_default_config():
|
||||||
"""Test getting default configuration."""
|
"""Test getting default configuration."""
|
||||||
config = get_config()
|
config = api.get_config()
|
||||||
|
|
||||||
assert config is not None
|
assert config is not None
|
||||||
assert isinstance(config, dict)
|
assert isinstance(config, dict)
|
||||||
@ -110,11 +100,55 @@ def test_get_default_config():
|
|||||||
assert config["spec_slices"] is False
|
assert config["spec_slices"] is False
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_with_model():
|
def test_api_exposes_default_model():
|
||||||
|
"""Test that API exposes default model."""
|
||||||
|
assert hasattr(api, "model")
|
||||||
|
assert isinstance(api.model, nn.Module)
|
||||||
|
assert type(api.model).__name__ == "Net2DFast"
|
||||||
|
|
||||||
|
# Check that model has expected attributes
|
||||||
|
assert api.model.num_classes == 17
|
||||||
|
assert api.model.num_filts == 128
|
||||||
|
assert api.model.emb_dim == 0
|
||||||
|
assert api.model.ip_height_rs == 128
|
||||||
|
assert api.model.resize_factor == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_exposes_default_config():
|
||||||
|
"""Test that API exposes default configuration."""
|
||||||
|
assert hasattr(api, "config")
|
||||||
|
assert isinstance(api.config, dict)
|
||||||
|
|
||||||
|
assert api.config["target_samp_rate"] == 256000
|
||||||
|
assert api.config["fft_win_length"] == 0.002
|
||||||
|
assert api.config["fft_overlap"] == 0.75
|
||||||
|
assert api.config["resize_factor"] == 0.5
|
||||||
|
assert api.config["spec_divide_factor"] == 32
|
||||||
|
assert api.config["spec_height"] == 256
|
||||||
|
assert api.config["spec_scale"] == "pcen"
|
||||||
|
assert api.config["denoise_spec_avg"] is True
|
||||||
|
assert api.config["max_scale_spec"] is False
|
||||||
|
assert api.config["scale_raw_audio"] is False
|
||||||
|
assert len(api.config["class_names"]) == 17
|
||||||
|
assert api.config["detection_threshold"] == 0.01
|
||||||
|
assert api.config["time_expansion"] == 1
|
||||||
|
assert api.config["top_n"] == 3
|
||||||
|
assert api.config["return_raw_preds"] is False
|
||||||
|
assert api.config["max_duration"] is None
|
||||||
|
assert api.config["nms_kernel_size"] == 9
|
||||||
|
assert api.config["max_freq"] == 120000
|
||||||
|
assert api.config["min_freq"] == 10000
|
||||||
|
assert api.config["nms_top_k_per_sec"] == 200
|
||||||
|
assert api.config["quiet"] is True
|
||||||
|
assert api.config["chunk_size"] == 3
|
||||||
|
assert api.config["cnn_features"] is False
|
||||||
|
assert api.config["spec_features"] is False
|
||||||
|
assert api.config["spec_slices"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_with_default_model():
|
||||||
"""Test processing file with model."""
|
"""Test processing file with model."""
|
||||||
model, params = load_model()
|
predictions = api.process_file(TEST_DATA[0])
|
||||||
config = get_config(**params)
|
|
||||||
predictions = process_file(TEST_DATA[0], model, config=config)
|
|
||||||
|
|
||||||
assert predictions is not None
|
assert predictions is not None
|
||||||
assert isinstance(predictions, dict)
|
assert isinstance(predictions, dict)
|
||||||
@ -141,18 +175,11 @@ def test_process_file_with_model():
|
|||||||
assert len(pred_dict["annotation"]) > 0
|
assert len(pred_dict["annotation"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_process_spectrogram_with_model():
|
def test_process_spectrogram_with_default_model():
|
||||||
"""Test processing spectrogram with model."""
|
"""Test processing spectrogram with model."""
|
||||||
model, params = load_model()
|
audio = api.load_audio(TEST_DATA[0])
|
||||||
config = get_config(**params)
|
spectrogram = api.generate_spectrogram(audio)
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
predictions, features = api.process_spectrogram(spectrogram)
|
||||||
spectrogram = generate_spectrogram(audio, samplerate)
|
|
||||||
predictions, features = process_spectrogram(
|
|
||||||
spectrogram,
|
|
||||||
samplerate,
|
|
||||||
model,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert predictions is not None
|
assert predictions is not None
|
||||||
assert isinstance(predictions, list)
|
assert isinstance(predictions, list)
|
||||||
@ -172,17 +199,10 @@ def test_process_spectrogram_with_model():
|
|||||||
assert len(features) == 1
|
assert len(features) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_process_audio_with_model():
|
def test_process_audio_with_default_model():
|
||||||
"""Test processing audio with model."""
|
"""Test processing audio with model."""
|
||||||
model, params = load_model()
|
audio = api.load_audio(TEST_DATA[0])
|
||||||
config = get_config(**params)
|
predictions, features, spec = api.process_audio(audio)
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
|
||||||
predictions, features, spec = process_audio(
|
|
||||||
audio,
|
|
||||||
samplerate,
|
|
||||||
model,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert predictions is not None
|
assert predictions is not None
|
||||||
assert isinstance(predictions, list)
|
assert isinstance(predictions, list)
|
||||||
|
Loading…
Reference in New Issue
Block a user