mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: restore legacy batdetect2 process outputs
This commit is contained in:
parent
105384d9a2
commit
c4f759e9a3
@ -14,7 +14,7 @@ Current built-in output formats include:
|
|||||||
- `soundevent`:
|
- `soundevent`:
|
||||||
prediction-set JSON for soundevent-style tooling,
|
prediction-set JSON for soundevent-style tooling,
|
||||||
- `batdetect2`:
|
- `batdetect2`:
|
||||||
legacy per-recording JSON output.
|
legacy-compatible per-recording JSON and CSV outputs.
|
||||||
|
|
||||||
## Select a format from the CLI
|
## Select a format from the CLI
|
||||||
|
|
||||||
@ -61,7 +61,29 @@ batdetect2 process directory \
|
|||||||
- Use `raw` if you want the richest output surface and easy round-tripping.
|
- Use `raw` if you want the richest output surface and easy round-tripping.
|
||||||
- Use `parquet` if you want tabular analysis in Python or data-lake workflows.
|
- Use `parquet` if you want tabular analysis in Python or data-lake workflows.
|
||||||
- Use `soundevent` if you want prediction-set JSON.
|
- Use `soundevent` if you want prediction-set JSON.
|
||||||
- Use `batdetect2` only when you need the legacy JSON shape.
|
- Use `batdetect2` when you need legacy BatDetect2-style outputs.
|
||||||
|
|
||||||
|
## Enable legacy CNN feature CSVs
|
||||||
|
|
||||||
|
The `batdetect2` formatter can also write the legacy CNN feature sidecar CSVs.
|
||||||
|
This is controlled through the outputs config.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
format:
|
||||||
|
name: batdetect2
|
||||||
|
write_cnn_features_csv: true
|
||||||
|
transform:
|
||||||
|
detection_transforms: []
|
||||||
|
clip_transforms: []
|
||||||
|
```
|
||||||
|
|
||||||
|
When enabled, BatDetect2 writes:
|
||||||
|
|
||||||
|
- one `.json` file per recording,
|
||||||
|
- one detection `.csv` file per recording,
|
||||||
|
- one `_cnn_features.csv` file per recording when detections are present.
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
|
|||||||
@ -47,17 +47,29 @@ Writes a prediction-set JSON file.
|
|||||||
|
|
||||||
Defined by `BatDetect2OutputConfig`.
|
Defined by `BatDetect2OutputConfig`.
|
||||||
|
|
||||||
This is the legacy BatDetect2-style JSON output.
|
This is the legacy-compatible BatDetect2 formatter.
|
||||||
|
|
||||||
Key fields:
|
Key fields:
|
||||||
|
|
||||||
- `event_name`
|
- `event_name`
|
||||||
- `annotation_note`
|
- `annotation_note`
|
||||||
|
- `write_detection_csv`
|
||||||
|
- `write_cnn_features_csv`
|
||||||
|
- `save_if_empty`
|
||||||
|
- `preserve_audio_tree`
|
||||||
|
- `include_file_path`
|
||||||
|
|
||||||
Writes one `.json` file per recording.
|
By default it writes one `.json` file and one detection `.csv` file per
|
||||||
|
recording, preserving the input audio directory layout under the output root.
|
||||||
|
|
||||||
|
It can also write legacy `_cnn_features.csv` sidecars when
|
||||||
|
`write_cnn_features_csv` is enabled.
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Outputs config: {doc}`outputs-config`
|
- Outputs config:
|
||||||
- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats`
|
{doc}`outputs-config`
|
||||||
- Understanding formatted outputs: {doc}`../explanation/interpreting-formatted-outputs`
|
- Save predictions in different output formats:
|
||||||
|
{doc}`../how_to/save-predictions-in-different-output-formats`
|
||||||
|
- Understanding formatted outputs:
|
||||||
|
{doc}`../explanation/interpreting-formatted-outputs`
|
||||||
|
|||||||
@ -24,10 +24,18 @@ The output workflow is:
|
|||||||
|
|
||||||
## Default behavior
|
## Default behavior
|
||||||
|
|
||||||
By default, the current stack uses the raw output formatter unless you override it.
|
By default, the current stack uses the raw output formatter unless you override
|
||||||
|
it.
|
||||||
|
|
||||||
|
For CLI processing commands, omitting `--format` now leaves format selection to
|
||||||
|
the loaded outputs config.
|
||||||
|
If no outputs config is provided, the CLI still uses its command defaults.
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Output formats: {doc}`output-formats`
|
- Output formats:
|
||||||
- Output transforms: {doc}`output-transforms`
|
{doc}`output-formats`
|
||||||
- Save predictions in different output formats: {doc}`../how_to/save-predictions-in-different-output-formats`
|
- Output transforms:
|
||||||
|
{doc}`output-transforms`
|
||||||
|
- Save predictions in different output formats:
|
||||||
|
{doc}`../how_to/save-predictions-in-different-output-formats`
|
||||||
|
|||||||
@ -27,6 +27,15 @@ def process() -> None:
|
|||||||
def common_predict_options(func):
|
def common_predict_options(func):
|
||||||
"""Attach options shared by all ``process`` subcommands."""
|
"""Attach options shared by all ``process`` subcommands."""
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--model",
|
||||||
|
"model_path",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
|
||||||
|
"URI to fine-tune from. Defaults to uk_same"
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--audio-config",
|
"--audio-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
@ -97,7 +106,7 @@ def common_predict_options(func):
|
|||||||
|
|
||||||
|
|
||||||
def _build_api(
|
def _build_api(
|
||||||
model_path: str,
|
model_path: str | None,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
inference_config: Path | None,
|
inference_config: Path | None,
|
||||||
outputs_config: Path | None,
|
outputs_config: Path | None,
|
||||||
@ -129,7 +138,7 @@ def _build_api(
|
|||||||
)
|
)
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
path=model_path,
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
inference_config=inference_conf,
|
inference_config=inference_conf,
|
||||||
outputs_config=outputs_conf,
|
outputs_config=outputs_conf,
|
||||||
@ -139,7 +148,7 @@ def _build_api(
|
|||||||
|
|
||||||
|
|
||||||
def _run_prediction(
|
def _run_prediction(
|
||||||
model_path: str,
|
model_path: str | None,
|
||||||
audio_files: list[Path],
|
audio_files: list[Path],
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -190,12 +199,11 @@ def _run_prediction(
|
|||||||
name="directory",
|
name="directory",
|
||||||
short_help="Process audio files in a directory.",
|
short_help="Process audio files in a directory.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=str)
|
|
||||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_directory_command(
|
def predict_directory_command(
|
||||||
model_path: str,
|
model_path: str | None,
|
||||||
audio_dir: Path,
|
audio_dir: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -234,14 +242,13 @@ def predict_directory_command(
|
|||||||
name="file_list",
|
name="file_list",
|
||||||
short_help="Process paths listed in a text file.",
|
short_help="Process paths listed in a text file.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=str)
|
|
||||||
@click.argument("file_list", type=click.Path(exists=True))
|
@click.argument("file_list", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_file_list_command(
|
def predict_file_list_command(
|
||||||
model_path: str,
|
|
||||||
file_list: Path,
|
file_list: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
|
model_path: str | None,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
inference_config: Path | None,
|
inference_config: Path | None,
|
||||||
outputs_config: Path | None,
|
outputs_config: Path | None,
|
||||||
@ -282,14 +289,13 @@ def predict_file_list_command(
|
|||||||
name="dataset",
|
name="dataset",
|
||||||
short_help="Process recordings from a dataset config.",
|
short_help="Process recordings from a dataset config.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=str)
|
|
||||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_dataset_command(
|
def predict_dataset_command(
|
||||||
model_path: str,
|
|
||||||
dataset_path: Path,
|
dataset_path: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
|
model_path: str | None,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
inference_config: Path | None,
|
inference_config: Path | None,
|
||||||
outputs_config: Path | None,
|
outputs_config: Path | None,
|
||||||
|
|||||||
@ -3,7 +3,9 @@ from pathlib import Path
|
|||||||
from typing import List, Literal, Sequence, TypedDict
|
from typing import List, Literal, Sequence, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent import terms as soundevent_terms
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
@ -13,7 +15,6 @@ from batdetect2.outputs.formats.base import (
|
|||||||
)
|
)
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
from batdetect2.targets import terms
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -24,7 +25,7 @@ except ImportError:
|
|||||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||||
|
|
||||||
|
|
||||||
class Annotation(DictWithClass):
|
class Annotation(DictWithClass, total=False):
|
||||||
start_time: float
|
start_time: float
|
||||||
end_time: float
|
end_time: float
|
||||||
low_freq: float
|
low_freq: float
|
||||||
@ -33,6 +34,7 @@ class Annotation(DictWithClass):
|
|||||||
det_prob: float
|
det_prob: float
|
||||||
individual: str
|
individual: str
|
||||||
event: str
|
event: str
|
||||||
|
cnn_features: NotRequired[list[float]] # ty: ignore[invalid-type-form]
|
||||||
|
|
||||||
|
|
||||||
class FileAnnotation(TypedDict):
|
class FileAnnotation(TypedDict):
|
||||||
@ -52,6 +54,11 @@ class BatDetect2OutputConfig(BaseConfig):
|
|||||||
|
|
||||||
event_name: str = "Echolocation"
|
event_name: str = "Echolocation"
|
||||||
annotation_note: str = "Automatically generated."
|
annotation_note: str = "Automatically generated."
|
||||||
|
write_detection_csv: bool = True
|
||||||
|
write_cnn_features_csv: bool = False
|
||||||
|
save_if_empty: bool = False
|
||||||
|
preserve_audio_tree: bool = True
|
||||||
|
include_file_path: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||||
@ -60,10 +67,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
event_name: str,
|
event_name: str,
|
||||||
annotation_note: str,
|
annotation_note: str,
|
||||||
|
write_detection_csv: bool = True,
|
||||||
|
write_cnn_features_csv: bool = False,
|
||||||
|
save_if_empty: bool = False,
|
||||||
|
preserve_audio_tree: bool = True,
|
||||||
|
include_file_path: bool = False,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.event_name = event_name
|
self.event_name = event_name
|
||||||
self.annotation_note = annotation_note
|
self.annotation_note = annotation_note
|
||||||
|
self.write_detection_csv = write_detection_csv
|
||||||
|
self.write_cnn_features_csv = write_cnn_features_csv
|
||||||
|
self.save_if_empty = save_if_empty
|
||||||
|
self.preserve_audio_tree = preserve_audio_tree
|
||||||
|
self.include_file_path = include_file_path
|
||||||
|
|
||||||
def format(
|
def format(
|
||||||
self, predictions: Sequence[ClipDetections]
|
self, predictions: Sequence[ClipDetections]
|
||||||
@ -84,22 +101,56 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
path.mkdir(parents=True)
|
path.mkdir(parents=True)
|
||||||
|
|
||||||
for prediction in predictions:
|
for prediction in predictions:
|
||||||
pred_path = path / (prediction["id"] + ".json")
|
annotations = prediction["annotation"]
|
||||||
|
|
||||||
if audio_dir is not None and "file_path" in prediction:
|
if not annotations and not self.save_if_empty:
|
||||||
prediction["file_path"] = str(
|
continue
|
||||||
make_path_relative(
|
|
||||||
prediction["file_path"],
|
pred_path = self.get_output_path(prediction, path, audio_dir)
|
||||||
audio_dir,
|
pred_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
)
|
|
||||||
|
# make a copy of the prediction
|
||||||
|
data = dict(prediction)
|
||||||
|
|
||||||
|
raw_file_path = data.get("file_path")
|
||||||
|
if audio_dir is not None and isinstance(raw_file_path, str):
|
||||||
|
data["file_path"] = str(
|
||||||
|
make_path_relative(raw_file_path, audio_dir)
|
||||||
)
|
)
|
||||||
|
|
||||||
pred_path.write_text(json.dumps(prediction))
|
if not self.include_file_path:
|
||||||
|
data.pop("file_path", None)
|
||||||
|
|
||||||
|
data["annotation"] = [
|
||||||
|
{
|
||||||
|
key: value
|
||||||
|
for key, value in annotation.items()
|
||||||
|
if key != "cnn_features"
|
||||||
|
}
|
||||||
|
for annotation in data["annotation"]
|
||||||
|
]
|
||||||
|
|
||||||
|
pred_path.write_text(json.dumps(data, indent=2, sort_keys=True))
|
||||||
|
|
||||||
|
if self.write_detection_csv:
|
||||||
|
self.save_detection_csv(
|
||||||
|
prediction,
|
||||||
|
pred_path.with_suffix(".csv"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.write_cnn_features_csv:
|
||||||
|
self.save_cnn_features_csv(
|
||||||
|
prediction,
|
||||||
|
pred_path.with_name(pred_path.stem + "_cnn_features.csv"),
|
||||||
|
)
|
||||||
|
|
||||||
def load(self, path: data.PathLike) -> List[FileAnnotation]:
|
def load(self, path: data.PathLike) -> List[FileAnnotation]:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
files = list(path.glob("*.json"))
|
if path.is_file():
|
||||||
|
files = [path] if path.suffix == ".json" else []
|
||||||
|
else:
|
||||||
|
files = sorted(path.rglob("*.json"))
|
||||||
|
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
@ -108,12 +159,108 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
json.loads(file.read_text()) for file in files if file.is_file()
|
json.loads(file.read_text()) for file in files if file.is_file()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
def get_output_path(
|
||||||
if not annotations:
|
self,
|
||||||
return ""
|
prediction: FileAnnotation,
|
||||||
|
output_dir: Path,
|
||||||
|
audio_dir: data.PathLike | None,
|
||||||
|
) -> Path:
|
||||||
|
if (
|
||||||
|
self.preserve_audio_tree
|
||||||
|
and audio_dir is not None
|
||||||
|
and "file_path" in prediction
|
||||||
|
):
|
||||||
|
relative_path = make_path_relative(
|
||||||
|
prediction["file_path"],
|
||||||
|
audio_dir,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
output_dir / relative_path.parent / f"{prediction['id']}.json"
|
||||||
|
)
|
||||||
|
|
||||||
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
return output_dir / f"{prediction['id']}.json"
|
||||||
return highest_scoring["class"]
|
|
||||||
|
def save_detection_csv(
|
||||||
|
self,
|
||||||
|
prediction: FileAnnotation,
|
||||||
|
path: Path,
|
||||||
|
) -> None:
|
||||||
|
annotations = prediction["annotation"]
|
||||||
|
if not annotations:
|
||||||
|
return
|
||||||
|
|
||||||
|
preds_df = pd.DataFrame(annotations)[
|
||||||
|
[
|
||||||
|
"det_prob",
|
||||||
|
"start_time",
|
||||||
|
"end_time",
|
||||||
|
"high_freq",
|
||||||
|
"low_freq",
|
||||||
|
"class",
|
||||||
|
"class_prob",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
preds_df.to_csv(path, sep=",")
|
||||||
|
|
||||||
|
def save_cnn_features_csv(
|
||||||
|
self, prediction: FileAnnotation, path: Path
|
||||||
|
) -> None:
|
||||||
|
annotations = prediction["annotation"]
|
||||||
|
|
||||||
|
if not annotations:
|
||||||
|
return
|
||||||
|
|
||||||
|
cnn_features = [
|
||||||
|
annotation["cnn_features"]
|
||||||
|
for annotation in annotations
|
||||||
|
if "cnn_features" in annotation
|
||||||
|
]
|
||||||
|
|
||||||
|
if not cnn_features:
|
||||||
|
return
|
||||||
|
|
||||||
|
cnn_feats_df = pd.DataFrame(
|
||||||
|
cnn_features,
|
||||||
|
columns=[str(ii) for ii in range(len(cnn_features[0]))],
|
||||||
|
)
|
||||||
|
|
||||||
|
cnn_feats_df.to_csv(
|
||||||
|
path,
|
||||||
|
sep=",",
|
||||||
|
index=False,
|
||||||
|
float_format="%.5f",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_class_name(self, class_index: int) -> str:
|
||||||
|
class_name = self.targets.class_names[class_index]
|
||||||
|
tags = self.targets.decode_class(class_name)
|
||||||
|
return data.find_tag_value(
|
||||||
|
tags,
|
||||||
|
term=soundevent_terms.scientific_name,
|
||||||
|
default=class_name,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
def get_recording_class(self, detections: Sequence[Detection]) -> str:
|
||||||
|
if not detections:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
class_scores = np.stack(
|
||||||
|
[detection.class_scores for detection in detections],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
detection_scores = np.array(
|
||||||
|
[detection.detection_score for detection in detections],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
weighted_scores = (class_scores * detection_scores).sum(axis=1)
|
||||||
|
|
||||||
|
total = weighted_scores.sum()
|
||||||
|
|
||||||
|
if total <= 0:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
top_class_index = int(np.argmax(weighted_scores / total))
|
||||||
|
return self.get_class_name(top_class_index)
|
||||||
|
|
||||||
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
|
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
|
||||||
recording = prediction.clip.recording
|
recording = prediction.clip.recording
|
||||||
@ -123,26 +270,19 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
for pred in prediction.detections
|
for pred in prediction.detections
|
||||||
]
|
]
|
||||||
|
|
||||||
return FileAnnotation(
|
file_annotation = FileAnnotation(
|
||||||
id=recording.path.name,
|
id=recording.path.name,
|
||||||
file_path=str(recording.path),
|
|
||||||
annotated=False,
|
annotated=False,
|
||||||
duration=recording.duration,
|
duration=round(float(recording.duration), 4),
|
||||||
issues=False,
|
issues=False,
|
||||||
time_exp=recording.time_expansion,
|
time_exp=recording.time_expansion,
|
||||||
class_name=self.get_recording_class(annotations),
|
class_name=self.get_recording_class(prediction.detections),
|
||||||
notes=self.annotation_note,
|
notes=self.annotation_note,
|
||||||
annotation=annotations,
|
annotation=annotations,
|
||||||
|
file_path=str(recording.path),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_class_name(self, class_index: int) -> str:
|
return file_annotation
|
||||||
class_name = self.targets.class_names[class_index]
|
|
||||||
tags = self.targets.decode_class(class_name)
|
|
||||||
return data.find_tag_value(
|
|
||||||
tags,
|
|
||||||
term=terms.generic_class,
|
|
||||||
default=class_name,
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
def format_sound_event_prediction(
|
def format_sound_event_prediction(
|
||||||
self, prediction: Detection
|
self, prediction: Detection
|
||||||
@ -155,16 +295,20 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
top_class_score = float(prediction.class_scores[top_class_index])
|
top_class_score = float(prediction.class_scores[top_class_index])
|
||||||
top_class = self.get_class_name(top_class_index)
|
top_class = self.get_class_name(top_class_index)
|
||||||
annotation: Annotation = {
|
annotation: Annotation = {
|
||||||
"start_time": start_time,
|
"start_time": round(float(start_time), 4),
|
||||||
"end_time": end_time,
|
"end_time": round(float(end_time), 4),
|
||||||
"low_freq": low_freq,
|
"low_freq": int(low_freq),
|
||||||
"high_freq": high_freq,
|
"high_freq": int(high_freq),
|
||||||
"class_prob": top_class_score,
|
"class_prob": round(top_class_score, 3),
|
||||||
"det_prob": float(prediction.detection_score),
|
"det_prob": round(float(prediction.detection_score), 3),
|
||||||
"individual": "",
|
"individual": "-1",
|
||||||
"event": self.event_name,
|
"event": self.event_name,
|
||||||
"class": top_class,
|
"class": top_class,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.write_cnn_features_csv:
|
||||||
|
annotation["cnn_features"] = prediction.features.tolist() # type: ignore[index]
|
||||||
|
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
@output_formatters.register(BatDetect2OutputConfig)
|
@output_formatters.register(BatDetect2OutputConfig)
|
||||||
@ -174,4 +318,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
targets,
|
targets,
|
||||||
event_name=config.event_name,
|
event_name=config.event_name,
|
||||||
annotation_note=config.annotation_note,
|
annotation_note=config.annotation_note,
|
||||||
|
write_detection_csv=config.write_detection_csv,
|
||||||
|
write_cnn_features_csv=config.write_cnn_features_csv,
|
||||||
|
save_if_empty=config.save_if_empty,
|
||||||
|
preserve_audio_tree=config.preserve_audio_tree,
|
||||||
|
include_file_path=config.include_file_path,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from pathlib import Path
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
@ -98,6 +99,47 @@ def test_load_predictions_with_format_override(
|
|||||||
assert "annotation" in loaded_item
|
assert "annotation" in loaded_item
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_predictions_with_batdetect2_nested_layout(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "batdetect2_nested"
|
||||||
|
predictions = [
|
||||||
|
api_v2.process_file(audio_file) for audio_file in example_audio_files
|
||||||
|
]
|
||||||
|
|
||||||
|
api_v2.save_predictions(
|
||||||
|
predictions,
|
||||||
|
path=output_dir,
|
||||||
|
format="batdetect2",
|
||||||
|
audio_dir=example_audio_files[0].parent,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
|
||||||
|
|
||||||
|
assert len(loaded) == len(example_audio_files)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_predictions_with_batdetect2_writes_cnn_feature_csv(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
file_prediction,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "batdetect2_cnn"
|
||||||
|
api_v2.save_predictions(
|
||||||
|
[file_prediction],
|
||||||
|
path=output_dir,
|
||||||
|
config=BatDetect2OutputConfig(write_cnn_features_csv=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
cnn_csvs = list(output_dir.rglob("*_cnn_features.csv"))
|
||||||
|
assert len(cnn_csvs) == 1
|
||||||
|
|
||||||
|
loaded_df = pd.read_csv(cnn_csvs[0])
|
||||||
|
assert not loaded_df.empty
|
||||||
|
|
||||||
|
|
||||||
def test_save_predictions_with_soundevent_override(
|
def test_save_predictions_with_soundevent_override(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
file_prediction,
|
file_prediction,
|
||||||
|
|||||||
@ -1,12 +1,16 @@
|
|||||||
"""Behavior tests for process CLI workflows."""
|
"""Behavior tests for process CLI workflows."""
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
from batdetect2.outputs.formats import BatDetect2OutputConfig
|
||||||
|
|
||||||
|
|
||||||
def test_cli_process_help() -> None:
|
def test_cli_process_help() -> None:
|
||||||
@ -35,6 +39,7 @@ def test_cli_process_directory_runs_on_real_audio(
|
|||||||
[
|
[
|
||||||
"process",
|
"process",
|
||||||
"directory",
|
"directory",
|
||||||
|
"--model",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(single_audio_dir),
|
str(single_audio_dir),
|
||||||
str(output_path),
|
str(output_path),
|
||||||
@ -52,6 +57,158 @@ def test_cli_process_directory_runs_on_real_audio(
|
|||||||
assert len(list(output_path.glob("*.json"))) == 1
|
assert len(list(output_path.glob("*.json"))) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_cli_process_directory_runs_on_example_audio_data(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
example_audio_dir: Path,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: process the bundled example audio directory."""
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"process",
|
||||||
|
"directory",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(example_audio_dir),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
"batdetect2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert output_path.exists()
|
||||||
|
assert len(list(output_path.glob("*.json"))) == len(example_audio_files)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_cli_process_directory_batdetect2_matches_legacy_artifacts(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
example_audio_dir: Path,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
example_anns_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: process batdetect2 output matches legacy-style files."""
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"process",
|
||||||
|
"directory",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(example_audio_dir),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
"batdetect2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
json_files = sorted(output_path.rglob("*.json"))
|
||||||
|
csv_files = sorted(output_path.rglob("*.csv"))
|
||||||
|
|
||||||
|
assert len(json_files) == len(example_audio_files)
|
||||||
|
assert len(csv_files) == len(example_audio_files)
|
||||||
|
|
||||||
|
expected_names = sorted(
|
||||||
|
audio_file.name for audio_file in example_audio_files
|
||||||
|
)
|
||||||
|
assert sorted(path.stem for path in json_files) == expected_names
|
||||||
|
assert sorted(path.stem for path in csv_files) == expected_names
|
||||||
|
|
||||||
|
first_output = json.loads(json_files[0].read_text())
|
||||||
|
assert "file_path" not in first_output
|
||||||
|
assert isinstance(first_output["class_name"], str)
|
||||||
|
assert first_output["class_name"]
|
||||||
|
|
||||||
|
first_annotation = first_output["annotation"][0]
|
||||||
|
assert first_annotation["individual"] == "-1"
|
||||||
|
assert isinstance(first_annotation["high_freq"], int)
|
||||||
|
assert isinstance(first_annotation["low_freq"], int)
|
||||||
|
|
||||||
|
expected_json = json.loads(
|
||||||
|
(example_anns_dir / json_files[0].name).read_text()
|
||||||
|
)
|
||||||
|
assert first_output["id"] == expected_json["id"]
|
||||||
|
assert first_output["time_exp"] == expected_json["time_exp"]
|
||||||
|
|
||||||
|
first_csv = pd.read_csv(csv_files[0], index_col=0)
|
||||||
|
assert list(first_csv.columns) == [
|
||||||
|
"det_prob",
|
||||||
|
"start_time",
|
||||||
|
"end_time",
|
||||||
|
"high_freq",
|
||||||
|
"low_freq",
|
||||||
|
"class",
|
||||||
|
"class_prob",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_cli_process_directory_batdetect2_writes_cnn_features_csv_when_enabled(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
example_audio_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: request legacy CNN feature CSV sidecars via config."""
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
outputs_config_path = tmp_path / "outputs.yaml"
|
||||||
|
outputs_config_path.write_text(
|
||||||
|
OutputsConfig(
|
||||||
|
format=BatDetect2OutputConfig(write_cnn_features_csv=True)
|
||||||
|
).to_yaml_string()
|
||||||
|
)
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"process",
|
||||||
|
"directory",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(example_audio_dir),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--outputs-config",
|
||||||
|
str(outputs_config_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
cnn_csvs = sorted(output_path.rglob("*_cnn_features.csv"))
|
||||||
|
assert len(cnn_csvs) == 3
|
||||||
|
|
||||||
|
first_df = pd.read_csv(cnn_csvs[0])
|
||||||
|
assert not first_df.empty
|
||||||
|
assert list(first_df.columns) == [
|
||||||
|
str(ii) for ii in range(len(first_df.columns))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_cli_process_file_list_runs_on_real_audio(
|
def test_cli_process_file_list_runs_on_real_audio(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
@ -70,6 +227,7 @@ def test_cli_process_file_list_runs_on_real_audio(
|
|||||||
[
|
[
|
||||||
"process",
|
"process",
|
||||||
"file_list",
|
"file_list",
|
||||||
|
"--model",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(file_list),
|
str(file_list),
|
||||||
str(output_path),
|
str(output_path),
|
||||||
@ -117,6 +275,7 @@ def test_cli_process_dataset_runs_on_aoef_metadata(
|
|||||||
[
|
[
|
||||||
"process",
|
"process",
|
||||||
"dataset",
|
"dataset",
|
||||||
|
"--model",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(dataset_path),
|
str(dataset_path),
|
||||||
str(output_path),
|
str(output_path),
|
||||||
@ -159,6 +318,7 @@ def test_cli_process_directory_supports_output_format_override(
|
|||||||
[
|
[
|
||||||
"process",
|
"process",
|
||||||
"directory",
|
"directory",
|
||||||
|
"--model",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(single_audio_dir),
|
str(single_audio_dir),
|
||||||
str(output_path),
|
str(output_path),
|
||||||
@ -217,6 +377,7 @@ def test_cli_process_dataset_deduplicates_recordings(
|
|||||||
[
|
[
|
||||||
"process",
|
"process",
|
||||||
"dataset",
|
"dataset",
|
||||||
|
"--model",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(dataset_path),
|
str(dataset_path),
|
||||||
str(output_path),
|
str(output_path),
|
||||||
@ -247,6 +408,7 @@ def test_cli_process_rejects_unknown_output_format(
|
|||||||
[
|
[
|
||||||
"process",
|
"process",
|
||||||
"directory",
|
"directory",
|
||||||
|
"--model",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(single_audio_dir),
|
str(single_audio_dir),
|
||||||
str(output_path),
|
str(output_path),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user