mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Create save evaluation results
This commit is contained in:
parent
a332c5c3bd
commit
8e35956007
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
@ -16,6 +15,7 @@ from batdetect2.evaluate import (
|
|||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
build_evaluator,
|
build_evaluator,
|
||||||
run_evaluate,
|
run_evaluate,
|
||||||
|
save_evaluation_results,
|
||||||
)
|
)
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
@ -148,21 +148,11 @@ class BatDetect2API:
|
|||||||
metrics = self.evaluator.compute_metrics(clip_evals)
|
metrics = self.evaluator.compute_metrics(clip_evals)
|
||||||
|
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
output_dir = Path(output_dir)
|
save_evaluation_results(
|
||||||
|
metrics=metrics,
|
||||||
if not output_dir.is_dir():
|
plots=self.evaluator.generate_plots(clip_evals),
|
||||||
output_dir.mkdir(parents=True)
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
metrics_path = output_dir / "metrics.json"
|
|
||||||
metrics_path.write_text(json.dumps(metrics))
|
|
||||||
|
|
||||||
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
|
|
||||||
fig_path = output_dir / figure_name
|
|
||||||
|
|
||||||
if not fig_path.parent.is_dir():
|
|
||||||
fig_path.parent.mkdir(parents=True)
|
|
||||||
|
|
||||||
fig.savefig(fig_path)
|
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
@ -175,6 +165,49 @@ class BatDetect2API:
|
|||||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||||
return self.audio_loader.load_clip(clip)
|
return self.audio_loader.load_clip(clip)
|
||||||
|
|
||||||
|
def get_top_class_name(self, detection: Detection) -> str:
|
||||||
|
"""Get highest-confidence class name for one detection."""
|
||||||
|
|
||||||
|
top_index = int(np.argmax(detection.class_scores))
|
||||||
|
return self.targets.class_names[top_index]
|
||||||
|
|
||||||
|
def get_class_scores(
|
||||||
|
self,
|
||||||
|
detection: Detection,
|
||||||
|
*,
|
||||||
|
include_top_class: bool = True,
|
||||||
|
sort_descending: bool = True,
|
||||||
|
) -> list[tuple[str, float]]:
|
||||||
|
"""Get class score list as ``(class_name, score)`` pairs."""
|
||||||
|
|
||||||
|
scores = [
|
||||||
|
(class_name, float(score))
|
||||||
|
for class_name, score in zip(
|
||||||
|
self.targets.class_names,
|
||||||
|
detection.class_scores,
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if sort_descending:
|
||||||
|
scores.sort(key=lambda item: item[1], reverse=True)
|
||||||
|
|
||||||
|
if include_top_class:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
top_class_name = self.get_top_class_name(detection)
|
||||||
|
return [
|
||||||
|
(class_name, score)
|
||||||
|
for class_name, score in scores
|
||||||
|
if class_name != top_class_name
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_detection_features(detection: Detection) -> np.ndarray:
|
||||||
|
"""Get extracted feature vector for one detection."""
|
||||||
|
|
||||||
|
return detection.features
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
|
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
|
from batdetect2.evaluate.results import save_evaluation_results
|
||||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||||
from batdetect2.evaluate.types import (
|
from batdetect2.evaluate.types import (
|
||||||
AffinityFunction,
|
AffinityFunction,
|
||||||
@ -28,4 +29,5 @@ __all__ = [
|
|||||||
"build_task",
|
"build_task",
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
"run_evaluate",
|
"run_evaluate",
|
||||||
|
"save_evaluation_results",
|
||||||
]
|
]
|
||||||
|
|||||||
27
src/batdetect2/evaluate/results.py
Normal file
27
src/batdetect2/evaluate/results.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
__all__ = ["save_evaluation_results"]
|
||||||
|
|
||||||
|
|
||||||
|
def save_evaluation_results(
|
||||||
|
metrics: dict[str, float],
|
||||||
|
plots: Iterable[tuple[str, Figure]],
|
||||||
|
output_dir: data.PathLike,
|
||||||
|
) -> None:
|
||||||
|
"""Save evaluation metrics and plots to disk."""
|
||||||
|
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
metrics_path = output_path / "metrics.json"
|
||||||
|
metrics_path.write_text(json.dumps(metrics))
|
||||||
|
|
||||||
|
for figure_name, figure in plots:
|
||||||
|
figure_path = output_path / figure_name
|
||||||
|
figure_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
figure.savefig(figure_path)
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lightning as L
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -7,6 +8,7 @@ from soundevent.geometry import compute_bounds
|
|||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -113,3 +115,141 @@ def test_process_spectrogram_rejects_batched_input(
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
||||||
api_v2.process_spectrogram(spec)
|
api_v2.process_spectrogram(spec)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_read_top_class_and_other_class_scores(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: inspect top class and all class scores per detection."""
|
||||||
|
|
||||||
|
prediction = api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
|
assert len(prediction.detections) > 0
|
||||||
|
|
||||||
|
top_classes = [
|
||||||
|
api_v2.get_top_class_name(det) for det in prediction.detections
|
||||||
|
]
|
||||||
|
other_class_scores = [
|
||||||
|
api_v2.get_class_scores(det, include_top_class=False)
|
||||||
|
for det in prediction.detections
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(top_classes) == len(prediction.detections)
|
||||||
|
assert all(isinstance(class_name, str) for class_name in top_classes)
|
||||||
|
assert len(other_class_scores) == len(prediction.detections)
|
||||||
|
assert all(len(scores) >= 1 for scores in other_class_scores)
|
||||||
|
assert all(
|
||||||
|
all(class_name != top_class for class_name, _ in scores)
|
||||||
|
for top_class, scores in zip(
|
||||||
|
top_classes,
|
||||||
|
other_class_scores,
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
all(
|
||||||
|
score_a >= score_b
|
||||||
|
for (_, score_a), (_, score_b) in zip(
|
||||||
|
scores, scores[1:], strict=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for scores in other_class_scores
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_read_extracted_features_per_detection(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: inspect extracted feature vectors per detection."""
|
||||||
|
|
||||||
|
prediction = api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
|
assert len(prediction.detections) > 0
|
||||||
|
|
||||||
|
feature_vectors = [
|
||||||
|
api_v2.get_detection_features(det) for det in prediction.detections
|
||||||
|
]
|
||||||
|
assert len(feature_vectors) == len(prediction.detections)
|
||||||
|
assert all(vec.ndim == 1 for vec in feature_vectors)
|
||||||
|
assert all(vec.size > 0 for vec in feature_vectors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_load_checkpoint_and_finetune(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations,
|
||||||
|
) -> None:
|
||||||
|
"""User story: load a checkpoint and continue training from it."""
|
||||||
|
|
||||||
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
|
checkpoint_path = tmp_path / "base.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
|
config = BatDetect2Config()
|
||||||
|
config.train.trainer.limit_train_batches = 1
|
||||||
|
config.train.trainer.limit_val_batches = 1
|
||||||
|
config.train.trainer.log_every_n_steps = 1
|
||||||
|
config.train.train_loader.batch_size = 1
|
||||||
|
config.train.train_loader.augmentations.enabled = False
|
||||||
|
|
||||||
|
api = BatDetect2API.from_checkpoint(checkpoint_path, config=config)
|
||||||
|
finetune_dir = tmp_path / "finetuned"
|
||||||
|
|
||||||
|
api.train(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=finetune_dir,
|
||||||
|
log_dir=tmp_path / "logs",
|
||||||
|
num_epochs=1,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
assert checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_annotations,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: run evaluation and receive metrics."""
|
||||||
|
|
||||||
|
metrics, predictions = api_v2.evaluate(
|
||||||
|
test_annotations=example_annotations[:1],
|
||||||
|
num_workers=0,
|
||||||
|
output_dir=tmp_path / "eval",
|
||||||
|
save_predictions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(metrics, list)
|
||||||
|
assert len(metrics) == 1
|
||||||
|
assert isinstance(metrics[0], dict)
|
||||||
|
assert len(metrics[0]) > 0
|
||||||
|
assert isinstance(predictions, list)
|
||||||
|
assert len(predictions) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_save_evaluation_results_to_disk(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_annotations,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: evaluate saved predictions and persist results."""
|
||||||
|
|
||||||
|
prediction = api_v2.process_file(
|
||||||
|
example_annotations[0].clip.recording.path
|
||||||
|
)
|
||||||
|
metrics = api_v2.evaluate_predictions(
|
||||||
|
annotations=[example_annotations[0]],
|
||||||
|
predictions=[prediction],
|
||||||
|
output_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(metrics, dict)
|
||||||
|
assert (tmp_path / "metrics.json").exists()
|
||||||
|
|||||||
21
tests/test_evaluate/test_results.py
Normal file
21
tests/test_evaluate/test_results.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from batdetect2.evaluate.results import save_evaluation_results
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_evaluation_results_writes_metrics_and_plots(tmp_path) -> None:
|
||||||
|
metrics = {"mAP": 0.5}
|
||||||
|
figure = Figure()
|
||||||
|
|
||||||
|
save_evaluation_results(
|
||||||
|
metrics=metrics,
|
||||||
|
plots=[("plots/example.png", figure)],
|
||||||
|
output_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics_path = tmp_path / "metrics.json"
|
||||||
|
assert metrics_path.exists()
|
||||||
|
assert json.loads(metrics_path.read_text()) == metrics
|
||||||
|
assert (tmp_path / "plots" / "example.png").exists()
|
||||||
Loading…
Reference in New Issue
Block a user