diff --git a/batdetect2/cli.py b/batdetect2/cli.py index 9706312..8379a00 100644 --- a/batdetect2/cli.py +++ b/batdetect2/cli.py @@ -45,6 +45,12 @@ def cli(): default=False, help="Extracts CNN call features", ) +@click.option( + "--chunk_size", + type=float, + default=2, + help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.", +) @click.option( "--spec_features", is_flag=True, @@ -80,6 +86,7 @@ def detect( ann_dir: str, detection_threshold: float, time_expansion_factor: int, + chunk_size: float, **args, ): """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. @@ -108,7 +115,7 @@ def detect( **args, "time_expansion": time_expansion_factor, "spec_slices": False, - "chunk_size": 2, + "chunk_size": chunk_size, "detection_threshold": detection_threshold, } ) @@ -147,6 +154,7 @@ def print_config(config: ProcessingConfiguration): click.echo("\nProcessing Configuration:") click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") click.echo(f"Detection Threshold: {config.get('detection_threshold')}") + click.echo(f"Chunk Size: {config.get('chunk_size')}s") if __name__ == "__main__": diff --git a/tests/test_cli.py b/tests/test_cli.py index 4bcbe08..adba969 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -130,3 +130,29 @@ def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path): ) assert result.exit_code == 0 assert f"Error processing file {empty_file}" in result.output + + +def test_can_set_chunk_size(tmp_path: Path): + results_dir = tmp_path / "results" + + # Remove results dir if it exists + if results_dir.exists(): + results_dir.rmdir() + + result = runner.invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + "--chunk_size", + "1", + ], + ) + + assert "Chunk Size: 1.0s" in result.output + assert result.exit_code == 0 + assert results_dir.exists() + assert len(list(results_dir.glob("*.csv"))) == 3 + assert len(list(results_dir.glob("*.json"))) == 3