mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
adding typing and documentation to load_device function
This commit is contained in:
parent
d7ddf72c73
commit
ca2a9c39a0
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -12,6 +12,20 @@ import bat_detect.detector.post_process as pp
|
||||
import bat_detect.utils.audio_utils as au
|
||||
from bat_detect.detector import models
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"model.pth",
|
||||
)
|
||||
|
||||
__all__ = ["load_model", "DEFAULT_MODEL_PATH"]
|
||||
|
||||
|
||||
def get_default_bd_args():
|
||||
args = {}
|
||||
@ -29,25 +43,64 @@ def get_default_bd_args():
|
||||
return args
|
||||
|
||||
|
||||
def get_audio_files(ip_dir):
|
||||
def get_audio_files(ip_dir: str) -> List[str]:
|
||||
"""Get all audio files in directory.
|
||||
|
||||
Args:
|
||||
ip_dir (str): Input directory.
|
||||
|
||||
Returns:
|
||||
list: List of audio files. Only .wav files are returned. Paths are
|
||||
relative to ip_dir.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Input directory not found.
|
||||
|
||||
"""
|
||||
matches = []
|
||||
for root, dirnames, filenames in os.walk(ip_dir):
|
||||
for root, _, filenames in os.walk(ip_dir):
|
||||
for filename in filenames:
|
||||
if filename.lower().endswith(".wav"):
|
||||
matches.append(os.path.join(root, filename))
|
||||
return matches
|
||||
|
||||
|
||||
def load_model(model_path, load_weights=True):
|
||||
class ModelParameters(TypedDict):
|
||||
"""Model parameters."""
|
||||
model_name: str
|
||||
num_filters: int
|
||||
emb_dim: int
|
||||
ip_height: int
|
||||
resize_factor: int
|
||||
class_names: List[str]
|
||||
device: torch.device
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str=DEFAULT_MODEL_PATH,
|
||||
load_weights: bool=True
|
||||
) -> Tuple[torch.nn.Module, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
|
||||
load_weights (bool, optional): Load weights. Defaults to True.
|
||||
|
||||
Returns:
|
||||
model, params: Model and parameters.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Model file not found.
|
||||
ValueError: Unknown model.
|
||||
"""
|
||||
|
||||
# load model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if os.path.isfile(model_path):
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
else:
|
||||
print("Error: model not found.")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
|
||||
params = net_params["params"]
|
||||
params["device"] = device
|
||||
@ -77,7 +130,7 @@ def load_model(model_path, load_weights=True):
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
else:
|
||||
print("Error: unknown model.")
|
||||
raise ValueError("Unknown model.")
|
||||
|
||||
if load_weights:
|
||||
model.load_state_dict(net_params["state_dict"])
|
||||
@ -215,7 +268,6 @@ def save_results_to_file(results, op_path):
|
||||
|
||||
|
||||
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
||||
|
||||
# pad audio so it is evenly divisible by downsampling factors
|
||||
duration = audio.shape[0] / float(sampling_rate)
|
||||
audio = au.pad_audio(
|
||||
|
Loading…
Reference in New Issue
Block a user