Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
4b75e13fa2 Bump version: 1.1.1 → 1.2.0
Some checks failed
Python package / build (3.12) (push) Has been cancelled
Python package / build (3.9) (push) Has been cancelled
Python package / build (3.10) (push) Has been cancelled
Python package / build (3.11) (push) Has been cancelled
2025-03-12 18:03:43 +00:00
Santiago Martinez Balvanera
98bf506634
Merge pull request #45 from macaodha/feat/add-chunk-size-to-cli
Added the chunk_size param to the detect command
2025-03-12 18:01:48 +00:00
mbsantiago
b4c59f7de1 Added the chunk_size param to the detect command 2025-03-12 17:59:18 +00:00
5 changed files with 38 additions and 4 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 1.1.1 current_version = 1.2.0
commit = True commit = True
tag = True tag = True

View File

@ -3,4 +3,4 @@ import logging
numba_logger = logging.getLogger("numba") numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING) numba_logger.setLevel(logging.WARNING)
__version__ = "1.1.1" __version__ = "1.2.0"

View File

@ -45,6 +45,12 @@ def cli():
default=False, default=False,
help="Extracts CNN call features", help="Extracts CNN call features",
) )
@click.option(
"--chunk_size",
type=float,
default=2,
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
)
@click.option( @click.option(
"--spec_features", "--spec_features",
is_flag=True, is_flag=True,
@ -80,6 +86,7 @@ def detect(
ann_dir: str, ann_dir: str,
detection_threshold: float, detection_threshold: float,
time_expansion_factor: int, time_expansion_factor: int,
chunk_size: float,
**args, **args,
): ):
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
@ -108,7 +115,7 @@ def detect(
**args, **args,
"time_expansion": time_expansion_factor, "time_expansion": time_expansion_factor,
"spec_slices": False, "spec_slices": False,
"chunk_size": 2, "chunk_size": chunk_size,
"detection_threshold": detection_threshold, "detection_threshold": detection_threshold,
} }
) )
@ -147,6 +154,7 @@ def print_config(config: ProcessingConfiguration):
click.echo("\nProcessing Configuration:") click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
click.echo(f"Detection Threshold: {config.get('detection_threshold')}") click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,6 @@
[project] [project]
name = "batdetect2" name = "batdetect2"
version = "1.1.1" version = "1.2.0"
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings." description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
authors = [ authors = [
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" }, { "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },

View File

@ -130,3 +130,29 @@ def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
) )
assert result.exit_code == 0 assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output assert f"Error processing file {empty_file}" in result.output
def test_can_set_chunk_size(tmp_path: Path):
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--chunk_size",
"1",
],
)
assert "Chunk Size: 1.0s" in result.output
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3