Create save evaluation results

This commit is contained in:
mbsantiago 2026-03-18 16:49:22 +00:00
parent a332c5c3bd
commit 8e35956007
5 changed files with 239 additions and 16 deletions

View File

@ -1,4 +1,3 @@
import json
from pathlib import Path from pathlib import Path
from typing import Sequence from typing import Sequence
@ -16,6 +15,7 @@ from batdetect2.evaluate import (
EvaluatorProtocol, EvaluatorProtocol,
build_evaluator, build_evaluator,
run_evaluate, run_evaluate,
save_evaluation_results,
) )
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import DEFAULT_LOGS_DIR
@ -148,21 +148,11 @@ class BatDetect2API:
metrics = self.evaluator.compute_metrics(clip_evals) metrics = self.evaluator.compute_metrics(clip_evals)
if output_dir is not None: if output_dir is not None:
output_dir = Path(output_dir) save_evaluation_results(
metrics=metrics,
if not output_dir.is_dir(): plots=self.evaluator.generate_plots(clip_evals),
output_dir.mkdir(parents=True) output_dir=output_dir,
)
metrics_path = output_dir / "metrics.json"
metrics_path.write_text(json.dumps(metrics))
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
fig_path = output_dir / figure_name
if not fig_path.parent.is_dir():
fig_path.parent.mkdir(parents=True)
fig.savefig(fig_path)
return metrics return metrics
@ -175,6 +165,49 @@ class BatDetect2API:
def load_clip(self, clip: data.Clip) -> np.ndarray: def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip) return self.audio_loader.load_clip(clip)
def get_top_class_name(self, detection: Detection) -> str:
"""Get highest-confidence class name for one detection."""
top_index = int(np.argmax(detection.class_scores))
return self.targets.class_names[top_index]
def get_class_scores(
self,
detection: Detection,
*,
include_top_class: bool = True,
sort_descending: bool = True,
) -> list[tuple[str, float]]:
"""Get class score list as ``(class_name, score)`` pairs."""
scores = [
(class_name, float(score))
for class_name, score in zip(
self.targets.class_names,
detection.class_scores,
strict=True,
)
]
if sort_descending:
scores.sort(key=lambda item: item[1], reverse=True)
if include_top_class:
return scores
top_class_name = self.get_top_class_name(detection)
return [
(class_name, score)
for class_name, score in scores
if class_name != top_class_name
]
@staticmethod
def get_detection_features(detection: Detection) -> np.ndarray:
"""Get extracted feature vector for one detection."""
return detection.features
def generate_spectrogram( def generate_spectrogram(
self, self,
audio: np.ndarray, audio: np.ndarray,

View File

@ -1,6 +1,7 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.results import save_evaluation_results
from batdetect2.evaluate.tasks import TaskConfig, build_task from batdetect2.evaluate.tasks import TaskConfig, build_task
from batdetect2.evaluate.types import ( from batdetect2.evaluate.types import (
AffinityFunction, AffinityFunction,
@ -28,4 +29,5 @@ __all__ = [
"build_task", "build_task",
"load_evaluation_config", "load_evaluation_config",
"run_evaluate", "run_evaluate",
"save_evaluation_results",
] ]

View File

@ -0,0 +1,27 @@
import json
from pathlib import Path
from typing import Iterable
from matplotlib.figure import Figure
from soundevent import data
__all__ = ["save_evaluation_results"]
def save_evaluation_results(
metrics: dict[str, float],
plots: Iterable[tuple[str, Figure]],
output_dir: data.PathLike,
) -> None:
"""Save evaluation metrics and plots to disk."""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
metrics_path = output_path / "metrics.json"
metrics_path.write_text(json.dumps(metrics))
for figure_name, figure in plots:
figure_path = output_path / figure_name
figure_path.parent.mkdir(parents=True, exist_ok=True)
figure.savefig(figure_path)

View File

@ -1,5 +1,6 @@
from pathlib import Path from pathlib import Path
import lightning as L
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
@ -7,6 +8,7 @@ from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.train.lightning import build_training_module
@pytest.fixture @pytest.fixture
@ -113,3 +115,141 @@ def test_process_spectrogram_rejects_batched_input(
with pytest.raises(ValueError, match="Batched spectrograms not supported"): with pytest.raises(ValueError, match="Batched spectrograms not supported"):
api_v2.process_spectrogram(spec) api_v2.process_spectrogram(spec)
def test_user_can_read_top_class_and_other_class_scores(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: inspect top class and all class scores per detection."""
prediction = api_v2.process_file(example_audio_files[0])
assert len(prediction.detections) > 0
top_classes = [
api_v2.get_top_class_name(det) for det in prediction.detections
]
other_class_scores = [
api_v2.get_class_scores(det, include_top_class=False)
for det in prediction.detections
]
assert len(top_classes) == len(prediction.detections)
assert all(isinstance(class_name, str) for class_name in top_classes)
assert len(other_class_scores) == len(prediction.detections)
assert all(len(scores) >= 1 for scores in other_class_scores)
assert all(
all(class_name != top_class for class_name, _ in scores)
for top_class, scores in zip(
top_classes,
other_class_scores,
strict=True,
)
)
assert all(
all(
score_a >= score_b
for (_, score_a), (_, score_b) in zip(
scores, scores[1:], strict=False
)
)
for scores in other_class_scores
)
def test_user_can_read_extracted_features_per_detection(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: inspect extracted feature vectors per detection."""
prediction = api_v2.process_file(example_audio_files[0])
assert len(prediction.detections) > 0
feature_vectors = [
api_v2.get_detection_features(det) for det in prediction.detections
]
assert len(feature_vectors) == len(prediction.detections)
assert all(vec.ndim == 1 for vec in feature_vectors)
assert all(vec.size > 0 for vec in feature_vectors)
def test_user_can_load_checkpoint_and_finetune(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: load a checkpoint and continue training from it."""
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
config = BatDetect2Config()
config.train.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
api = BatDetect2API.from_checkpoint(checkpoint_path, config=config)
finetune_dir = tmp_path / "finetuned"
api.train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
checkpoints = list(finetune_dir.rglob("*.ckpt"))
assert checkpoints
def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API,
example_annotations,
tmp_path: Path,
) -> None:
"""User story: run evaluation and receive metrics."""
metrics, predictions = api_v2.evaluate(
test_annotations=example_annotations[:1],
num_workers=0,
output_dir=tmp_path / "eval",
save_predictions=False,
)
assert isinstance(metrics, list)
assert len(metrics) == 1
assert isinstance(metrics[0], dict)
assert len(metrics[0]) > 0
assert isinstance(predictions, list)
assert len(predictions) == 1
def test_user_can_save_evaluation_results_to_disk(
api_v2: BatDetect2API,
example_annotations,
tmp_path: Path,
) -> None:
"""User story: evaluate saved predictions and persist results."""
prediction = api_v2.process_file(
example_annotations[0].clip.recording.path
)
metrics = api_v2.evaluate_predictions(
annotations=[example_annotations[0]],
predictions=[prediction],
output_dir=tmp_path,
)
assert isinstance(metrics, dict)
assert (tmp_path / "metrics.json").exists()

View File

@ -0,0 +1,21 @@
import json
from matplotlib.figure import Figure
from batdetect2.evaluate.results import save_evaluation_results
def test_save_evaluation_results_writes_metrics_and_plots(tmp_path) -> None:
metrics = {"mAP": 0.5}
figure = Figure()
save_evaluation_results(
metrics=metrics,
plots=[("plots/example.png", figure)],
output_dir=tmp_path,
)
metrics_path = tmp_path / "metrics.json"
assert metrics_path.exists()
assert json.loads(metrics_path.read_text()) == metrics
assert (tmp_path / "plots" / "example.png").exists()