Decouple config loading from preprocess function

This commit is contained in:
mbsantiago 2024-07-16 01:30:34 +01:00
parent 17cf958cd3
commit 335a05d51a
3 changed files with 12 additions and 10 deletions

View File

@ -32,7 +32,7 @@ results will be combined into a dictionary with the following keys:
for each detection. The CNN features are the output of the CNN before for each detection. The CNN features are the output of the CNN before
the final classification layer. You can use these features to train the final classification layer. You can use these features to train
your own classifier, or to do other processing on the detections. your own classifier, or to do other processing on the detections.
They are in the same order as the detections in They are in the same order as the detections in
`results['pred_dict']['annotation']`. Will only be returned if the `results['pred_dict']['annotation']`. Will only be returned if the
`cnn_feats` parameter in the config is set to `True`. `cnn_feats` parameter in the config is set to `True`.
- `spec_slices`: Optional. A list of `numpy` arrays containing the spectrogram - `spec_slices`: Optional. A list of `numpy` arrays containing the spectrogram
@ -96,6 +96,7 @@ If you wish to use a custom model or change the default parameters, please
consult the API documentation in the code. consult the API documentation in the code.
""" """
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -410,7 +411,9 @@ def print_summary(results: RunResults) -> None:
Detection result. Detection result.
""" """
print("Results for " + results["pred_dict"]["id"]) print("Results for " + results["pred_dict"]["id"])
print("{} calls detected\n".format(len(results["pred_dict"]["annotation"]))) print(
"{} calls detected\n".format(len(results["pred_dict"]["annotation"]))
)
print("time\tprob\tlfreq\tspecies_name") print("time\tprob\tlfreq\tspecies_name")
for ann in results["pred_dict"]["annotation"]: for ann in results["pred_dict"]["annotation"]:

View File

@ -65,7 +65,6 @@ def generate_heatmaps(
# Get the position of the sound event # Get the position of the sound event
time, frequency = geometry.get_geometry_point(geom, position=position) time, frequency = geometry.get_geometry_point(geom, position=position)
print(time, frequency)
# Set 1.0 at the position of the sound event in the detection heatmap # Set 1.0 at the position of the sound event in the detection heatmap
detection_heatmap = arrays.set_value_at_pos( detection_heatmap = arrays.set_value_at_pos(

View File

@ -116,6 +116,9 @@ def preprocess_single_annotation(
if path.is_file() and not replace: if path.is_file() and not replace:
return return
if path.is_file() and replace:
path.unlink()
sample = generate_train_example( sample = generate_train_example(
clip_annotation, clip_annotation,
class_mapper, class_mapper,
@ -133,21 +136,18 @@ def preprocess_annotations(
target_sigma: float = TARGET_SIGMA, target_sigma: float = TARGET_SIGMA,
filename_fn: FilenameFn = _get_filename, filename_fn: FilenameFn = _get_filename,
replace: bool = False, replace: bool = False,
config_file: Optional[PathLike] = None, config: Optional[PreprocessingConfig] = None,
max_workers: Optional[int] = None, max_workers: Optional[int] = None,
**kwargs,
) -> None: ) -> None:
"""Preprocess annotations and save to disk.""" """Preprocess annotations and save to disk."""
output_dir = Path(output_dir) output_dir = Path(output_dir)
if config is None:
config = PreprocessingConfig()
if not output_dir.is_dir(): if not output_dir.is_dir():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
if config_file is not None:
config = load_config(config_file, **kwargs)
else:
config = PreprocessingConfig(**kwargs)
with Pool(max_workers) as pool: with Pool(max_workers) as pool:
list( list(
tqdm( tqdm(