mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Checked that returned features is a 2d numpy array
This commit is contained in:
parent
29074c689e
commit
0a22f1798e
@ -115,6 +115,7 @@ from batdetect2.types import (
|
|||||||
DetectionModel,
|
DetectionModel,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
ProcessingConfiguration,
|
ProcessingConfiguration,
|
||||||
|
RunResults,
|
||||||
SpectrogramParameters,
|
SpectrogramParameters,
|
||||||
)
|
)
|
||||||
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
||||||
@ -134,6 +135,7 @@ __all__ = [
|
|||||||
"process_audio",
|
"process_audio",
|
||||||
"process_file",
|
"process_file",
|
||||||
"process_spectrogram",
|
"process_spectrogram",
|
||||||
|
"print_summary",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -150,11 +152,11 @@ def get_config(**kwargs) -> ProcessingConfiguration:
|
|||||||
|
|
||||||
Can be used to override default parameters by passing keyword arguments.
|
Can be used to override default parameters by passing keyword arguments.
|
||||||
"""
|
"""
|
||||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **PARAMS, **kwargs} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# Default processing configuration
|
# Default processing configuration
|
||||||
CONFIG = get_config(**PARAMS)
|
CONFIG = get_config()
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
@ -270,7 +272,7 @@ def process_spectrogram(
|
|||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
"""Process spectrogram with model.
|
"""Process spectrogram with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -289,7 +291,11 @@ def process_spectrogram(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
DetectionResult
|
detections : List[Annotation]
|
||||||
|
List of detections.
|
||||||
|
features: np.ndarray
|
||||||
|
An array of features. The array has shape (n_detections, n_features)
|
||||||
|
where each row is a feature vector for a detection.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = CONFIG
|
config = CONFIG
|
||||||
@ -308,7 +314,7 @@ def process_audio(
|
|||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||||
"""Process audio array with model.
|
"""Process audio array with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -329,10 +335,9 @@ def process_audio(
|
|||||||
-------
|
-------
|
||||||
annotations : List[Annotation]
|
annotations : List[Annotation]
|
||||||
List of predicted annotations.
|
List of predicted annotations.
|
||||||
|
features: np.ndarray
|
||||||
features: List[np.ndarray]
|
An array of features. The array has shape (n_detections, n_features)
|
||||||
List of extracted features for each annotation.
|
where each row is a feature vector for a detection.
|
||||||
|
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Spectrogram of the audio used for prediction.
|
Spectrogram of the audio used for prediction.
|
||||||
"""
|
"""
|
||||||
@ -395,3 +400,26 @@ model: DetectionModel = MODEL
|
|||||||
|
|
||||||
config: ProcessingConfiguration = CONFIG
|
config: ProcessingConfiguration = CONFIG
|
||||||
"""Default processing configuration."""
|
"""Default processing configuration."""
|
||||||
|
|
||||||
|
|
||||||
|
def print_summary(results: RunResults) -> None:
|
||||||
|
"""Print summary of results.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
results : DetectionResult
|
||||||
|
Detection result.
|
||||||
|
"""
|
||||||
|
print("Results for " + results["pred_dict"]["id"])
|
||||||
|
print("{} calls detected\n".format(len(results["pred_dict"]["annotation"])))
|
||||||
|
|
||||||
|
print("time\tprob\tlfreq\tspecies_name")
|
||||||
|
for ann in results["pred_dict"]["annotation"]:
|
||||||
|
print(
|
||||||
|
"{}\t{}\t{}\t{}".format(
|
||||||
|
ann["start_time"],
|
||||||
|
ann["class_prob"],
|
||||||
|
ann["low_freq"],
|
||||||
|
ann["class"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -295,7 +295,6 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
op_path (str): Output path.
|
op_path (str): Output path.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# make directory if it does not exist
|
# make directory if it does not exist
|
||||||
if not os.path.isdir(os.path.dirname(op_path)):
|
if not os.path.isdir(os.path.dirname(op_path)):
|
||||||
os.makedirs(os.path.dirname(op_path))
|
os.makedirs(os.path.dirname(op_path))
|
||||||
@ -474,7 +473,7 @@ def _process_spectrogram(
|
|||||||
samplerate: int,
|
samplerate: int,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[PredictionResults, List[np.ndarray]]:
|
) -> Tuple[PredictionResults, np.ndarray]:
|
||||||
# evaluate model
|
# evaluate model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(spec)
|
outputs = model(spec)
|
||||||
@ -504,7 +503,7 @@ def _process_spectrogram(
|
|||||||
):
|
):
|
||||||
pred_nms["class_probs"] = class_probs[:-1, :]
|
pred_nms["class_probs"] = class_probs[:-1, :]
|
||||||
|
|
||||||
return pred_nms, features
|
return pred_nms, np.concatenate(features, axis=0)
|
||||||
|
|
||||||
|
|
||||||
def postprocess_model_outputs(
|
def postprocess_model_outputs(
|
||||||
@ -550,7 +549,7 @@ def process_spectrogram(
|
|||||||
samplerate: int,
|
samplerate: int,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
"""Process a spectrogram with detection model.
|
"""Process a spectrogram with detection model.
|
||||||
|
|
||||||
Will run non-maximum suppression on the output of the model.
|
Will run non-maximum suppression on the output of the model.
|
||||||
@ -569,10 +568,11 @@ def process_spectrogram(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
annotations : List[Annotation]
|
detections: List[Annotation]
|
||||||
List of annotations predicted by the model.
|
List of detections predicted by the model.
|
||||||
features : List[np.ndarray]
|
features : np.ndarray
|
||||||
List of CNN features associated with each annotation.
|
An array of CNN features associated with each annotation.
|
||||||
|
The array is of shape (num_detections, num_features).
|
||||||
Is empty if `config["cnn_features"]` is False.
|
Is empty if `config["cnn_features"]` is False.
|
||||||
"""
|
"""
|
||||||
pred_nms, features = _process_spectrogram(
|
pred_nms, features = _process_spectrogram(
|
||||||
@ -582,12 +582,12 @@ def process_spectrogram(
|
|||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
|
|
||||||
annotations = get_annotations_from_preds(
|
detections = get_annotations_from_preds(
|
||||||
pred_nms,
|
pred_nms,
|
||||||
config["class_names"],
|
config["class_names"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return annotations, features
|
return detections, features
|
||||||
|
|
||||||
|
|
||||||
def _process_audio_array(
|
def _process_audio_array(
|
||||||
@ -596,7 +596,7 @@ def _process_audio_array(
|
|||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]:
|
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||||
# load audio file and compute spectrogram
|
# load audio file and compute spectrogram
|
||||||
_, spec, _ = compute_spectrogram(
|
_, spec, _ = compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
@ -634,7 +634,7 @@ def process_audio_array(
|
|||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||||
"""Process a single audio array with detection model.
|
"""Process a single audio array with detection model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -656,10 +656,9 @@ def process_audio_array(
|
|||||||
-------
|
-------
|
||||||
annotations : List[Annotation]
|
annotations : List[Annotation]
|
||||||
List of annotations predicted by the model.
|
List of annotations predicted by the model.
|
||||||
|
features : np.ndarray
|
||||||
features : List[np.ndarray]
|
Array of CNN features associated with each annotation.
|
||||||
List of CNN features associated with each annotation.
|
The array is of shape (num_detections, num_features).
|
||||||
|
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Spectrogram of the audio used as input.
|
Spectrogram of the audio used as input.
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ def test_get_default_config():
|
|||||||
assert config["denoise_spec_avg"] is True
|
assert config["denoise_spec_avg"] is True
|
||||||
assert config["max_scale_spec"] is False
|
assert config["max_scale_spec"] is False
|
||||||
assert config["scale_raw_audio"] is False
|
assert config["scale_raw_audio"] is False
|
||||||
assert len(config["class_names"]) == 0
|
assert len(config["class_names"]) == 17
|
||||||
assert config["detection_threshold"] == 0.01
|
assert config["detection_threshold"] == 0.01
|
||||||
assert config["time_expansion"] == 1
|
assert config["time_expansion"] == 1
|
||||||
assert config["top_n"] == 3
|
assert config["top_n"] == 3
|
||||||
@ -193,8 +193,8 @@ def test_process_spectrogram_with_default_model():
|
|||||||
assert "high_freq" in sample_pred
|
assert "high_freq" in sample_pred
|
||||||
|
|
||||||
assert features is not None
|
assert features is not None
|
||||||
assert isinstance(features, list)
|
assert isinstance(features, np.ndarray)
|
||||||
assert len(features) == 1
|
assert len(features) == len(predictions)
|
||||||
|
|
||||||
|
|
||||||
def test_process_audio_with_default_model():
|
def test_process_audio_with_default_model():
|
||||||
@ -216,8 +216,8 @@ def test_process_audio_with_default_model():
|
|||||||
assert "high_freq" in sample_pred
|
assert "high_freq" in sample_pred
|
||||||
|
|
||||||
assert features is not None
|
assert features is not None
|
||||||
assert isinstance(features, list)
|
assert isinstance(features, np.ndarray)
|
||||||
assert len(features) == 1
|
assert len(features) == len(predictions)
|
||||||
|
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
assert isinstance(spec, torch.Tensor)
|
assert isinstance(spec, torch.Tensor)
|
||||||
|
Loading…
Reference in New Issue
Block a user