diff --git a/run_batdetect.py b/run_batdetect.py deleted file mode 100644 index 3079eca..0000000 --- a/run_batdetect.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Run batdetect2.command.main() from the command line.""" -from batdetect2.cli import detect - -if __name__ == "__main__": - detect() diff --git a/src/batdetect2/cli/inference.py b/src/batdetect2/cli/inference.py index 70ec546..baa7b41 100644 --- a/src/batdetect2/cli/inference.py +++ b/src/batdetect2/cli/inference.py @@ -84,6 +84,7 @@ def common_predict_options(func): "--format", "format_name", type=str, + default="batdetect2", help=( "Output format name used by the prediction writer. If omitted, " "the config default is used." @@ -159,6 +160,7 @@ def _run_prediction( num_workers: int, format_name: str | None, detection_threshold: float | None, + audio_dir: Path | None = None, ) -> None: logger.info("Initiating prediction process...") @@ -182,11 +184,13 @@ def _run_prediction( detection_threshold=detection_threshold, ) - common_path = audio_files[0].parent if audio_files else None + if audio_dir is None: + audio_dir = audio_files[0].parent if audio_files else None + api.save_predictions( predictions, path=output_path, - audio_dir=common_path, + audio_dir=audio_dir, format=format_name, ) @@ -235,6 +239,7 @@ def predict_directory_command( num_workers=num_workers, format_name=format_name, detection_threshold=detection_threshold, + audio_dir=audio_dir, )