mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Change inference command to predict
This commit is contained in:
parent
2f03abe8f6
commit
f0af5dd79e
@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
|
|||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.evaluate import evaluate_command
|
from batdetect2.cli.evaluate import evaluate_command
|
||||||
|
from batdetect2.cli.inference import predict
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -10,6 +11,7 @@ __all__ = [
|
|||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
"evaluate_command",
|
"evaluate_command",
|
||||||
|
"predict",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
229
src/batdetect2/cli/inference.py
Normal file
229
src/batdetect2/cli/inference.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import click
|
||||||
|
from loguru import logger
|
||||||
|
from soundevent import io
|
||||||
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
|
||||||
|
__all__ = ["predict"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.group(name="predict")
|
||||||
|
def predict() -> None:
|
||||||
|
"""Run prediction with BatDetect2 API v2."""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_api(
|
||||||
|
model_path: Path,
|
||||||
|
audio_config: Path | None,
|
||||||
|
inference_config: Path | None,
|
||||||
|
outputs_config: Path | None,
|
||||||
|
logging_config: Path | None,
|
||||||
|
):
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.audio import AudioConfig
|
||||||
|
from batdetect2.inference import InferenceConfig
|
||||||
|
from batdetect2.logging import AppLoggingConfig
|
||||||
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
|
||||||
|
audio_conf = (
|
||||||
|
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||||
|
)
|
||||||
|
inference_conf = (
|
||||||
|
InferenceConfig.load(inference_config)
|
||||||
|
if inference_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
outputs_conf = (
|
||||||
|
OutputsConfig.load(outputs_config)
|
||||||
|
if outputs_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logging_conf = (
|
||||||
|
AppLoggingConfig.load(logging_config)
|
||||||
|
if logging_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
api = BatDetect2API.from_checkpoint(
|
||||||
|
model_path,
|
||||||
|
audio_config=audio_conf,
|
||||||
|
inference_config=inference_conf,
|
||||||
|
outputs_config=outputs_conf,
|
||||||
|
logging_config=logging_conf,
|
||||||
|
)
|
||||||
|
return api, audio_conf, inference_conf, outputs_conf
|
||||||
|
|
||||||
|
|
||||||
|
def _run_inference(
|
||||||
|
model_path: Path,
|
||||||
|
audio_files: list[Path],
|
||||||
|
output_path: Path,
|
||||||
|
audio_config: Path | None,
|
||||||
|
inference_config: Path | None,
|
||||||
|
outputs_config: Path | None,
|
||||||
|
logging_config: Path | None,
|
||||||
|
batch_size: int | None,
|
||||||
|
num_workers: int,
|
||||||
|
format_name: str | None,
|
||||||
|
) -> None:
|
||||||
|
logger.info("Initiating prediction process...")
|
||||||
|
|
||||||
|
api, audio_conf, inference_conf, outputs_conf = _build_api(
|
||||||
|
model_path,
|
||||||
|
audio_config,
|
||||||
|
inference_config,
|
||||||
|
outputs_config,
|
||||||
|
logging_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Found {num_files} audio files", num_files=len(audio_files))
|
||||||
|
|
||||||
|
predictions = api.process_files(
|
||||||
|
audio_files,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
audio_config=audio_conf,
|
||||||
|
inference_config=inference_conf,
|
||||||
|
output_config=outputs_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
common_path = audio_files[0].parent if audio_files else None
|
||||||
|
api.save_predictions(
|
||||||
|
predictions,
|
||||||
|
path=output_path,
|
||||||
|
audio_dir=common_path,
|
||||||
|
format=format_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Inference complete. Results saved to {path}", path=output_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@predict.command(name="directory")
|
||||||
|
@click.argument("model_path", type=click.Path(exists=True))
|
||||||
|
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||||
|
@click.argument("output_path", type=click.Path())
|
||||||
|
@click.option("--audio-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--inference-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--logging-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--batch-size", type=int)
|
||||||
|
@click.option("--workers", "num_workers", type=int, default=0)
|
||||||
|
@click.option("--format", "format_name", type=str)
|
||||||
|
def inference_directory_command(
|
||||||
|
model_path: Path,
|
||||||
|
audio_dir: Path,
|
||||||
|
output_path: Path,
|
||||||
|
audio_config: Path | None,
|
||||||
|
inference_config: Path | None,
|
||||||
|
outputs_config: Path | None,
|
||||||
|
logging_config: Path | None,
|
||||||
|
batch_size: int | None,
|
||||||
|
num_workers: int,
|
||||||
|
format_name: str | None,
|
||||||
|
) -> None:
|
||||||
|
audio_files = list(get_audio_files(audio_dir))
|
||||||
|
_run_inference(
|
||||||
|
model_path=model_path,
|
||||||
|
audio_files=audio_files,
|
||||||
|
output_path=output_path,
|
||||||
|
audio_config=audio_config,
|
||||||
|
inference_config=inference_config,
|
||||||
|
outputs_config=outputs_config,
|
||||||
|
logging_config=logging_config,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
format_name=format_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@predict.command(name="file_list")
|
||||||
|
@click.argument("model_path", type=click.Path(exists=True))
|
||||||
|
@click.argument("file_list", type=click.Path(exists=True))
|
||||||
|
@click.argument("output_path", type=click.Path())
|
||||||
|
@click.option("--audio-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--inference-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--logging-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--batch-size", type=int)
|
||||||
|
@click.option("--workers", "num_workers", type=int, default=0)
|
||||||
|
@click.option("--format", "format_name", type=str)
|
||||||
|
def inference_file_list_command(
|
||||||
|
model_path: Path,
|
||||||
|
file_list: Path,
|
||||||
|
output_path: Path,
|
||||||
|
audio_config: Path | None,
|
||||||
|
inference_config: Path | None,
|
||||||
|
outputs_config: Path | None,
|
||||||
|
logging_config: Path | None,
|
||||||
|
batch_size: int | None,
|
||||||
|
num_workers: int,
|
||||||
|
format_name: str | None,
|
||||||
|
) -> None:
|
||||||
|
audio_files = [
|
||||||
|
Path(line.strip())
|
||||||
|
for line in file_list.read_text().splitlines()
|
||||||
|
if line.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
_run_inference(
|
||||||
|
model_path=model_path,
|
||||||
|
audio_files=audio_files,
|
||||||
|
output_path=output_path,
|
||||||
|
audio_config=audio_config,
|
||||||
|
inference_config=inference_config,
|
||||||
|
outputs_config=outputs_config,
|
||||||
|
logging_config=logging_config,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
format_name=format_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@predict.command(name="dataset")
|
||||||
|
@click.argument("model_path", type=click.Path(exists=True))
|
||||||
|
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||||
|
@click.argument("output_path", type=click.Path())
|
||||||
|
@click.option("--audio-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--inference-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--logging-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--batch-size", type=int)
|
||||||
|
@click.option("--workers", "num_workers", type=int, default=0)
|
||||||
|
@click.option("--format", "format_name", type=str)
|
||||||
|
def inference_dataset_command(
|
||||||
|
model_path: Path,
|
||||||
|
dataset_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
audio_config: Path | None,
|
||||||
|
inference_config: Path | None,
|
||||||
|
outputs_config: Path | None,
|
||||||
|
logging_config: Path | None,
|
||||||
|
batch_size: int | None,
|
||||||
|
num_workers: int,
|
||||||
|
format_name: str | None,
|
||||||
|
) -> None:
|
||||||
|
dataset = io.load(dataset_path, type="annotation_set")
|
||||||
|
audio_files = sorted(
|
||||||
|
{
|
||||||
|
Path(clip_annotation.clip.recording.path)
|
||||||
|
for clip_annotation in dataset.clip_annotations
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_inference(
|
||||||
|
model_path=model_path,
|
||||||
|
audio_files=audio_files,
|
||||||
|
output_path=output_path,
|
||||||
|
audio_config=audio_config,
|
||||||
|
inference_config=inference_config,
|
||||||
|
outputs_config=outputs_config,
|
||||||
|
logging_config=logging_config,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
format_name=format_name,
|
||||||
|
)
|
||||||
@ -1,11 +1,15 @@
|
|||||||
"""Test the command line interface."""
|
"""Test the command line interface."""
|
||||||
|
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lightning as L
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
@ -26,6 +30,55 @@ def test_cli_detect_command_help():
|
|||||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_command_help():
|
||||||
|
"""Test the predict command help."""
|
||||||
|
result = runner.invoke(cli, ["predict", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "directory" in result.output
|
||||||
|
assert "file_list" in result.output
|
||||||
|
assert "dataset" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_directory_runs_on_real_audio(tmp_path: Path):
|
||||||
|
"""User story: run prediction from CLI on a small directory."""
|
||||||
|
|
||||||
|
source_audio = Path("example_data/audio")
|
||||||
|
source_file = next(source_audio.glob("*.wav"))
|
||||||
|
audio_dir = tmp_path / "audio"
|
||||||
|
audio_dir.mkdir()
|
||||||
|
target_file = audio_dir / source_file.name
|
||||||
|
shutil.copy(source_file, target_file)
|
||||||
|
|
||||||
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
|
model_path = tmp_path / "model.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(model_path)
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"predict",
|
||||||
|
"directory",
|
||||||
|
str(model_path),
|
||||||
|
str(audio_dir),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
"batdetect2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert output_path.exists()
|
||||||
|
output_files = list(output_path.glob("*.json"))
|
||||||
|
assert len(output_files) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_on_test_audio(tmp_path):
|
def test_cli_detect_command_on_test_audio(tmp_path):
|
||||||
"""Test the detect command on test audio."""
|
"""Test the detect command on test audio."""
|
||||||
results_dir = tmp_path / "results"
|
results_dir = tmp_path / "results"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user