Added the chunk_size param to the detect command

This commit is contained in:
mbsantiago 2025-03-12 17:59:18 +00:00
parent 2100a3e483
commit b4c59f7de1
2 changed files with 35 additions and 1 deletions

View File

@ -45,6 +45,12 @@ def cli():
default=False, default=False,
help="Extracts CNN call features", help="Extracts CNN call features",
) )
@click.option(
"--chunk_size",
type=float,
default=2,
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
)
@click.option( @click.option(
"--spec_features", "--spec_features",
is_flag=True, is_flag=True,
@ -80,6 +86,7 @@ def detect(
ann_dir: str, ann_dir: str,
detection_threshold: float, detection_threshold: float,
time_expansion_factor: int, time_expansion_factor: int,
chunk_size: float,
**args, **args,
): ):
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
@ -108,7 +115,7 @@ def detect(
**args, **args,
"time_expansion": time_expansion_factor, "time_expansion": time_expansion_factor,
"spec_slices": False, "spec_slices": False,
"chunk_size": 2, "chunk_size": chunk_size,
"detection_threshold": detection_threshold, "detection_threshold": detection_threshold,
} }
) )
@ -147,6 +154,7 @@ def print_config(config: ProcessingConfiguration):
click.echo("\nProcessing Configuration:") click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
click.echo(f"Detection Threshold: {config.get('detection_threshold')}") click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -130,3 +130,29 @@ def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
) )
assert result.exit_code == 0 assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output assert f"Error processing file {empty_file}" in result.output
def test_can_set_chunk_size(tmp_path: Path):
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--chunk_size",
"1",
],
)
assert "Chunk Size: 1.0s" in result.output
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3