Store anns and preds instead of evals in EvaluatorModule

This commit is contained in:
mbsantiago 2025-09-30 19:10:07 +01:00
parent 49ec1916ce
commit 2d796394f6
3 changed files with 24 additions and 12 deletions

View File

@ -17,6 +17,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--config", "config_path", type=click.Path())
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@ -30,6 +31,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
def evaluate_command(
model_path: Path,
test_dataset: Path,
base_dir: Path,
config_path: Optional[Path],
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None,
@ -52,7 +54,10 @@ def evaluate_command(
logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(test_dataset)
test_annotations = load_dataset_from_config(
test_dataset,
base_dir=base_dir,
)
logger.debug(
"Loaded {num_annotations} test examples",
num_annotations=len(test_annotations),

View File

@ -1,6 +1,7 @@
from typing import Any
from typing import Any, List
from lightning import LightningModule
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
@ -8,6 +9,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import RawPrediction
class EvaluationModule(LightningModule):
@ -21,9 +23,10 @@ class EvaluationModule(LightningModule):
self.model = model
self.evaluator = evaluator
self.clip_evaluations = []
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[List[RawPrediction]] = []
def test_step(self, batch: TestExample):
def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset()
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
@ -43,18 +46,22 @@ class EvaluationModule(LightningModule):
for clip_dets in clip_detections
]
self.clip_evaluations.extend(
self.evaluator.evaluate(clip_annotations, predictions)
)
self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions)
def on_test_epoch_start(self):
self.clip_evaluations = []
self.clip_annotations = []
self.predictions = []
def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations)
clip_evals = self.evaluator.evaluate(
self.clip_annotations,
self.predictions,
)
self.log_metrics(clip_evals)
self.generate_plots(clip_evals)
def plot_examples(self, evaluated_clips: Any):
def generate_plots(self, evaluated_clips: Any):
plotter = get_image_logger(self.logger) # type: ignore
if plotter is None:

View File

@ -122,7 +122,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])
np.mean([v for v in class_scores.values() if not np.isnan(v)])
)
return {