feat: restore legacy batdetect2 process outputs

This commit is contained in:
mbsantiago 2026-05-06 18:57:52 +01:00
parent 105384d9a2
commit c4f759e9a3
7 changed files with 456 additions and 55 deletions

View File

@ -14,7 +14,7 @@ Current built-in output formats include:
- `soundevent`: - `soundevent`:
prediction-set JSON for soundevent-style tooling, prediction-set JSON for soundevent-style tooling,
- `batdetect2`: - `batdetect2`:
legacy per-recording JSON output. legacy-compatible per-recording JSON and CSV outputs.
## Select a format from the CLI ## Select a format from the CLI
@ -61,7 +61,29 @@ batdetect2 process directory \
- Use `raw` if you want the richest output surface and easy round-tripping. - Use `raw` if you want the richest output surface and easy round-tripping.
- Use `parquet` if you want tabular analysis in Python or data-lake workflows. - Use `parquet` if you want tabular analysis in Python or data-lake workflows.
- Use `soundevent` if you want prediction-set JSON. - Use `soundevent` if you want prediction-set JSON.
- Use `batdetect2` only when you need the legacy JSON shape. - Use `batdetect2` when you need legacy BatDetect2-style outputs.
## Enable legacy CNN feature CSVs
The `batdetect2` formatter can also write the legacy CNN feature sidecar CSVs.
This is controlled through the outputs config.
Example:
```yaml
format:
name: batdetect2
write_cnn_features_csv: true
transform:
detection_transforms: []
clip_transforms: []
```
When enabled, BatDetect2 writes:
- one `.json` file per recording,
- one detection `.csv` file per recording,
- one `_cnn_features.csv` file per recording when detections are present.
## Related pages ## Related pages

View File

@ -47,17 +47,29 @@ Writes a prediction-set JSON file.
Defined by `BatDetect2OutputConfig`. Defined by `BatDetect2OutputConfig`.
This is the legacy BatDetect2-style JSON output. This is the legacy-compatible BatDetect2 formatter.
Key fields: Key fields:
- `event_name` - `event_name`
- `annotation_note` - `annotation_note`
- `write_detection_csv`
- `write_cnn_features_csv`
- `save_if_empty`
- `preserve_audio_tree`
- `include_file_path`
Writes one `.json` file per recording. By default it writes one `.json` file and one detection `.csv` file per
recording, preserving the input audio directory layout under the output root.
It can also write legacy `_cnn_features.csv` sidecars when
`write_cnn_features_csv` is enabled.
## Related pages ## Related pages
- Outputs config: {doc}`outputs-config` - Outputs config:
- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats` {doc}`outputs-config`
- Understanding formatted outputs: {doc}`../explanation/interpreting-formatted-outputs` - Save predictions in different output formats:
{doc}`../how_to/save-predictions-in-different-output-formats`
- Understanding formatted outputs:
{doc}`../explanation/interpreting-formatted-outputs`

View File

@ -24,10 +24,18 @@ The output workflow is:
## Default behavior ## Default behavior
By default, the current stack uses the raw output formatter unless you override it. By default, the current stack uses the raw output formatter unless you override
it.
For CLI processing commands, omitting `--format` now leaves format selection to
the loaded outputs config.
If no outputs config is provided, the CLI still uses its command defaults.
## Related pages ## Related pages
- Output formats: {doc}`output-formats` - Output formats:
- Output transforms: {doc}`output-transforms` {doc}`output-formats`
- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats` - Output transforms:
{doc}`output-transforms`
- Save predictions in different output formats:
{doc}`../how_to/save-predictions-in-different-output-formats`

View File

