adding typing and documentation to load_device function

This commit is contained in:
Santiago Martinez 2023-02-22 15:00:38 +00:00
parent d7ddf72c73
commit ca2a9c39a0

View File

@ -1,6 +1,6 @@
import json import json
import os import os
import sys from typing import List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -12,6 +12,20 @@ import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
from bat_detect.detector import models from bat_detect.detector import models
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"model.pth",
)
__all__ = ["load_model", "DEFAULT_MODEL_PATH"]
def get_default_bd_args(): def get_default_bd_args():
args = {} args = {}
@ -29,25 +43,64 @@ def get_default_bd_args():
return args return args
def get_audio_files(ip_dir): def get_audio_files(ip_dir: str) -> List[str]:
"""Get all audio files in directory.
Args:
ip_dir (str): Input directory.
Returns:
list: List of audio files. Only .wav files are returned. Paths are
relative to ip_dir.
Raises:
FileNotFoundError: Input directory not found.
"""
matches = [] matches = []
for root, dirnames, filenames in os.walk(ip_dir): for root, _, filenames in os.walk(ip_dir):
for filename in filenames: for filename in filenames:
if filename.lower().endswith(".wav"): if filename.lower().endswith(".wav"):
matches.append(os.path.join(root, filename)) matches.append(os.path.join(root, filename))
return matches return matches
def load_model(model_path, load_weights=True): class ModelParameters(TypedDict):
"""Model parameters."""
model_name: str
num_filters: int
emb_dim: int
ip_height: int
resize_factor: int
class_names: List[str]
device: torch.device
def load_model(
model_path: str=DEFAULT_MODEL_PATH,
load_weights: bool=True
) -> Tuple[torch.nn.Module, ModelParameters]:
"""Load model from file.
Args:
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
load_weights (bool, optional): Load weights. Defaults to True.
Returns:
model, params: Model and parameters.
Raises:
FileNotFoundError: Model file not found.
ValueError: Unknown model.
"""
# load model # load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.path.isfile(model_path):
if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device) net_params = torch.load(model_path, map_location=device)
else:
print("Error: model not found.")
sys.exit(1)
params = net_params["params"] params = net_params["params"]
params["device"] = device params["device"] = device
@ -77,7 +130,7 @@ def load_model(model_path, load_weights=True):
resize_factor=params["resize_factor"], resize_factor=params["resize_factor"],
) )
else: else:
print("Error: unknown model.") raise ValueError("Unknown model.")
if load_weights: if load_weights:
model.load_state_dict(net_params["state_dict"]) model.load_state_dict(net_params["state_dict"])
@ -215,7 +268,6 @@ def save_results_to_file(results, op_path):
def compute_spectrogram(audio, sampling_rate, params, return_np=False): def compute_spectrogram(audio, sampling_rate, params, return_np=False):
# pad audio so it is evenly divisible by downsampling factors # pad audio so it is evenly divisible by downsampling factors
duration = audio.shape[0] / float(sampling_rate) duration = audio.shape[0] / float(sampling_rate)
audio = au.pad_audio( audio = au.pad_audio(