adding typing and documentation to load_device function

This commit is contained in:
Santiago Martinez 2023-02-22 15:00:38 +00:00
parent d7ddf72c73
commit ca2a9c39a0

View File

@ -1,6 +1,6 @@
import json
import os
import sys
from typing import List, Tuple
import numpy as np
import pandas as pd
@ -12,6 +12,20 @@ import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
from bat_detect.detector import models
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"model.pth",
)
__all__ = ["load_model", "DEFAULT_MODEL_PATH"]
def get_default_bd_args():
args = {}
@ -29,25 +43,64 @@ def get_default_bd_args():
return args
def get_audio_files(ip_dir):
def get_audio_files(ip_dir: str) -> List[str]:
"""Get all audio files in directory.
Args:
ip_dir (str): Input directory.
Returns:
list: List of audio files. Only .wav files are returned. Paths are
relative to ip_dir.
Raises:
FileNotFoundError: Input directory not found.
"""
matches = []
for root, dirnames, filenames in os.walk(ip_dir):
for root, _, filenames in os.walk(ip_dir):
for filename in filenames:
if filename.lower().endswith(".wav"):
matches.append(os.path.join(root, filename))
return matches
def load_model(model_path, load_weights=True):
class ModelParameters(TypedDict):
"""Model parameters."""
model_name: str
num_filters: int
emb_dim: int
ip_height: int
resize_factor: int
class_names: List[str]
device: torch.device
def load_model(
model_path: str=DEFAULT_MODEL_PATH,
load_weights: bool=True
) -> Tuple[torch.nn.Module, ModelParameters]:
"""Load model from file.
Args:
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
load_weights (bool, optional): Load weights. Defaults to True.
Returns:
model, params: Model and parameters.
Raises:
FileNotFoundError: Model file not found.
ValueError: Unknown model.
"""
# load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.path.isfile(model_path):
net_params = torch.load(model_path, map_location=device)
else:
print("Error: model not found.")
sys.exit(1)
if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device)
params = net_params["params"]
params["device"] = device
@ -77,7 +130,7 @@ def load_model(model_path, load_weights=True):
resize_factor=params["resize_factor"],
)
else:
print("Error: unknown model.")
raise ValueError("Unknown model.")
if load_weights:
model.load_state_dict(net_params["state_dict"])
@ -215,7 +268,6 @@ def save_results_to_file(results, op_path):
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
# pad audio so it is evenly divisible by downsampling factors
duration = audio.shape[0] / float(sampling_rate)
audio = au.pad_audio(