mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: restore legacy batdetect2 process outputs
This commit is contained in:
parent
105384d9a2
commit
c4f759e9a3
@ -14,7 +14,7 @@ Current built-in output formats include:
|
||||
- `soundevent`:
|
||||
prediction-set JSON for soundevent-style tooling,
|
||||
- `batdetect2`:
|
||||
legacy per-recording JSON output.
|
||||
legacy-compatible per-recording JSON and CSV outputs.
|
||||
|
||||
## Select a format from the CLI
|
||||
|
||||
@ -61,7 +61,29 @@ batdetect2 process directory \
|
||||
- Use `raw` if you want the richest output surface and easy round-tripping.
|
||||
- Use `parquet` if you want tabular analysis in Python or data-lake workflows.
|
||||
- Use `soundevent` if you want prediction-set JSON.
|
||||
- Use `batdetect2` only when you need the legacy JSON shape.
|
||||
- Use `batdetect2` when you need legacy BatDetect2-style outputs.
|
||||
|
||||
## Enable legacy CNN feature CSVs
|
||||
|
||||
The `batdetect2` formatter can also write the legacy CNN feature sidecar CSVs.
|
||||
This is controlled through the outputs config.
|
||||
|
||||
Example:
|
||||
|
||||
```yaml
|
||||
format:
|
||||
name: batdetect2
|
||||
write_cnn_features_csv: true
|
||||
transform:
|
||||
detection_transforms: []
|
||||
clip_transforms: []
|
||||
```
|
||||
|
||||
When enabled, BatDetect2 writes:
|
||||
|
||||
- one `.json` file per recording,
|
||||
- one detection `.csv` file per recording,
|
||||
- one `_cnn_features.csv` file per recording when detections are present.
|
||||
|
||||
## Related pages
|
||||
|
||||
|
||||
@ -47,17 +47,29 @@ Writes a prediction-set JSON file.
|
||||
|
||||
Defined by `BatDetect2OutputConfig`.
|
||||
|
||||
This is the legacy BatDetect2-style JSON output.
|
||||
This is the legacy-compatible BatDetect2 formatter.
|
||||
|
||||
Key fields:
|
||||
|
||||
- `event_name`
|
||||
- `annotation_note`
|
||||
- `write_detection_csv`
|
||||
- `write_cnn_features_csv`
|
||||
- `save_if_empty`
|
||||
- `preserve_audio_tree`
|
||||
- `include_file_path`
|
||||
|
||||
Writes one `.json` file per recording.
|
||||
By default it writes one `.json` file and one detection `.csv` file per
|
||||
recording, preserving the input audio directory layout under the output root.
|
||||
|
||||
It can also write legacy `_cnn_features.csv` sidecars when
|
||||
`write_cnn_features_csv` is enabled.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Outputs config: {doc}`outputs-config`
|
||||
- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats`
|
||||
- Understanding formatted outputs: {doc}`../explanation/interpreting-formatted-outputs`
|
||||
- Outputs config:
|
||||
{doc}`outputs-config`
|
||||
- Save predictions in different output formats:
|
||||
{doc}`../how_to/save-predictions-in-different-output-formats`
|
||||
- Understanding formatted outputs:
|
||||
{doc}`../explanation/interpreting-formatted-outputs`
|
||||
|
||||
@ -24,10 +24,18 @@ The output workflow is:
|
||||
|
||||
## Default behavior
|
||||
|
||||
By default, the current stack uses the raw output formatter unless you override it.
|
||||
By default, the current stack uses the raw output formatter unless you override
|
||||
it.
|
||||
|
||||
For CLI processing commands, omitting `--format` now leaves format selection to
|
||||
the loaded outputs config.
|
||||
If no outputs config is provided, the CLI still uses its command defaults.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Output formats: {doc}`output-formats`
|
||||
- Output transforms: {doc}`output-transforms`
|
||||
- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats`
|
||||
- Output formats:
|
||||
{doc}`output-formats`
|
||||
- Output transforms:
|
||||
{doc}`output-transforms`
|
||||
- Save predictions in different output formats:
|
||||
{doc}`../how_to/save-predictions-in-different-output-formats`
|
||||
|
||||
@ -27,6 +27,15 @@ def process() -> None:
|
||||
def common_predict_options(func):
|
||||
"""Attach options shared by all ``process`` subcommands."""
|
||||
|
||||
@click.option(
|
||||
"--model",
|
||||
"model_path",
|
||||
type=str,
|
||||
help=(
|
||||
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
|
||||
"URI to fine-tune from. Defaults to uk_same"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--audio-config",
|
||||
type=click.Path(exists=True),
|
||||
@ -97,7 +106,7 @@ def common_predict_options(func):
|
||||
|
||||
|
||||
def _build_api(
|
||||
model_path: str,
|
||||
model_path: str | None,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
@ -129,7 +138,7 @@ def _build_api(
|
||||
)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
path=model_path,
|
||||
audio_config=audio_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
@ -139,7 +148,7 @@ def _build_api(
|
||||
|
||||
|
||||
def _run_prediction(
|
||||
model_path: str,
|
||||
model_path: str | None,
|
||||
audio_files: list[Path],
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -190,12 +199,11 @@ def _run_prediction(
|
||||
name="directory",
|
||||
short_help="Process audio files in a directory.",
|
||||
)
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_directory_command(
|
||||
model_path: str,
|
||||
model_path: str | None,
|
||||
audio_dir: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -234,14 +242,13 @@ def predict_directory_command(
|
||||
name="file_list",
|
||||
short_help="Process paths listed in a text file.",
|
||||
)
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("file_list", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_file_list_command(
|
||||
model_path: str,
|
||||
file_list: Path,
|
||||
output_path: Path,
|
||||
model_path: str | None,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
@ -282,14 +289,13 @@ def predict_file_list_command(
|
||||
name="dataset",
|
||||
short_help="Process recordings from a dataset config.",
|
||||
)
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_dataset_command(
|
||||
model_path: str,
|
||||
dataset_path: Path,
|
||||
output_path: Path,
|
||||
model_path: str | None,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
|
||||
@ -3,7 +3,9 @@ from pathlib import Path
|
||||
from typing import List, Literal, Sequence, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from soundevent import data
|
||||
from soundevent import terms as soundevent_terms
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
@ -13,7 +15,6 @@ from batdetect2.outputs.formats.base import (
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
try:
|
||||
@ -24,7 +25,7 @@ except ImportError:
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
|
||||
|
||||
class Annotation(DictWithClass):
|
||||
class Annotation(DictWithClass, total=False):
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
@ -33,6 +34,7 @@ class Annotation(DictWithClass):
|
||||
det_prob: float
|
||||
individual: str
|
||||
event: str
|
||||
cnn_features: NotRequired[list[float]] # ty: ignore[invalid-type-form]
|
||||
|
||||
|
||||
class FileAnnotation(TypedDict):
|
||||
@ -52,6 +54,11 @@ class BatDetect2OutputConfig(BaseConfig):
|
||||
|
||||
event_name: str = "Echolocation"
|
||||
annotation_note: str = "Automatically generated."
|
||||
write_detection_csv: bool = True
|
||||
write_cnn_features_csv: bool = False
|
||||
save_if_empty: bool = False
|
||||
preserve_audio_tree: bool = True
|
||||
include_file_path: bool = False
|
||||
|
||||
|
||||
class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
@ -60,10 +67,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
targets: TargetProtocol,
|
||||
event_name: str,
|
||||
annotation_note: str,
|
||||
write_detection_csv: bool = True,
|
||||
write_cnn_features_csv: bool = False,
|
||||
save_if_empty: bool = False,
|
||||
preserve_audio_tree: bool = True,
|
||||
include_file_path: bool = False,
|
||||
):
|
||||
self.targets = targets
|
||||
self.event_name = event_name
|
||||
self.annotation_note = annotation_note
|
||||
self.write_detection_csv = write_detection_csv
|
||||
self.write_cnn_features_csv = write_cnn_features_csv
|
||||
self.save_if_empty = save_if_empty
|
||||
self.preserve_audio_tree = preserve_audio_tree
|
||||
self.include_file_path = include_file_path
|
||||
|
||||
def format(
|
||||
self, predictions: Sequence[ClipDetections]
|
||||
@ -84,22 +101,56 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
path.mkdir(parents=True)
|
||||
|
||||
for prediction in predictions:
|
||||
pred_path = path / (prediction["id"] + ".json")
|
||||
annotations = prediction["annotation"]
|
||||
|
||||
if audio_dir is not None and "file_path" in prediction:
|
||||
prediction["file_path"] = str(
|
||||
make_path_relative(
|
||||
prediction["file_path"],
|
||||
audio_dir,
|
||||
)
|
||||
if not annotations and not self.save_if_empty:
|
||||
continue
|
||||
|
||||
pred_path = self.get_output_path(prediction, path, audio_dir)
|
||||
pred_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# make a copy of the prediction
|
||||
data = dict(prediction)
|
||||
|
||||
raw_file_path = data.get("file_path")
|
||||
if audio_dir is not None and isinstance(raw_file_path, str):
|
||||
data["file_path"] = str(
|
||||
make_path_relative(raw_file_path, audio_dir)
|
||||
)
|
||||
|
||||
pred_path.write_text(json.dumps(prediction))
|
||||
if not self.include_file_path:
|
||||
data.pop("file_path", None)
|
||||
|
||||
data["annotation"] = [
|
||||
{
|
||||
key: value
|
||||
for key, value in annotation.items()
|
||||
if key != "cnn_features"
|
||||
}
|
||||
for annotation in data["annotation"]
|
||||
]
|
||||
|
||||
pred_path.write_text(json.dumps(data, indent=2, sort_keys=True))
|
||||
|
||||
if self.write_detection_csv:
|
||||
self.save_detection_csv(
|
||||
prediction,
|
||||
pred_path.with_suffix(".csv"),
|
||||
)
|
||||
|
||||
if self.write_cnn_features_csv:
|
||||
self.save_cnn_features_csv(
|
||||
prediction,
|
||||
pred_path.with_name(pred_path.stem + "_cnn_features.csv"),
|
||||
)
|
||||
|
||||
def load(self, path: data.PathLike) -> List[FileAnnotation]:
|
||||
path = Path(path)
|
||||
|
||||
files = list(path.glob("*.json"))
|
||||
if path.is_file():
|
||||
files = [path] if path.suffix == ".json" else []
|
||||
else:
|
||||
files = sorted(path.rglob("*.json"))
|
||||
|
||||
if not files:
|
||||
return []
|
||||
@ -108,12 +159,108 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
json.loads(file.read_text()) for file in files if file.is_file()
|
||||
]
|
||||
|
||||
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
||||
if not annotations:
|
||||
return ""
|
||||
def get_output_path(
|
||||
self,
|
||||
prediction: FileAnnotation,
|
||||
output_dir: Path,
|
||||
audio_dir: data.PathLike | None,
|
||||
) -> Path:
|
||||
if (
|
||||
self.preserve_audio_tree
|
||||
and audio_dir is not None
|
||||
and "file_path" in prediction
|
||||
):
|
||||
relative_path = make_path_relative(
|
||||
prediction["file_path"],
|
||||
audio_dir,
|
||||
)
|
||||
return (
|
||||
output_dir / relative_path.parent / f"{prediction['id']}.json"
|
||||
)
|
||||
|
||||
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
||||
return highest_scoring["class"]
|
||||
return output_dir / f"{prediction['id']}.json"
|
||||
|
||||
def save_detection_csv(
|
||||
self,
|
||||
prediction: FileAnnotation,
|
||||
path: Path,
|
||||
) -> None:
|
||||
annotations = prediction["annotation"]
|
||||
if not annotations:
|
||||
return
|
||||
|
||||
preds_df = pd.DataFrame(annotations)[
|
||||
[
|
||||
"det_prob",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"high_freq",
|
||||
"low_freq",
|
||||
"class",
|
||||
"class_prob",
|
||||
]
|
||||
]
|
||||
preds_df.to_csv(path, sep=",")
|
||||
|
||||
def save_cnn_features_csv(
|
||||
self, prediction: FileAnnotation, path: Path
|
||||
) -> None:
|
||||
annotations = prediction["annotation"]
|
||||
|
||||
if not annotations:
|
||||
return
|
||||
|
||||
cnn_features = [
|
||||
annotation["cnn_features"]
|
||||
for annotation in annotations
|
||||
if "cnn_features" in annotation
|
||||
]
|
||||
|
||||
if not cnn_features:
|
||||
return
|
||||
|
||||
cnn_feats_df = pd.DataFrame(
|
||||
cnn_features,
|
||||
columns=[str(ii) for ii in range(len(cnn_features[0]))],
|
||||
)
|
||||
|
||||
cnn_feats_df.to_csv(
|
||||
path,
|
||||
sep=",",
|
||||
index=False,
|
||||
float_format="%.5f",
|
||||
)
|
||||
|
||||
def get_class_name(self, class_index: int) -> str:
|
||||
class_name = self.targets.class_names[class_index]
|
||||
tags = self.targets.decode_class(class_name)
|
||||
return data.find_tag_value(
|
||||
tags,
|
||||
term=soundevent_terms.scientific_name,
|
||||
default=class_name,
|
||||
) # type: ignore
|
||||
|
||||
def get_recording_class(self, detections: Sequence[Detection]) -> str:
|
||||
if not detections:
|
||||
return "None"
|
||||
|
||||
class_scores = np.stack(
|
||||
[detection.class_scores for detection in detections],
|
||||
axis=1,
|
||||
)
|
||||
detection_scores = np.array(
|
||||
[detection.detection_score for detection in detections],
|
||||
dtype=np.float32,
|
||||
)
|
||||
weighted_scores = (class_scores * detection_scores).sum(axis=1)
|
||||
|
||||
total = weighted_scores.sum()
|
||||
|
||||
if total <= 0:
|
||||
return "None"
|
||||
|
||||
top_class_index = int(np.argmax(weighted_scores / total))
|
||||
return self.get_class_name(top_class_index)
|
||||
|
||||
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
|
||||
recording = prediction.clip.recording
|
||||
@ -123,26 +270,19 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
for pred in prediction.detections
|
||||
]
|
||||
|
||||
return FileAnnotation(
|
||||
file_annotation = FileAnnotation(
|
||||
id=recording.path.name,
|
||||
file_path=str(recording.path),
|
||||
annotated=False,
|
||||
duration=recording.duration,
|
||||
duration=round(float(recording.duration), 4),
|
||||
issues=False,
|
||||
time_exp=recording.time_expansion,
|
||||
class_name=self.get_recording_class(annotations),
|
||||
class_name=self.get_recording_class(prediction.detections),
|
||||
notes=self.annotation_note,
|
||||
annotation=annotations,
|
||||
file_path=str(recording.path),
|
||||
)
|
||||
|
||||
def get_class_name(self, class_index: int) -> str:
|
||||
class_name = self.targets.class_names[class_index]
|
||||
tags = self.targets.decode_class(class_name)
|
||||
return data.find_tag_value(
|
||||
tags,
|
||||
term=terms.generic_class,
|
||||
default=class_name,
|
||||
) # type: ignore
|
||||
return file_annotation
|
||||
|
||||
def format_sound_event_prediction(
|
||||
self, prediction: Detection
|
||||
@ -155,16 +295,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
top_class_score = float(prediction.class_scores[top_class_index])
|
||||
top_class = self.get_class_name(top_class_index)
|
||||
annotation: Annotation = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": top_class_score,
|
||||
"det_prob": float(prediction.detection_score),
|
||||
"individual": "",
|
||||
"start_time": round(float(start_time), 4),
|
||||
"end_time": round(float(end_time), 4),
|
||||
"low_freq": int(low_freq),
|
||||
"high_freq": int(high_freq),
|
||||
"class_prob": round(top_class_score, 3),
|
||||
"det_prob": round(float(prediction.detection_score), 3),
|
||||
"individual": "-1",
|
||||
"event": self.event_name,
|
||||
"class": top_class,
|
||||
}
|
||||
|
||||
if self.write_cnn_features_csv:
|
||||
annotation["cnn_features"] = prediction.features.tolist() # type: ignore[index]
|
||||
|
||||
return annotation
|
||||
|
||||
@output_formatters.register(BatDetect2OutputConfig)
|
||||
@ -174,4 +318,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
targets,
|
||||
event_name=config.event_name,
|
||||
annotation_note=config.annotation_note,
|
||||
write_detection_csv=config.write_detection_csv,
|
||||
write_cnn_features_csv=config.write_cnn_features_csv,
|
||||
save_if_empty=config.save_if_empty,
|
||||
preserve_audio_tree=config.preserve_audio_tree,
|
||||
include_file_path=config.include_file_path,
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@ from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
@ -98,6 +99,47 @@ def test_load_predictions_with_format_override(
|
||||
assert "annotation" in loaded_item
|
||||
|
||||
|
||||
def test_load_predictions_with_batdetect2_nested_layout(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "batdetect2_nested"
|
||||
predictions = [
|
||||
api_v2.process_file(audio_file) for audio_file in example_audio_files
|
||||
]
|
||||
|
||||
api_v2.save_predictions(
|
||||
predictions,
|
||||
path=output_dir,
|
||||
format="batdetect2",
|
||||
audio_dir=example_audio_files[0].parent,
|
||||
)
|
||||
|
||||
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
|
||||
|
||||
assert len(loaded) == len(example_audio_files)
|
||||
|
||||
|
||||
def test_save_predictions_with_batdetect2_writes_cnn_feature_csv(
|
||||
api_v2: BatDetect2API,
|
||||
file_prediction,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "batdetect2_cnn"
|
||||
api_v2.save_predictions(
|
||||
[file_prediction],
|
||||
path=output_dir,
|
||||
config=BatDetect2OutputConfig(write_cnn_features_csv=True),
|
||||
)
|
||||
|
||||
cnn_csvs = list(output_dir.rglob("*_cnn_features.csv"))
|
||||
assert len(cnn_csvs) == 1
|
||||
|
||||
loaded_df = pd.read_csv(cnn_csvs[0])
|
||||
assert not loaded_df.empty
|
||||
|
||||
|
||||
def test_save_predictions_with_soundevent_override(
|
||||
api_v2: BatDetect2API,
|
||||
file_prediction,
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
"""Behavior tests for process CLI workflows."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.cli import cli
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.outputs.formats import BatDetect2OutputConfig
|
||||
|
||||
|
||||
def test_cli_process_help() -> None:
|
||||
@ -35,6 +39,7 @@ def test_cli_process_directory_runs_on_real_audio(
|
||||
[
|
||||
"process",
|
||||
"directory",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
str(output_path),
|
||||
@ -52,6 +57,158 @@ def test_cli_process_directory_runs_on_real_audio(
|
||||
assert len(list(output_path.glob("*.json"))) == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_process_directory_runs_on_example_audio_data(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
example_audio_dir: Path,
|
||||
example_audio_files: list[Path],
|
||||
) -> None:
|
||||
"""User story: process the bundled example audio directory."""
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"process",
|
||||
"directory",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(example_audio_dir),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
"batdetect2",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert output_path.exists()
|
||||
assert len(list(output_path.glob("*.json"))) == len(example_audio_files)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_process_directory_batdetect2_matches_legacy_artifacts(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
example_audio_dir: Path,
|
||||
example_audio_files: list[Path],
|
||||
example_anns_dir: Path,
|
||||
) -> None:
|
||||
"""User story: process batdetect2 output matches legacy-style files."""
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"process",
|
||||
"directory",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(example_audio_dir),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
"batdetect2",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
json_files = sorted(output_path.rglob("*.json"))
|
||||
csv_files = sorted(output_path.rglob("*.csv"))
|
||||
|
||||
assert len(json_files) == len(example_audio_files)
|
||||
assert len(csv_files) == len(example_audio_files)
|
||||
|
||||
expected_names = sorted(
|
||||
audio_file.name for audio_file in example_audio_files
|
||||
)
|
||||
assert sorted(path.stem for path in json_files) == expected_names
|
||||
assert sorted(path.stem for path in csv_files) == expected_names
|
||||
|
||||
first_output = json.loads(json_files[0].read_text())
|
||||
assert "file_path" not in first_output
|
||||
assert isinstance(first_output["class_name"], str)
|
||||
assert first_output["class_name"]
|
||||
|
||||
first_annotation = first_output["annotation"][0]
|
||||
assert first_annotation["individual"] == "-1"
|
||||
assert isinstance(first_annotation["high_freq"], int)
|
||||
assert isinstance(first_annotation["low_freq"], int)
|
||||
|
||||
expected_json = json.loads(
|
||||
(example_anns_dir / json_files[0].name).read_text()
|
||||
)
|
||||
assert first_output["id"] == expected_json["id"]
|
||||
assert first_output["time_exp"] == expected_json["time_exp"]
|
||||
|
||||
first_csv = pd.read_csv(csv_files[0], index_col=0)
|
||||
assert list(first_csv.columns) == [
|
||||
"det_prob",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"high_freq",
|
||||
"low_freq",
|
||||
"class",
|
||||
"class_prob",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_process_directory_batdetect2_writes_cnn_features_csv_when_enabled(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
example_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: request legacy CNN feature CSV sidecars via config."""
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
outputs_config_path = tmp_path / "outputs.yaml"
|
||||
outputs_config_path.write_text(
|
||||
OutputsConfig(
|
||||
format=BatDetect2OutputConfig(write_cnn_features_csv=True)
|
||||
).to_yaml_string()
|
||||
)
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"process",
|
||||
"directory",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(example_audio_dir),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--outputs-config",
|
||||
str(outputs_config_path),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
cnn_csvs = sorted(output_path.rglob("*_cnn_features.csv"))
|
||||
assert len(cnn_csvs) == 3
|
||||
|
||||
first_df = pd.read_csv(cnn_csvs[0])
|
||||
assert not first_df.empty
|
||||
assert list(first_df.columns) == [
|
||||
str(ii) for ii in range(len(first_df.columns))
|
||||
]
|
||||
|
||||
|
||||
def test_cli_process_file_list_runs_on_real_audio(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
@ -70,6 +227,7 @@ def test_cli_process_file_list_runs_on_real_audio(
|
||||
[
|
||||
"process",
|
||||
"file_list",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(file_list),
|
||||
str(output_path),
|
||||
@ -117,6 +275,7 @@ def test_cli_process_dataset_runs_on_aoef_metadata(
|
||||
[
|
||||
"process",
|
||||
"dataset",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(dataset_path),
|
||||
str(output_path),
|
||||
@ -159,6 +318,7 @@ def test_cli_process_directory_supports_output_format_override(
|
||||
[
|
||||
"process",
|
||||
"directory",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
str(output_path),
|
||||
@ -217,6 +377,7 @@ def test_cli_process_dataset_deduplicates_recordings(
|
||||
[
|
||||
"process",
|
||||
"dataset",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(dataset_path),
|
||||
str(output_path),
|
||||
@ -247,6 +408,7 @@ def test_cli_process_rejects_unknown_output_format(
|
||||
[
|
||||
"process",
|
||||
"directory",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
str(output_path),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user