Store anns and preds instead of evals in EvaluatorModule

This commit is contained in:
mbsantiago 2025-09-30 19:10:07 +01:00
parent 49ec1916ce
commit 2d796394f6
3 changed files with 24 additions and 12 deletions

View File

@ -17,6 +17,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.argument("model-path", type=click.Path(exists=True)) @click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--config", "config_path", type=click.Path()) @click.option("--config", "config_path", type=click.Path())
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR) @click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--run-name", type=str) @click.option("--run-name", type=str)
@ -30,6 +31,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
def evaluate_command( def evaluate_command(
model_path: Path, model_path: Path,
test_dataset: Path, test_dataset: Path,
base_dir: Path,
config_path: Optional[Path], config_path: Optional[Path],
output_dir: Path = DEFAULT_OUTPUT_DIR, output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
@ -52,7 +54,10 @@ def evaluate_command(
logger.info("Initiating evaluation process...") logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(test_dataset) test_annotations = load_dataset_from_config(
test_dataset,
base_dir=base_dir,
)
logger.debug( logger.debug(
"Loaded {num_annotations} test examples", "Loaded {num_annotations} test examples",
num_annotations=len(test_annotations), num_annotations=len(test_annotations),

View File

@ -1,6 +1,7 @@
from typing import Any from typing import Any, List
from lightning import LightningModule from lightning import LightningModule
from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.dataset import TestDataset, TestExample
@ -8,6 +9,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import RawPrediction
class EvaluationModule(LightningModule): class EvaluationModule(LightningModule):
@ -21,9 +23,10 @@ class EvaluationModule(LightningModule):
self.model = model self.model = model
self.evaluator = evaluator self.evaluator = evaluator
self.clip_evaluations = [] self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[List[RawPrediction]] = []
def test_step(self, batch: TestExample): def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset() dataset = self.get_dataset()
clip_annotations = [ clip_annotations = [
dataset.clip_annotations[int(example_idx)] dataset.clip_annotations[int(example_idx)]
@ -43,18 +46,22 @@ class EvaluationModule(LightningModule):
for clip_dets in clip_detections for clip_dets in clip_detections
] ]
self.clip_evaluations.extend( self.clip_annotations.extend(clip_annotations)
self.evaluator.evaluate(clip_annotations, predictions) self.predictions.extend(predictions)
)
def on_test_epoch_start(self): def on_test_epoch_start(self):
self.clip_evaluations = [] self.clip_annotations = []
self.predictions = []
def on_test_epoch_end(self): def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations) clip_evals = self.evaluator.evaluate(
self.plot_examples(self.clip_evaluations) self.clip_annotations,
self.predictions,
)
self.log_metrics(clip_evals)
self.generate_plots(clip_evals)
def plot_examples(self, evaluated_clips: Any): def generate_plots(self, evaluated_clips: Any):
plotter = get_image_logger(self.logger) # type: ignore plotter = get_image_logger(self.logger) # type: ignore
if plotter is None: if plotter is None:

View File

@ -122,7 +122,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
} }
mean_score = float( mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan]) np.mean([v for v in class_scores.values() if not np.isnan(v)])
) )
return { return {