@ -27,6 +27,15 @@ def process() -> None:
def common_predict_options(func): def common_predict_options(func):
"""Attach options shared by all ``process`` subcommands.""" """Attach options shared by all ``process`` subcommands."""
@click.option(
"--model",
"model_path",
type=str,
help=(
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
"URI to fine-tune from. Defaults to uk_same"
),
)
@click.option( @click.option(
"--audio-config", "--audio-config",
type=click.Path(exists=True), type=click.Path(exists=True),
@ -97,7 +106,7 @@ def common_predict_options(func):
def _build_api( def _build_api(
model_path: str, model_path: str | None,
audio_config: Path | None, audio_config: Path | None,
inference_config: Path | None, inference_config: Path | None,
outputs_config: Path | None, outputs_config: Path | None,
@ -129,7 +138,7 @@ def _build_api(
) )
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, path=model_path,
audio_config=audio_conf, audio_config=audio_conf,
inference_config=inference_conf, inference_config=inference_conf,
outputs_config=outputs_conf, outputs_config=outputs_conf,
@ -139,7 +148,7 @@ def _build_api(
def _run_prediction( def _run_prediction(
model_path: str, model_path: str | None,
audio_files: list[Path], audio_files: list[Path],
output_path: Path, output_path: Path,
audio_config: Path | None, audio_config: Path | None,
@ -190,12 +199,11 @@ def _run_prediction(
name="directory", name="directory",
short_help="Process audio files in a directory.", short_help="Process audio files in a directory.",
) )
@click.argument("model_path", type=str)
@click.argument("audio_dir", type=click.Path(exists=True)) @click.argument("audio_dir", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path()) @click.argument("output_path", type=click.Path())
@common_predict_options @common_predict_options
def predict_directory_command( def predict_directory_command(
model_path: str, model_path: str | None,
audio_dir: Path, audio_dir: Path,
output_path: Path, output_path: Path,
audio_config: Path | None, audio_config: Path | None,
@ -234,14 +242,13 @@ def predict_directory_command(
name="file_list", name="file_list",
short_help="Process paths listed in a text file.", short_help="Process paths listed in a text file.",
) )
@click.argument("model_path", type=str)
@click.argument("file_list", type=click.Path(exists=True)) @click.argument("file_list", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path()) @click.argument("output_path", type=click.Path())
@common_predict_options @common_predict_options
def predict_file_list_command( def predict_file_list_command(
model_path: str,
file_list: Path, file_list: Path,
output_path: Path, output_path: Path,
model_path: str | None,
audio_config: Path | None, audio_config: Path | None,
inference_config: Path | None, inference_config: Path | None,
outputs_config: Path | None, outputs_config: Path | None,
@ -282,14 +289,13 @@ def predict_file_list_command(
name="dataset", name="dataset",
short_help="Process recordings from a dataset config.", short_help="Process recordings from a dataset config.",
) )
@click.argument("model_path", type=str)
@click.argument("dataset_path", type=click.Path(exists=True)) @click.argument("dataset_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path()) @click.argument("output_path", type=click.Path())
@common_predict_options @common_predict_options
def predict_dataset_command( def predict_dataset_command(
model_path: str,
dataset_path: Path, dataset_path: Path,
output_path: Path, output_path: Path,
model_path: str | None,
audio_config: Path | None, audio_config: Path | None,
inference_config: Path | None, inference_config: Path | None,
outputs_config: Path | None, outputs_config: Path | None,

View File

@ -3,7 +3,9 @@ from pathlib import Path
from typing import List, Literal, Sequence, TypedDict from typing import List, Literal, Sequence, TypedDict
import numpy as np import numpy as np
import pandas as pd
from soundevent import data from soundevent import data
from soundevent import terms as soundevent_terms
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
@ -13,7 +15,6 @@ from batdetect2.outputs.formats.base import (
) )
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections, Detection from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets import terms
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
try: try:
@ -24,7 +25,7 @@ except ImportError:
DictWithClass = TypedDict("DictWithClass", {"class": str}) DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass): class Annotation(DictWithClass, total=False):
start_time: float start_time: float
end_time: float end_time: float
low_freq: float low_freq: float
@ -33,6 +34,7 @@ class Annotation(DictWithClass):
det_prob: float det_prob: float
individual: str individual: str
event: str event: str
cnn_features: NotRequired[list[float]] # ty: ignore[invalid-type-form]
class FileAnnotation(TypedDict): class FileAnnotation(TypedDict):
@ -52,6 +54,11 @@ class BatDetect2OutputConfig(BaseConfig):
event_name: str = "Echolocation" event_name: str = "Echolocation"
annotation_note: str = "Automatically generated." annotation_note: str = "Automatically generated."
write_detection_csv: bool = True
write_cnn_features_csv: bool = False
save_if_empty: bool = False
preserve_audio_tree: bool = True
include_file_path: bool = False
class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
@ -60,10 +67,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
targets: TargetProtocol, targets: TargetProtocol,
event_name: str, event_name: str,
annotation_note: str, annotation_note: str,
write_detection_csv: bool = True,
write_cnn_features_csv: bool = False,
save_if_empty: bool = False,
preserve_audio_tree: bool = True,
include_file_path: bool = False,
): ):
self.targets = targets self.targets = targets
self.event_name = event_name self.event_name = event_name
self.annotation_note = annotation_note self.annotation_note = annotation_note
self.write_detection_csv = write_detection_csv
self.write_cnn_features_csv = write_cnn_features_csv
self.save_if_empty = save_if_empty
self.preserve_audio_tree = preserve_audio_tree
self.include_file_path = include_file_path
def format( def format(
self, predictions: Sequence[ClipDetections] self, predictions: Sequence[ClipDetections]
@ -84,22 +101,56 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
path.mkdir(parents=True) path.mkdir(parents=True)
for prediction in predictions: for prediction in predictions:
pred_path = path / (prediction["id"] + ".json") annotations = prediction["annotation"]
if audio_dir is not None and "file_path" in prediction: if not annotations and not self.save_if_empty:
prediction["file_path"] = str( continue
make_path_relative(
prediction["file_path"], pred_path = self.get_output_path(prediction, path, audio_dir)
audio_dir, pred_path.parent.mkdir(parents=True, exist_ok=True)
)
# make a copy of the prediction
data = dict(prediction)
raw_file_path = data.get("file_path")
if audio_dir is not None and isinstance(raw_file_path, str):
data["file_path"] = str(
make_path_relative(raw_file_path, audio_dir)
) )
pred_path.write_text(json.dumps(prediction)) if not self.include_file_path:
data.pop("file_path", None)
data["annotation"] = [
{
key: value
for key, value in annotation.items()
if key != "cnn_features"
}
for annotation in data["annotation"]
]
pred_path.write_text(json.dumps(data, indent=2, sort_keys=True))
if self.write_detection_csv:
self.save_detection_csv(
prediction,
pred_path.with_suffix(".csv"),
)
if self.write_cnn_features_csv:
self.save_cnn_features_csv(
prediction,
pred_path.with_name(pred_path.stem + "_cnn_features.csv"),
)
def load(self, path: data.PathLike) -> List[FileAnnotation]: def load(self, path: data.PathLike) -> List[FileAnnotation]:
path = Path(path) path = Path(path)
files = list(path.glob("*.json")) if path.is_file():
files = [path] if path.suffix == ".json" else []
else:
files = sorted(path.rglob("*.json"))
if not files: if not files:
return [] return []
@ -108,12 +159,108 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
json.loads(file.read_text()) for file in files if file.is_file() json.loads(file.read_text()) for file in files if file.is_file()
] ]
def get_recording_class(self, annotations: List[Annotation]) -> str: def get_output_path(
if not annotations: self,
return "" prediction: FileAnnotation,
output_dir: Path,
audio_dir: data.PathLike | None,
) -> Path:
if (
self.preserve_audio_tree
and audio_dir is not None
and "file_path" in prediction
):
relative_path = make_path_relative(
prediction["file_path"],
audio_dir,
)
return (
output_dir / relative_path.parent / f"{prediction['id']}.json"
)
highest_scoring = max(annotations, key=lambda x: x["class_prob"]) return output_dir / f"{prediction['id']}.json"
return highest_scoring["class"]
def save_detection_csv(
self,
prediction: FileAnnotation,
path: Path,
) -> None:
annotations = prediction["annotation"]
if not annotations:
return
preds_df = pd.DataFrame(annotations)[
[
"det_prob",
"start_time",
"end_time",
"high_freq",
"low_freq",
"class",
"class_prob",
]
]
preds_df.to_csv(path, sep=",")
def save_cnn_features_csv(
self, prediction: FileAnnotation, path: Path
) -> None:
annotations = prediction["annotation"]
if not annotations:
return
cnn_features = [
annotation["cnn_features"]
for annotation in annotations
if "cnn_features" in annotation
]
if not cnn_features:
return
cnn_feats_df = pd.DataFrame(
cnn_features,
columns=[str(ii) for ii in range(len(cnn_features[0]))],
)
cnn_feats_df.to_csv(
path,
sep=",",
index=False,
float_format="%.5f",
)
def get_class_name(self, class_index: int) -> str:
class_name = self.targets.class_names[class_index]
tags = self.targets.decode_class(class_name)
return data.find_tag_value(
tags,
term=soundevent_terms.scientific_name,
default=class_name,
) # type: ignore
def get_recording_class(self, detections: Sequence[Detection]) -> str:
if not detections:
return "None"
class_scores = np.stack(
[detection.class_scores for detection in detections],
axis=1,
)
detection_scores = np.array(
[detection.detection_score for detection in detections],
dtype=np.float32,
)
weighted_scores = (class_scores * detection_scores).sum(axis=1)
total = weighted_scores.sum()
if total <= 0:
return "None"
top_class_index = int(np.argmax(weighted_scores / total))
return self.get_class_name(top_class_index)
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation: def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
recording = prediction.clip.recording recording = prediction.clip.recording
@ -123,26 +270,19 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
for pred in prediction.detections for pred in prediction.detections
] ]
return FileAnnotation( file_annotation = FileAnnotation(
id=recording.path.name, id=recording.path.name,
file_path=str(recording.path),
annotated=False, annotated=False,
duration=recording.duration, duration=round(float(recording.duration), 4),
issues=False, issues=False,
time_exp=recording.time_expansion, time_exp=recording.time_expansion,
class_name=self.get_recording_class(annotations), class_name=self.get_recording_class(prediction.detections),
notes=self.annotation_note, notes=self.annotation_note,
annotation=annotations, annotation=annotations,
file_path=str(recording.path),
) )
def get_class_name(self, class_index: int) -> str: return file_annotation
class_name = self.targets.class_names[class_index]
tags = self.targets.decode_class(class_name)
return data.find_tag_value(
tags,
term=terms.generic_class,
default=class_name,
) # type: ignore
def format_sound_event_prediction( def format_sound_event_prediction(
self, prediction: Detection self, prediction: Detection
@ -155,16 +295,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
top_class_score = float(prediction.class_scores[top_class_index]) top_class_score = float(prediction.class_scores[top_class_index])
top_class = self.get_class_name(top_class_index) top_class = self.get_class_name(top_class_index)
annotation: Annotation = { annotation: Annotation = {
"start_time": start_time, "start_time": round(float(start_time), 4),
"end_time": end_time, "end_time": round(float(end_time), 4),
"low_freq": low_freq, "low_freq": int(low_freq),
"high_freq": high_freq, "high_freq": int(high_freq),
"class_prob": top_class_score, "class_prob": round(top_class_score, 3),
"det_prob": float(prediction.detection_score), "det_prob": round(float(prediction.detection_score), 3),
"individual": "", "individual": "-1",
"event": self.event_name, "event": self.event_name,
"class": top_class, "class": top_class,
} }
if self.write_cnn_features_csv:
annotation["cnn_features"] = prediction.features.tolist() # type: ignore[index]
return annotation return annotation
@output_formatters.register(BatDetect2OutputConfig) @output_formatters.register(BatDetect2OutputConfig)
@ -174,4 +318,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
targets, targets,
event_name=config.event_name, event_name=config.event_name,
annotation_note=config.annotation_note, annotation_note=config.annotation_note,
write_detection_csv=config.write_detection_csv,
write_cnn_features_csv=config.write_cnn_features_csv,
save_if_empty=config.save_if_empty,
preserve_audio_tree=config.preserve_audio_tree,
include_file_path=config.include_file_path,
) )

View File

@ -2,6 +2,7 @@ from pathlib import Path
from typing import cast from typing import cast
import numpy as np import numpy as np
import pandas as pd
import pytest import pytest
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
@ -98,6 +99,47 @@ def test_load_predictions_with_format_override(
assert "annotation" in loaded_item assert "annotation" in loaded_item
def test_load_predictions_with_batdetect2_nested_layout(
api_v2: BatDetect2API,
example_audio_files: list[Path],
tmp_path: Path,
) -> None:
output_dir = tmp_path / "batdetect2_nested"
predictions = [
api_v2.process_file(audio_file) for audio_file in example_audio_files
]
api_v2.save_predictions(
predictions,
path=output_dir,
format="batdetect2",
audio_dir=example_audio_files[0].parent,
)
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
assert len(loaded) == len(example_audio_files)
def test_save_predictions_with_batdetect2_writes_cnn_feature_csv(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "batdetect2_cnn"
api_v2.save_predictions(
[file_prediction],
path=output_dir,
config=BatDetect2OutputConfig(write_cnn_features_csv=True),
)
cnn_csvs = list(output_dir.rglob("*_cnn_features.csv"))
assert len(cnn_csvs) == 1
loaded_df = pd.read_csv(cnn_csvs[0])
assert not loaded_df.empty
def test_save_predictions_with_soundevent_override( def test_save_predictions_with_soundevent_override(
api_v2: BatDetect2API, api_v2: BatDetect2API,
file_prediction, file_prediction,

View File

@ -1,12 +1,16 @@
"""Behavior tests for process CLI workflows.""" """Behavior tests for process CLI workflows."""
import json
from pathlib import Path from pathlib import Path
import pandas as pd
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner
from soundevent import data, io from soundevent import data, io
from batdetect2.cli import cli from batdetect2.cli import cli
from batdetect2.outputs import OutputsConfig
from batdetect2.outputs.formats import BatDetect2OutputConfig
def test_cli_process_help() -> None: def test_cli_process_help() -> None:
@ -35,6 +39,7 @@ def test_cli_process_directory_runs_on_real_audio(
[ [
"process", "process",
"directory", "directory",
"--model",
str(tiny_checkpoint_path), str(tiny_checkpoint_path),
str(single_audio_dir), str(single_audio_dir),
str(output_path), str(output_path),
@ -52,6 +57,158 @@ def test_cli_process_directory_runs_on_real_audio(
assert len(list(output_path.glob("*.json"))) == 1 assert len(list(output_path.glob("*.json"))) == 1
@pytest.mark.slow
def test_cli_process_directory_runs_on_example_audio_data(
tmp_path: Path,
tiny_checkpoint_path: Path,
example_audio_dir: Path,
example_audio_files: list[Path],
) -> None:
"""User story: process the bundled example audio directory."""
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"process",
"directory",
"--model",
str(tiny_checkpoint_path),
str(example_audio_dir),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"batdetect2",
],
)
assert result.exit_code == 0
assert output_path.exists()
assert len(list(output_path.glob("*.json"))) == len(example_audio_files)
@pytest.mark.slow
def test_cli_process_directory_batdetect2_matches_legacy_artifacts(
tmp_path: Path,
tiny_checkpoint_path: Path,
example_audio_dir: Path,
example_audio_files: list[Path],
example_anns_dir: Path,
) -> None:
"""User story: process batdetect2 output matches legacy-style files."""
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"process",
"directory",
"--model",
str(tiny_checkpoint_path),
str(example_audio_dir),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"batdetect2",
],
)
assert result.exit_code == 0
json_files = sorted(output_path.rglob("*.json"))
csv_files = sorted(output_path.rglob("*.csv"))
assert len(json_files) == len(example_audio_files)
assert len(csv_files) == len(example_audio_files)
expected_names = sorted(
audio_file.name for audio_file in example_audio_files
)
assert sorted(path.stem for path in json_files) == expected_names
assert sorted(path.stem for path in csv_files) == expected_names
first_output = json.loads(json_files[0].read_text())
assert "file_path" not in first_output
assert isinstance(first_output["class_name"], str)
assert first_output["class_name"]
first_annotation = first_output["annotation"][0]
assert first_annotation["individual"] == "-1"
assert isinstance(first_annotation["high_freq"], int)
assert isinstance(first_annotation["low_freq"], int)
expected_json = json.loads(
(example_anns_dir / json_files[0].name).read_text()
)
assert first_output["id"] == expected_json["id"]
assert first_output["time_exp"] == expected_json["time_exp"]
first_csv = pd.read_csv(csv_files[0], index_col=0)
assert list(first_csv.columns) == [
"det_prob",
"start_time",
"end_time",
"high_freq",
"low_freq",
"class",
"class_prob",
]
@pytest.mark.slow
def test_cli_process_directory_batdetect2_writes_cnn_features_csv_when_enabled(
tmp_path: Path,
tiny_checkpoint_path: Path,
example_audio_dir: Path,
) -> None:
"""User story: request legacy CNN feature CSV sidecars via config."""
output_path = tmp_path / "predictions"
outputs_config_path = tmp_path / "outputs.yaml"
outputs_config_path.write_text(
OutputsConfig(
format=BatDetect2OutputConfig(write_cnn_features_csv=True)
).to_yaml_string()
)
result = CliRunner().invoke(
cli,
[
"process",
"directory",
"--model",
str(tiny_checkpoint_path),
str(example_audio_dir),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--outputs-config",
str(outputs_config_path),
],
)
assert result.exit_code == 0
cnn_csvs = sorted(output_path.rglob("*_cnn_features.csv"))
assert len(cnn_csvs) == 3
first_df = pd.read_csv(cnn_csvs[0])
assert not first_df.empty
assert list(first_df.columns) == [
str(ii) for ii in range(len(first_df.columns))
]
def test_cli_process_file_list_runs_on_real_audio( def test_cli_process_file_list_runs_on_real_audio(
tmp_path: Path, tmp_path: Path,
tiny_checkpoint_path: Path, tiny_checkpoint_path: Path,
@ -70,6 +227,7 @@ def test_cli_process_file_list_runs_on_real_audio(
[ [
"process", "process",
"file_list", "file_list",
"--model",
str(tiny_checkpoint_path), str(tiny_checkpoint_path),
str(file_list), str(file_list),
str(output_path), str(output_path),
@ -117,6 +275,7 @@ def test_cli_process_dataset_runs_on_aoef_metadata(
[ [
"process", "process",
"dataset", "dataset",
"--model",
str(tiny_checkpoint_path), str(tiny_checkpoint_path),
str(dataset_path), str(dataset_path),
str(output_path), str(output_path),
@ -159,6 +318,7 @@ def test_cli_process_directory_supports_output_format_override(
[ [
"process", "process",
"directory", "directory",
"--model",
str(tiny_checkpoint_path), str(tiny_checkpoint_path),
str(single_audio_dir), str(single_audio_dir),
str(output_path), str(output_path),
@ -217,6 +377,7 @@ def test_cli_process_dataset_deduplicates_recordings(
[ [
"process", "process",
"dataset", "dataset",
"--model",
str(tiny_checkpoint_path), str(tiny_checkpoint_path),
str(dataset_path), str(dataset_path),
str(output_path), str(output_path),
@ -247,6 +408,7 @@ def test_cli_process_rejects_unknown_output_format(
[ [
"process", "process",
"directory", "directory",
"--model",
str(tiny_checkpoint_path), str(tiny_checkpoint_path),
str(single_audio_dir), str(single_audio_dir),
str(output_path), str(output_path),