mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Checked that returned features is a 2d numpy array
This commit is contained in:
parent
29074c689e
commit
0a22f1798e
@ -115,6 +115,7 @@ from batdetect2.types import (
|
||||
DetectionModel,
|
||||
ModelOutput,
|
||||
ProcessingConfiguration,
|
||||
RunResults,
|
||||
SpectrogramParameters,
|
||||
)
|
||||
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
||||
@ -134,6 +135,7 @@ __all__ = [
|
||||
"process_audio",
|
||||
"process_file",
|
||||
"process_spectrogram",
|
||||
"print_summary",
|
||||
]
|
||||
|
||||
|
||||
@ -150,11 +152,11 @@ def get_config(**kwargs) -> ProcessingConfiguration:
|
||||
|
||||
Can be used to override default parameters by passing keyword arguments.
|
||||
"""
|
||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **PARAMS, **kwargs} # type: ignore
|
||||
|
||||
|
||||
# Default processing configuration
|
||||
CONFIG = get_config(**PARAMS)
|
||||
CONFIG = get_config()
|
||||
|
||||
|
||||
def load_audio(
|
||||
@ -270,7 +272,7 @@ def process_spectrogram(
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
Parameters
|
||||
@ -289,7 +291,11 @@ def process_spectrogram(
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionResult
|
||||
detections : List[Annotation]
|
||||
List of detections.
|
||||
features: np.ndarray
|
||||
An array of features. The array has shape (n_detections, n_features)
|
||||
where each row is a feature vector for a detection.
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
@ -308,7 +314,7 @@ def process_audio(
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
|
||||
Parameters
|
||||
@ -329,10 +335,9 @@ def process_audio(
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of predicted annotations.
|
||||
|
||||
features: List[np.ndarray]
|
||||
List of extracted features for each annotation.
|
||||
|
||||
features: np.ndarray
|
||||
An array of features. The array has shape (n_detections, n_features)
|
||||
where each row is a feature vector for a detection.
|
||||
spec : torch.Tensor
|
||||
Spectrogram of the audio used for prediction.
|
||||
"""
|
||||
@ -395,3 +400,26 @@ model: DetectionModel = MODEL
|
||||
|
||||
config: ProcessingConfiguration = CONFIG
|
||||
"""Default processing configuration."""
|
||||
|
||||
|
||||
def print_summary(results: RunResults) -> None:
|
||||
"""Print summary of results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
results : DetectionResult
|
||||
Detection result.
|
||||
"""
|
||||
print("Results for " + results["pred_dict"]["id"])
|
||||
print("{} calls detected\n".format(len(results["pred_dict"]["annotation"])))
|
||||
|
||||
print("time\tprob\tlfreq\tspecies_name")
|
||||
for ann in results["pred_dict"]["annotation"]:
|
||||
print(
|
||||
"{}\t{}\t{}\t{}".format(
|
||||
ann["start_time"],
|
||||
ann["class_prob"],
|
||||
ann["low_freq"],
|
||||
ann["class"],
|
||||
)
|
||||
)
|
||||
|
@ -295,7 +295,6 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
op_path (str): Output path.
|
||||
|
||||
"""
|
||||
|
||||
# make directory if it does not exist
|
||||
if not os.path.isdir(os.path.dirname(op_path)):
|
||||
os.makedirs(os.path.dirname(op_path))
|
||||
@ -474,7 +473,7 @@ def _process_spectrogram(
|
||||
samplerate: int,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[PredictionResults, List[np.ndarray]]:
|
||||
) -> Tuple[PredictionResults, np.ndarray]:
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec)
|
||||
@ -504,7 +503,7 @@ def _process_spectrogram(
|
||||
):
|
||||
pred_nms["class_probs"] = class_probs[:-1, :]
|
||||
|
||||
return pred_nms, features
|
||||
return pred_nms, np.concatenate(features, axis=0)
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
@ -550,7 +549,7 @@ def process_spectrogram(
|
||||
samplerate: int,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Process a spectrogram with detection model.
|
||||
|
||||
Will run non-maximum suppression on the output of the model.
|
||||
@ -569,10 +568,11 @@ def process_spectrogram(
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of annotations predicted by the model.
|
||||
features : List[np.ndarray]
|
||||
List of CNN features associated with each annotation.
|
||||
detections: List[Annotation]
|
||||
List of detections predicted by the model.
|
||||
features : np.ndarray
|
||||
An array of CNN features associated with each annotation.
|
||||
The array is of shape (num_detections, num_features).
|
||||
Is empty if `config["cnn_features"]` is False.
|
||||
"""
|
||||
pred_nms, features = _process_spectrogram(
|
||||
@ -582,12 +582,12 @@ def process_spectrogram(
|
||||
config,
|
||||
)
|
||||
|
||||
annotations = get_annotations_from_preds(
|
||||
detections = get_annotations_from_preds(
|
||||
pred_nms,
|
||||
config["class_names"],
|
||||
)
|
||||
|
||||
return annotations, features
|
||||
return detections, features
|
||||
|
||||
|
||||
def _process_audio_array(
|
||||
@ -596,7 +596,7 @@ def _process_audio_array(
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]:
|
||||
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||
# load audio file and compute spectrogram
|
||||
_, spec, _ = compute_spectrogram(
|
||||
audio,
|
||||
@ -634,7 +634,7 @@ def process_audio_array(
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process a single audio array with detection model.
|
||||
|
||||
Parameters
|
||||
@ -656,10 +656,9 @@ def process_audio_array(
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of annotations predicted by the model.
|
||||
|
||||
features : List[np.ndarray]
|
||||
List of CNN features associated with each annotation.
|
||||
|
||||
features : np.ndarray
|
||||
Array of CNN features associated with each annotation.
|
||||
The array is of shape (num_detections, num_features).
|
||||
spec : torch.Tensor
|
||||
Spectrogram of the audio used as input.
|
||||
|
||||
|
@ -81,7 +81,7 @@ def test_get_default_config():
|
||||
assert config["denoise_spec_avg"] is True
|
||||
assert config["max_scale_spec"] is False
|
||||
assert config["scale_raw_audio"] is False
|
||||
assert len(config["class_names"]) == 0
|
||||
assert len(config["class_names"]) == 17
|
||||
assert config["detection_threshold"] == 0.01
|
||||
assert config["time_expansion"] == 1
|
||||
assert config["top_n"] == 3
|
||||
@ -193,8 +193,8 @@ def test_process_spectrogram_with_default_model():
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
assert isinstance(features, np.ndarray)
|
||||
assert len(features) == len(predictions)
|
||||
|
||||
|
||||
def test_process_audio_with_default_model():
|
||||
@ -216,8 +216,8 @@ def test_process_audio_with_default_model():
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
assert isinstance(features, np.ndarray)
|
||||
assert len(features) == len(predictions)
|
||||
|
||||
assert spec is not None
|
||||
assert isinstance(spec, torch.Tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user