diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index 2815cc0..b4c0e47 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -1,6 +1,6 @@ import json import os -import sys +from typing import List, Tuple import numpy as np import pandas as pd @@ -12,6 +12,20 @@ import bat_detect.detector.post_process as pp import bat_detect.utils.audio_utils as au from bat_detect.detector import models +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + + +DEFAULT_MODEL_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "models", + "model.pth", +) + +__all__ = ["load_model", "DEFAULT_MODEL_PATH"] + def get_default_bd_args(): args = {} @@ -29,25 +43,64 @@ def get_default_bd_args(): return args -def get_audio_files(ip_dir): +def get_audio_files(ip_dir: str) -> List[str]: + """Get all audio files in directory. + Args: + ip_dir (str): Input directory. + + Returns: + list: List of audio files. Only .wav files are returned. Paths are + relative to ip_dir. + + Raises: + FileNotFoundError: Input directory not found. + + """ matches = [] - for root, dirnames, filenames in os.walk(ip_dir): + for root, _, filenames in os.walk(ip_dir): for filename in filenames: if filename.lower().endswith(".wav"): matches.append(os.path.join(root, filename)) return matches -def load_model(model_path, load_weights=True): +class ModelParameters(TypedDict): + """Model parameters.""" + model_name: str + num_filters: int + emb_dim: int + ip_height: int + resize_factor: int + class_names: List[str] + device: torch.device + + +def load_model( + model_path: str=DEFAULT_MODEL_PATH, + load_weights: bool=True +) -> Tuple[torch.nn.Module, ModelParameters]: + """Load model from file. + + Args: + model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH. + load_weights (bool, optional): Load weights. Defaults to True. + + Returns: + model, params: Model and parameters. + + Raises: + FileNotFoundError: Model file not found. + ValueError: Unknown model. + """ # load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if os.path.isfile(model_path): - net_params = torch.load(model_path, map_location=device) - else: - print("Error: model not found.") - sys.exit(1) + + if not os.path.isfile(model_path): + raise FileNotFoundError("Model file not found.") + + net_params = torch.load(model_path, map_location=device) params = net_params["params"] params["device"] = device @@ -77,7 +130,7 @@ def load_model(model_path, load_weights=True): resize_factor=params["resize_factor"], ) else: - print("Error: unknown model.") + raise ValueError("Unknown model.") if load_weights: model.load_state_dict(net_params["state_dict"]) @@ -215,7 +268,6 @@ def save_results_to_file(results, op_path): def compute_spectrogram(audio, sampling_rate, params, return_np=False): - # pad audio so it is evenly divisible by downsampling factors duration = audio.shape[0] / float(sampling_rate) audio = au.pad_audio(