From 8e359560070db5c08b04bf67676f4d8bda60bddd Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 18 Mar 2026 16:49:22 +0000 Subject: [PATCH] Create save evaluation results --- src/batdetect2/api_v2.py | 65 +++++++++---- src/batdetect2/evaluate/__init__.py | 2 + src/batdetect2/evaluate/results.py | 27 ++++++ tests/test_api_v2/test_api_v2.py | 140 ++++++++++++++++++++++++++++ tests/test_evaluate/test_results.py | 21 +++++ 5 files changed, 239 insertions(+), 16 deletions(-) create mode 100644 src/batdetect2/evaluate/results.py create mode 100644 tests/test_evaluate/test_results.py diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index bed0924..ecf0964 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -1,4 +1,3 @@ -import json from pathlib import Path from typing import Sequence @@ -16,6 +15,7 @@ from batdetect2.evaluate import ( EvaluatorProtocol, build_evaluator, run_evaluate, + save_evaluation_results, ) from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR @@ -148,21 +148,11 @@ class BatDetect2API: metrics = self.evaluator.compute_metrics(clip_evals) if output_dir is not None: - output_dir = Path(output_dir) - - if not output_dir.is_dir(): - output_dir.mkdir(parents=True) - - metrics_path = output_dir / "metrics.json" - metrics_path.write_text(json.dumps(metrics)) - - for figure_name, fig in self.evaluator.generate_plots(clip_evals): - fig_path = output_dir / figure_name - - if not fig_path.parent.is_dir(): - fig_path.parent.mkdir(parents=True) - - fig.savefig(fig_path) + save_evaluation_results( + metrics=metrics, + plots=self.evaluator.generate_plots(clip_evals), + output_dir=output_dir, + ) return metrics @@ -175,6 +165,49 @@ class BatDetect2API: def load_clip(self, clip: data.Clip) -> np.ndarray: return self.audio_loader.load_clip(clip) + def get_top_class_name(self, detection: Detection) -> str: + """Get highest-confidence class name for one detection.""" + + top_index = int(np.argmax(detection.class_scores)) + return self.targets.class_names[top_index] + + def get_class_scores( + self, + detection: Detection, + *, + include_top_class: bool = True, + sort_descending: bool = True, + ) -> list[tuple[str, float]]: + """Get class score list as ``(class_name, score)`` pairs.""" + + scores = [ + (class_name, float(score)) + for class_name, score in zip( + self.targets.class_names, + detection.class_scores, + strict=True, + ) + ] + + if sort_descending: + scores.sort(key=lambda item: item[1], reverse=True) + + if include_top_class: + return scores + + top_class_name = self.get_top_class_name(detection) + return [ + (class_name, score) + for class_name, score in scores + if class_name != top_class_name + ] + + @staticmethod + def get_detection_features(detection: Detection) -> np.ndarray: + """Get extracted feature vector for one detection.""" + + return detection.features + def generate_spectrogram( self, audio: np.ndarray, diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index 25463b5..20ee7b4 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,6 +1,7 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate from batdetect2.evaluate.evaluator import Evaluator, build_evaluator +from batdetect2.evaluate.results import save_evaluation_results from batdetect2.evaluate.tasks import TaskConfig, build_task from batdetect2.evaluate.types import ( AffinityFunction, @@ -28,4 +29,5 @@ __all__ = [ "build_task", "load_evaluation_config", "run_evaluate", + "save_evaluation_results", ] diff --git a/src/batdetect2/evaluate/results.py b/src/batdetect2/evaluate/results.py new file mode 100644 index 0000000..81af653 --- /dev/null +++ b/src/batdetect2/evaluate/results.py @@ -0,0 +1,27 @@ +import json +from pathlib import Path +from typing import Iterable + +from matplotlib.figure import Figure +from soundevent import data + +__all__ = ["save_evaluation_results"] + + +def save_evaluation_results( + metrics: dict[str, float], + plots: Iterable[tuple[str, Figure]], + output_dir: data.PathLike, +) -> None: + """Save evaluation metrics and plots to disk.""" + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + metrics_path = output_path / "metrics.json" + metrics_path.write_text(json.dumps(metrics)) + + for figure_name, figure in plots: + figure_path = output_path / figure_name + figure_path.parent.mkdir(parents=True, exist_ok=True) + figure.savefig(figure_path) diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index a3910f7..72e2f3e 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -1,5 +1,6 @@ from pathlib import Path +import lightning as L import numpy as np import pytest import torch @@ -7,6 +8,7 @@ from soundevent.geometry import compute_bounds from batdetect2.api_v2 import BatDetect2API from batdetect2.config import BatDetect2Config +from batdetect2.train.lightning import build_training_module @pytest.fixture @@ -113,3 +115,141 @@ def test_process_spectrogram_rejects_batched_input( with pytest.raises(ValueError, match="Batched spectrograms not supported"): api_v2.process_spectrogram(spec) + + +def test_user_can_read_top_class_and_other_class_scores( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: inspect top class and all class scores per detection.""" + + prediction = api_v2.process_file(example_audio_files[0]) + + assert len(prediction.detections) > 0 + + top_classes = [ + api_v2.get_top_class_name(det) for det in prediction.detections + ] + other_class_scores = [ + api_v2.get_class_scores(det, include_top_class=False) + for det in prediction.detections + ] + + assert len(top_classes) == len(prediction.detections) + assert all(isinstance(class_name, str) for class_name in top_classes) + assert len(other_class_scores) == len(prediction.detections) + assert all(len(scores) >= 1 for scores in other_class_scores) + assert all( + all(class_name != top_class for class_name, _ in scores) + for top_class, scores in zip( + top_classes, + other_class_scores, + strict=True, + ) + ) + assert all( + all( + score_a >= score_b + for (_, score_a), (_, score_b) in zip( + scores, scores[1:], strict=False + ) + ) + for scores in other_class_scores + ) + + +def test_user_can_read_extracted_features_per_detection( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: inspect extracted feature vectors per detection.""" + + prediction = api_v2.process_file(example_audio_files[0]) + + assert len(prediction.detections) > 0 + + feature_vectors = [ + api_v2.get_detection_features(det) for det in prediction.detections + ] + assert len(feature_vectors) == len(prediction.detections) + assert all(vec.ndim == 1 for vec in feature_vectors) + assert all(vec.size > 0 for vec in feature_vectors) + + +def test_user_can_load_checkpoint_and_finetune( + tmp_path: Path, + example_annotations, +) -> None: + """User story: load a checkpoint and continue training from it.""" + + module = build_training_module(model_config=BatDetect2Config().model) + trainer = L.Trainer(enable_checkpointing=False, logger=False) + checkpoint_path = tmp_path / "base.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(checkpoint_path) + + config = BatDetect2Config() + config.train.trainer.limit_train_batches = 1 + config.train.trainer.limit_val_batches = 1 + config.train.trainer.log_every_n_steps = 1 + config.train.train_loader.batch_size = 1 + config.train.train_loader.augmentations.enabled = False + + api = BatDetect2API.from_checkpoint(checkpoint_path, config=config) + finetune_dir = tmp_path / "finetuned" + + api.train( + train_annotations=example_annotations[:1], + val_annotations=example_annotations[:1], + train_workers=0, + val_workers=0, + checkpoint_dir=finetune_dir, + log_dir=tmp_path / "logs", + num_epochs=1, + seed=0, + ) + + checkpoints = list(finetune_dir.rglob("*.ckpt")) + assert checkpoints + + +def test_user_can_evaluate_small_dataset_and_get_metrics( + api_v2: BatDetect2API, + example_annotations, + tmp_path: Path, +) -> None: + """User story: run evaluation and receive metrics.""" + + metrics, predictions = api_v2.evaluate( + test_annotations=example_annotations[:1], + num_workers=0, + output_dir=tmp_path / "eval", + save_predictions=False, + ) + + assert isinstance(metrics, list) + assert len(metrics) == 1 + assert isinstance(metrics[0], dict) + assert len(metrics[0]) > 0 + assert isinstance(predictions, list) + assert len(predictions) == 1 + + +def test_user_can_save_evaluation_results_to_disk( + api_v2: BatDetect2API, + example_annotations, + tmp_path: Path, +) -> None: + """User story: evaluate saved predictions and persist results.""" + + prediction = api_v2.process_file( + example_annotations[0].clip.recording.path + ) + metrics = api_v2.evaluate_predictions( + annotations=[example_annotations[0]], + predictions=[prediction], + output_dir=tmp_path, + ) + + assert isinstance(metrics, dict) + assert (tmp_path / "metrics.json").exists() diff --git a/tests/test_evaluate/test_results.py b/tests/test_evaluate/test_results.py new file mode 100644 index 0000000..5702647 --- /dev/null +++ b/tests/test_evaluate/test_results.py @@ -0,0 +1,21 @@ +import json + +from matplotlib.figure import Figure + +from batdetect2.evaluate.results import save_evaluation_results + + +def test_save_evaluation_results_writes_metrics_and_plots(tmp_path) -> None: + metrics = {"mAP": 0.5} + figure = Figure() + + save_evaluation_results( + metrics=metrics, + plots=[("plots/example.png", figure)], + output_dir=tmp_path, + ) + + metrics_path = tmp_path / "metrics.json" + assert metrics_path.exists() + assert json.loads(metrics_path.read_text()) == metrics + assert (tmp_path / "plots" / "example.png").exists()