diff --git a/docs/source/how_to/save-predictions-in-different-output-formats.md b/docs/source/how_to/save-predictions-in-different-output-formats.md index 820728a..ee4d5c6 100644 --- a/docs/source/how_to/save-predictions-in-different-output-formats.md +++ b/docs/source/how_to/save-predictions-in-different-output-formats.md @@ -14,7 +14,7 @@ Current built-in output formats include: - `soundevent`: prediction-set JSON for soundevent-style tooling, - `batdetect2`: - legacy per-recording JSON output. + legacy-compatible per-recording JSON and CSV outputs. ## Select a format from the CLI @@ -61,7 +61,29 @@ batdetect2 process directory \ - Use `raw` if you want the richest output surface and easy round-tripping. - Use `parquet` if you want tabular analysis in Python or data-lake workflows. - Use `soundevent` if you want prediction-set JSON. -- Use `batdetect2` only when you need the legacy JSON shape. +- Use `batdetect2` when you need legacy BatDetect2-style outputs. + +## Enable legacy CNN feature CSVs + +The `batdetect2` formatter can also write the legacy CNN feature sidecar CSVs. +This is controlled through the outputs config. + +Example: + +```yaml +format: + name: batdetect2 + write_cnn_features_csv: true +transform: + detection_transforms: [] + clip_transforms: [] +``` + +When enabled, BatDetect2 writes: + +- one `.json` file per recording, +- one detection `.csv` file per recording, +- one `_cnn_features.csv` file per recording when detections are present. ## Related pages diff --git a/docs/source/reference/output-formats.md b/docs/source/reference/output-formats.md index a4780f1..cb92b1e 100644 --- a/docs/source/reference/output-formats.md +++ b/docs/source/reference/output-formats.md @@ -47,17 +47,29 @@ Writes a prediction-set JSON file. Defined by `BatDetect2OutputConfig`. -This is the legacy BatDetect2-style JSON output. +This is the legacy-compatible BatDetect2 formatter. Key fields: - `event_name` - `annotation_note` +- `write_detection_csv` +- `write_cnn_features_csv` +- `save_if_empty` +- `preserve_audio_tree` +- `include_file_path` -Writes one `.json` file per recording. +By default it writes one `.json` file and one detection `.csv` file per +recording, preserving the input audio directory layout under the output root. + +It can also write legacy `_cnn_features.csv` sidecars when +`write_cnn_features_csv` is enabled. ## Related pages -- Outputs config: {doc}`outputs-config` -- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats` -- Understanding formatted outputs: {doc}`../explanation/interpreting-formatted-outputs` +- Outputs config: + {doc}`outputs-config` +- Save predictions in different output formats: + {doc}`../how_to/save-predictions-in-different-output-formats` +- Understanding formatted outputs: + {doc}`../explanation/interpreting-formatted-outputs` diff --git a/docs/source/reference/outputs-config.md b/docs/source/reference/outputs-config.md index d548b18..6726d2c 100644 --- a/docs/source/reference/outputs-config.md +++ b/docs/source/reference/outputs-config.md @@ -24,10 +24,18 @@ The output workflow is: ## Default behavior -By default, the current stack uses the raw output formatter unless you override it. +By default, the current stack uses the raw output formatter unless you override +it. + +For CLI processing commands, omitting `--format` now leaves format selection to +the loaded outputs config. +If no outputs config is provided, the CLI still uses its command defaults. ## Related pages -- Output formats: {doc}`output-formats` -- Output transforms: {doc}`output-transforms` -- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats` +- Output formats: + {doc}`output-formats` +- Output transforms: + {doc}`output-transforms` +- Save predictions in different output formats: + {doc}`../how_to/save-predictions-in-different-output-formats` diff --git a/src/batdetect2/cli/inference.py b/src/batdetect2/cli/inference.py index 84985e6..70ec546 100644 --- a/src/batdetect2/cli/inference.py +++ b/src/batdetect2/cli/inference.py @@ -27,6 +27,15 @@ def process() -> None: def common_predict_options(func): """Attach options shared by all ``process`` subcommands.""" + @click.option( + "--model", + "model_path", + type=str, + help=( + "Path to a checkpoint, checkpoint alias, or a Hugging Face " + "URI to fine-tune from. Defaults to uk_same" + ), + ) @click.option( "--audio-config", type=click.Path(exists=True), @@ -97,7 +106,7 @@ def common_predict_options(func): def _build_api( - model_path: str, + model_path: str | None, audio_config: Path | None, inference_config: Path | None, outputs_config: Path | None, @@ -129,7 +138,7 @@ def _build_api( ) api = BatDetect2API.from_checkpoint( - model_path, + path=model_path, audio_config=audio_conf, inference_config=inference_conf, outputs_config=outputs_conf, @@ -139,7 +148,7 @@ def _build_api( def _run_prediction( - model_path: str, + model_path: str | None, audio_files: list[Path], output_path: Path, audio_config: Path | None, @@ -190,12 +199,11 @@ def _run_prediction( name="directory", short_help="Process audio files in a directory.", ) -@click.argument("model_path", type=str) @click.argument("audio_dir", type=click.Path(exists=True)) @click.argument("output_path", type=click.Path()) @common_predict_options def predict_directory_command( - model_path: str, + model_path: str | None, audio_dir: Path, output_path: Path, audio_config: Path | None, @@ -234,14 +242,13 @@ def predict_directory_command( name="file_list", short_help="Process paths listed in a text file.", ) -@click.argument("model_path", type=str) @click.argument("file_list", type=click.Path(exists=True)) @click.argument("output_path", type=click.Path()) @common_predict_options def predict_file_list_command( - model_path: str, file_list: Path, output_path: Path, + model_path: str | None, audio_config: Path | None, inference_config: Path | None, outputs_config: Path | None, @@ -282,14 +289,13 @@ def predict_file_list_command( name="dataset", short_help="Process recordings from a dataset config.", ) -@click.argument("model_path", type=str) @click.argument("dataset_path", type=click.Path(exists=True)) @click.argument("output_path", type=click.Path()) @common_predict_options def predict_dataset_command( - model_path: str, dataset_path: Path, output_path: Path, + model_path: str | None, audio_config: Path | None, inference_config: Path | None, outputs_config: Path | None, diff --git a/src/batdetect2/outputs/formats/batdetect2.py b/src/batdetect2/outputs/formats/batdetect2.py index 115a908..4db5f11 100644 --- a/src/batdetect2/outputs/formats/batdetect2.py +++ b/src/batdetect2/outputs/formats/batdetect2.py @@ -3,7 +3,9 @@ from pathlib import Path from typing import List, Literal, Sequence, TypedDict import numpy as np +import pandas as pd from soundevent import data +from soundevent import terms as soundevent_terms from soundevent.geometry import compute_bounds from batdetect2.core import BaseConfig @@ -13,7 +15,6 @@ from batdetect2.outputs.formats.base import ( ) from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.postprocess.types import ClipDetections, Detection -from batdetect2.targets import terms from batdetect2.targets.types import TargetProtocol try: @@ -24,7 +25,7 @@ except ImportError: DictWithClass = TypedDict("DictWithClass", {"class": str}) -class Annotation(DictWithClass): +class Annotation(DictWithClass, total=False): start_time: float end_time: float low_freq: float @@ -33,6 +34,7 @@ class Annotation(DictWithClass): det_prob: float individual: str event: str + cnn_features: NotRequired[list[float]] # ty: ignore[invalid-type-form] class FileAnnotation(TypedDict): @@ -52,6 +54,11 @@ class BatDetect2OutputConfig(BaseConfig): event_name: str = "Echolocation" annotation_note: str = "Automatically generated." + write_detection_csv: bool = True + write_cnn_features_csv: bool = False + save_if_empty: bool = False + preserve_audio_tree: bool = True + include_file_path: bool = False class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): @@ -60,10 +67,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): targets: TargetProtocol, event_name: str, annotation_note: str, + write_detection_csv: bool = True, + write_cnn_features_csv: bool = False, + save_if_empty: bool = False, + preserve_audio_tree: bool = True, + include_file_path: bool = False, ): self.targets = targets self.event_name = event_name self.annotation_note = annotation_note + self.write_detection_csv = write_detection_csv + self.write_cnn_features_csv = write_cnn_features_csv + self.save_if_empty = save_if_empty + self.preserve_audio_tree = preserve_audio_tree + self.include_file_path = include_file_path def format( self, predictions: Sequence[ClipDetections] @@ -84,22 +101,56 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): path.mkdir(parents=True) for prediction in predictions: - pred_path = path / (prediction["id"] + ".json") + annotations = prediction["annotation"] - if audio_dir is not None and "file_path" in prediction: - prediction["file_path"] = str( - make_path_relative( - prediction["file_path"], - audio_dir, - ) + if not annotations and not self.save_if_empty: + continue + + pred_path = self.get_output_path(prediction, path, audio_dir) + pred_path.parent.mkdir(parents=True, exist_ok=True) + + # make a copy of the prediction + data = dict(prediction) + + raw_file_path = data.get("file_path") + if audio_dir is not None and isinstance(raw_file_path, str): + data["file_path"] = str( + make_path_relative(raw_file_path, audio_dir) ) - pred_path.write_text(json.dumps(prediction)) + if not self.include_file_path: + data.pop("file_path", None) + + data["annotation"] = [ + { + key: value + for key, value in annotation.items() + if key != "cnn_features" + } + for annotation in data["annotation"] + ] + + pred_path.write_text(json.dumps(data, indent=2, sort_keys=True)) + + if self.write_detection_csv: + self.save_detection_csv( + prediction, + pred_path.with_suffix(".csv"), + ) + + if self.write_cnn_features_csv: + self.save_cnn_features_csv( + prediction, + pred_path.with_name(pred_path.stem + "_cnn_features.csv"), + ) def load(self, path: data.PathLike) -> List[FileAnnotation]: path = Path(path) - files = list(path.glob("*.json")) + if path.is_file(): + files = [path] if path.suffix == ".json" else [] + else: + files = sorted(path.rglob("*.json")) if not files: return [] @@ -108,12 +159,108 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): json.loads(file.read_text()) for file in files if file.is_file() ] - def get_recording_class(self, annotations: List[Annotation]) -> str: - if not annotations: - return "" + def get_output_path( + self, + prediction: FileAnnotation, + output_dir: Path, + audio_dir: data.PathLike | None, + ) -> Path: + if ( + self.preserve_audio_tree + and audio_dir is not None + and "file_path" in prediction + ): + relative_path = make_path_relative( + prediction["file_path"], + audio_dir, + ) + return ( + output_dir / relative_path.parent / f"{prediction['id']}.json" + ) - highest_scoring = max(annotations, key=lambda x: x["class_prob"]) - return highest_scoring["class"] + return output_dir / f"{prediction['id']}.json" + + def save_detection_csv( + self, + prediction: FileAnnotation, + path: Path, + ) -> None: + annotations = prediction["annotation"] + if not annotations: + return + + preds_df = pd.DataFrame(annotations)[ + [ + "det_prob", + "start_time", + "end_time", + "high_freq", + "low_freq", + "class", + "class_prob", + ] + ] + preds_df.to_csv(path, sep=",") + + def save_cnn_features_csv( + self, prediction: FileAnnotation, path: Path + ) -> None: + annotations = prediction["annotation"] + + if not annotations: + return + + cnn_features = [ + annotation["cnn_features"] + for annotation in annotations + if "cnn_features" in annotation + ] + + if not cnn_features: + return + + cnn_feats_df = pd.DataFrame( + cnn_features, + columns=[str(ii) for ii in range(len(cnn_features[0]))], + ) + + cnn_feats_df.to_csv( + path, + sep=",", + index=False, + float_format="%.5f", + ) + + def get_class_name(self, class_index: int) -> str: + class_name = self.targets.class_names[class_index] + tags = self.targets.decode_class(class_name) + return data.find_tag_value( + tags, + term=soundevent_terms.scientific_name, + default=class_name, + ) # type: ignore + + def get_recording_class(self, detections: Sequence[Detection]) -> str: + if not detections: + return "None" + + class_scores = np.stack( + [detection.class_scores for detection in detections], + axis=1, + ) + detection_scores = np.array( + [detection.detection_score for detection in detections], + dtype=np.float32, + ) + weighted_scores = (class_scores * detection_scores).sum(axis=1) + + total = weighted_scores.sum() + + if total <= 0: + return "None" + + top_class_index = int(np.argmax(weighted_scores / total)) + return self.get_class_name(top_class_index) def format_prediction(self, prediction: ClipDetections) -> FileAnnotation: recording = prediction.clip.recording @@ -123,26 +270,19 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): for pred in prediction.detections ] - return FileAnnotation( + file_annotation = FileAnnotation( id=recording.path.name, - file_path=str(recording.path), annotated=False, - duration=recording.duration, + duration=round(float(recording.duration), 4), issues=False, time_exp=recording.time_expansion, - class_name=self.get_recording_class(annotations), + class_name=self.get_recording_class(prediction.detections), notes=self.annotation_note, annotation=annotations, + file_path=str(recording.path), ) - def get_class_name(self, class_index: int) -> str: - class_name = self.targets.class_names[class_index] - tags = self.targets.decode_class(class_name) - return data.find_tag_value( - tags, - term=terms.generic_class, - default=class_name, - ) # type: ignore + return file_annotation def format_sound_event_prediction( self, prediction: Detection @@ -155,16 +295,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): top_class_score = float(prediction.class_scores[top_class_index]) top_class = self.get_class_name(top_class_index) annotation: Annotation = { - "start_time": start_time, - "end_time": end_time, - "low_freq": low_freq, - "high_freq": high_freq, - "class_prob": top_class_score, - "det_prob": float(prediction.detection_score), - "individual": "", + "start_time": round(float(start_time), 4), + "end_time": round(float(end_time), 4), + "low_freq": int(low_freq), + "high_freq": int(high_freq), + "class_prob": round(top_class_score, 3), + "det_prob": round(float(prediction.detection_score), 3), + "individual": "-1", "event": self.event_name, "class": top_class, } + + if self.write_cnn_features_csv: + annotation["cnn_features"] = prediction.features.tolist() # type: ignore[index] + return annotation @output_formatters.register(BatDetect2OutputConfig) @@ -174,4 +318,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): targets, event_name=config.event_name, annotation_note=config.annotation_note, + write_detection_csv=config.write_detection_csv, + write_cnn_features_csv=config.write_cnn_features_csv, + save_if_empty=config.save_if_empty, + preserve_audio_tree=config.preserve_audio_tree, + include_file_path=config.include_file_path, ) diff --git a/tests/test_api_v2/test_outputs_io.py b/tests/test_api_v2/test_outputs_io.py index 5a62fb8..99c6ade 100644 --- a/tests/test_api_v2/test_outputs_io.py +++ b/tests/test_api_v2/test_outputs_io.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import cast import numpy as np +import pandas as pd import pytest from batdetect2.api_v2 import BatDetect2API @@ -98,6 +99,47 @@ def test_load_predictions_with_format_override( assert "annotation" in loaded_item +def test_load_predictions_with_batdetect2_nested_layout( + api_v2: BatDetect2API, + example_audio_files: list[Path], + tmp_path: Path, +) -> None: + output_dir = tmp_path / "batdetect2_nested" + predictions = [ + api_v2.process_file(audio_file) for audio_file in example_audio_files + ] + + api_v2.save_predictions( + predictions, + path=output_dir, + format="batdetect2", + audio_dir=example_audio_files[0].parent, + ) + + loaded = api_v2.load_predictions(output_dir, format="batdetect2") + + assert len(loaded) == len(example_audio_files) + + +def test_save_predictions_with_batdetect2_writes_cnn_feature_csv( + api_v2: BatDetect2API, + file_prediction, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "batdetect2_cnn" + api_v2.save_predictions( + [file_prediction], + path=output_dir, + config=BatDetect2OutputConfig(write_cnn_features_csv=True), + ) + + cnn_csvs = list(output_dir.rglob("*_cnn_features.csv")) + assert len(cnn_csvs) == 1 + + loaded_df = pd.read_csv(cnn_csvs[0]) + assert not loaded_df.empty + + def test_save_predictions_with_soundevent_override( api_v2: BatDetect2API, file_prediction, diff --git a/tests/test_cli/test_predict.py b/tests/test_cli/test_predict.py index 0fdc91f..5b10bb8 100644 --- a/tests/test_cli/test_predict.py +++ b/tests/test_cli/test_predict.py @@ -1,12 +1,16 @@ """Behavior tests for process CLI workflows.""" +import json from pathlib import Path +import pandas as pd import pytest from click.testing import CliRunner from soundevent import data, io from batdetect2.cli import cli +from batdetect2.outputs import OutputsConfig +from batdetect2.outputs.formats import BatDetect2OutputConfig def test_cli_process_help() -> None: @@ -35,6 +39,7 @@ def test_cli_process_directory_runs_on_real_audio( [ "process", "directory", + "--model", str(tiny_checkpoint_path), str(single_audio_dir), str(output_path), @@ -52,6 +57,158 @@ def test_cli_process_directory_runs_on_real_audio( assert len(list(output_path.glob("*.json"))) == 1 +@pytest.mark.slow +def test_cli_process_directory_runs_on_example_audio_data( + tmp_path: Path, + tiny_checkpoint_path: Path, + example_audio_dir: Path, + example_audio_files: list[Path], +) -> None: + """User story: process the bundled example audio directory.""" + + output_path = tmp_path / "predictions" + + result = CliRunner().invoke( + cli, + [ + "process", + "directory", + "--model", + str(tiny_checkpoint_path), + str(example_audio_dir), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--format", + "batdetect2", + ], + ) + + assert result.exit_code == 0 + assert output_path.exists() + assert len(list(output_path.glob("*.json"))) == len(example_audio_files) + + +@pytest.mark.slow +def test_cli_process_directory_batdetect2_matches_legacy_artifacts( + tmp_path: Path, + tiny_checkpoint_path: Path, + example_audio_dir: Path, + example_audio_files: list[Path], + example_anns_dir: Path, +) -> None: + """User story: process batdetect2 output matches legacy-style files.""" + + output_path = tmp_path / "predictions" + + result = CliRunner().invoke( + cli, + [ + "process", + "directory", + "--model", + str(tiny_checkpoint_path), + str(example_audio_dir), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--format", + "batdetect2", + ], + ) + + assert result.exit_code == 0 + + json_files = sorted(output_path.rglob("*.json")) + csv_files = sorted(output_path.rglob("*.csv")) + + assert len(json_files) == len(example_audio_files) + assert len(csv_files) == len(example_audio_files) + + expected_names = sorted( + audio_file.name for audio_file in example_audio_files + ) + assert sorted(path.stem for path in json_files) == expected_names + assert sorted(path.stem for path in csv_files) == expected_names + + first_output = json.loads(json_files[0].read_text()) + assert "file_path" not in first_output + assert isinstance(first_output["class_name"], str) + assert first_output["class_name"] + + first_annotation = first_output["annotation"][0] + assert first_annotation["individual"] == "-1" + assert isinstance(first_annotation["high_freq"], int) + assert isinstance(first_annotation["low_freq"], int) + + expected_json = json.loads( + (example_anns_dir / json_files[0].name).read_text() + ) + assert first_output["id"] == expected_json["id"] + assert first_output["time_exp"] == expected_json["time_exp"] + + first_csv = pd.read_csv(csv_files[0], index_col=0) + assert list(first_csv.columns) == [ + "det_prob", + "start_time", + "end_time", + "high_freq", + "low_freq", + "class", + "class_prob", + ] + + +@pytest.mark.slow +def test_cli_process_directory_batdetect2_writes_cnn_features_csv_when_enabled( + tmp_path: Path, + tiny_checkpoint_path: Path, + example_audio_dir: Path, +) -> None: + """User story: request legacy CNN feature CSV sidecars via config.""" + + output_path = tmp_path / "predictions" + outputs_config_path = tmp_path / "outputs.yaml" + outputs_config_path.write_text( + OutputsConfig( + format=BatDetect2OutputConfig(write_cnn_features_csv=True) + ).to_yaml_string() + ) + + result = CliRunner().invoke( + cli, + [ + "process", + "directory", + "--model", + str(tiny_checkpoint_path), + str(example_audio_dir), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--outputs-config", + str(outputs_config_path), + ], + ) + + assert result.exit_code == 0 + + cnn_csvs = sorted(output_path.rglob("*_cnn_features.csv")) + assert len(cnn_csvs) == 3 + + first_df = pd.read_csv(cnn_csvs[0]) + assert not first_df.empty + assert list(first_df.columns) == [ + str(ii) for ii in range(len(first_df.columns)) + ] + + def test_cli_process_file_list_runs_on_real_audio( tmp_path: Path, tiny_checkpoint_path: Path, @@ -70,6 +227,7 @@ def test_cli_process_file_list_runs_on_real_audio( [ "process", "file_list", + "--model", str(tiny_checkpoint_path), str(file_list), str(output_path), @@ -117,6 +275,7 @@ def test_cli_process_dataset_runs_on_aoef_metadata( [ "process", "dataset", + "--model", str(tiny_checkpoint_path), str(dataset_path), str(output_path), @@ -159,6 +318,7 @@ def test_cli_process_directory_supports_output_format_override( [ "process", "directory", + "--model", str(tiny_checkpoint_path), str(single_audio_dir), str(output_path), @@ -217,6 +377,7 @@ def test_cli_process_dataset_deduplicates_recordings( [ "process", "dataset", + "--model", str(tiny_checkpoint_path), str(dataset_path), str(output_path), @@ -247,6 +408,7 @@ def test_cli_process_rejects_unknown_output_format( [ "process", "directory", + "--model", str(tiny_checkpoint_path), str(single_audio_dir), str(output_path),