From 2d796394f69e5ada5fa36ae5a1a867c538447786 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 30 Sep 2025 19:10:07 +0100 Subject: [PATCH] Store anns and preds instead of evals in EvaluatorModule --- src/batdetect2/cli/evaluate.py | 7 ++++- src/batdetect2/evaluate/lightning.py | 27 ++++++++++++------- .../evaluate/metrics/classification.py | 2 +- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index 4da81dd..7fa2631 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -17,6 +17,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation" @click.argument("model-path", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True)) @click.option("--config", "config_path", type=click.Path()) +@click.option("--base-dir", type=click.Path(), default=Path.cwd()) @click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR) @click.option("--experiment-name", type=str) @click.option("--run-name", type=str) @@ -30,6 +31,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation" def evaluate_command( model_path: Path, test_dataset: Path, + base_dir: Path, config_path: Optional[Path], output_dir: Path = DEFAULT_OUTPUT_DIR, num_workers: Optional[int] = None, @@ -52,7 +54,10 @@ def evaluate_command( logger.info("Initiating evaluation process...") - test_annotations = load_dataset_from_config(test_dataset) + test_annotations = load_dataset_from_config( + test_dataset, + base_dir=base_dir, + ) logger.debug( "Loaded {num_annotations} test examples", num_annotations=len(test_annotations), diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index 6a02d5f..da2e3e3 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -1,6 +1,7 @@ -from typing import Any +from typing import Any, List from lightning import LightningModule +from soundevent import data from torch.utils.data import DataLoader from batdetect2.evaluate.dataset import TestDataset, TestExample @@ -8,6 +9,7 @@ from batdetect2.logging import get_image_logger from batdetect2.models import Model from batdetect2.postprocess import to_raw_predictions from batdetect2.typing import EvaluatorProtocol +from batdetect2.typing.postprocess import RawPrediction class EvaluationModule(LightningModule): @@ -21,9 +23,10 @@ class EvaluationModule(LightningModule): self.model = model self.evaluator = evaluator - self.clip_evaluations = [] + self.clip_annotations: List[data.ClipAnnotation] = [] + self.predictions: List[List[RawPrediction]] = [] - def test_step(self, batch: TestExample): + def test_step(self, batch: TestExample, batch_idx: int): dataset = self.get_dataset() clip_annotations = [ dataset.clip_annotations[int(example_idx)] @@ -43,18 +46,22 @@ class EvaluationModule(LightningModule): for clip_dets in clip_detections ] - self.clip_evaluations.extend( - self.evaluator.evaluate(clip_annotations, predictions) - ) + self.clip_annotations.extend(clip_annotations) + self.predictions.extend(predictions) def on_test_epoch_start(self): - self.clip_evaluations = [] + self.clip_annotations = [] + self.predictions = [] def on_test_epoch_end(self): - self.log_metrics(self.clip_evaluations) - self.plot_examples(self.clip_evaluations) + clip_evals = self.evaluator.evaluate( + self.clip_annotations, + self.predictions, + ) + self.log_metrics(clip_evals) + self.generate_plots(clip_evals) - def plot_examples(self, evaluated_clips: Any): + def generate_plots(self, evaluated_clips: Any): plotter = get_image_logger(self.logger) # type: ignore if plotter is None: diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index 602600b..404c57f 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -122,7 +122,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric): } mean_score = float( - np.mean([v for v in class_scores.values() if v != np.nan]) + np.mean([v for v in class_scores.values() if not np.isnan(v)]) ) return {