mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 07:02:01 +02:00
adding typing and documentation to load_device function
This commit is contained in:
parent
d7ddf72c73
commit
ca2a9c39a0
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -12,6 +12,20 @@ import bat_detect.detector.post_process as pp
|
|||||||
import bat_detect.utils.audio_utils as au
|
import bat_detect.utils.audio_utils as au
|
||||||
from bat_detect.detector import models
|
from bat_detect.detector import models
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import TypedDict
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PATH = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"models",
|
||||||
|
"model.pth",
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["load_model", "DEFAULT_MODEL_PATH"]
|
||||||
|
|
||||||
|
|
||||||
def get_default_bd_args():
|
def get_default_bd_args():
|
||||||
args = {}
|
args = {}
|
||||||
@ -29,25 +43,64 @@ def get_default_bd_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def get_audio_files(ip_dir):
|
def get_audio_files(ip_dir: str) -> List[str]:
|
||||||
|
"""Get all audio files in directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_dir (str): Input directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of audio files. Only .wav files are returned. Paths are
|
||||||
|
relative to ip_dir.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: Input directory not found.
|
||||||
|
|
||||||
|
"""
|
||||||
matches = []
|
matches = []
|
||||||
for root, dirnames, filenames in os.walk(ip_dir):
|
for root, _, filenames in os.walk(ip_dir):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if filename.lower().endswith(".wav"):
|
if filename.lower().endswith(".wav"):
|
||||||
matches.append(os.path.join(root, filename))
|
matches.append(os.path.join(root, filename))
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path, load_weights=True):
|
class ModelParameters(TypedDict):
|
||||||
|
"""Model parameters."""
|
||||||
|
model_name: str
|
||||||
|
num_filters: int
|
||||||
|
emb_dim: int
|
||||||
|
ip_height: int
|
||||||
|
resize_factor: int
|
||||||
|
class_names: List[str]
|
||||||
|
device: torch.device
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_path: str=DEFAULT_MODEL_PATH,
|
||||||
|
load_weights: bool=True
|
||||||
|
) -> Tuple[torch.nn.Module, ModelParameters]:
|
||||||
|
"""Load model from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
|
||||||
|
load_weights (bool, optional): Load weights. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model, params: Model and parameters.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: Model file not found.
|
||||||
|
ValueError: Unknown model.
|
||||||
|
"""
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
if os.path.isfile(model_path):
|
|
||||||
net_params = torch.load(model_path, map_location=device)
|
if not os.path.isfile(model_path):
|
||||||
else:
|
raise FileNotFoundError("Model file not found.")
|
||||||
print("Error: model not found.")
|
|
||||||
sys.exit(1)
|
net_params = torch.load(model_path, map_location=device)
|
||||||
|
|
||||||
params = net_params["params"]
|
params = net_params["params"]
|
||||||
params["device"] = device
|
params["device"] = device
|
||||||
@ -77,7 +130,7 @@ def load_model(model_path, load_weights=True):
|
|||||||
resize_factor=params["resize_factor"],
|
resize_factor=params["resize_factor"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Error: unknown model.")
|
raise ValueError("Unknown model.")
|
||||||
|
|
||||||
if load_weights:
|
if load_weights:
|
||||||
model.load_state_dict(net_params["state_dict"])
|
model.load_state_dict(net_params["state_dict"])
|
||||||
@ -215,7 +268,6 @@ def save_results_to_file(results, op_path):
|
|||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
||||||
|
|
||||||
# pad audio so it is evenly divisible by downsampling factors
|
# pad audio so it is evenly divisible by downsampling factors
|
||||||
duration = audio.shape[0] / float(sampling_rate)
|
duration = audio.shape[0] / float(sampling_rate)
|
||||||
audio = au.pad_audio(
|
audio = au.pad_audio(
|
||||||
|
Loading…
Reference in New Issue
Block a user