mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Store anns and preds instead of evals in EvaluatorModule
This commit is contained in:
parent
49ec1916ce
commit
2d796394f6
@ -17,6 +17,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--config", "config_path", type=click.Path())
|
||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@ -30,6 +31,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
base_dir: Path,
|
||||
config_path: Optional[Path],
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: Optional[int] = None,
|
||||
@ -52,7 +54,10 @@ def evaluate_command(
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
test_annotations = load_dataset_from_config(test_dataset)
|
||||
test_annotations = load_dataset_from_config(
|
||||
test_dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded {num_annotations} test examples",
|
||||
num_annotations=len(test_annotations),
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from lightning import LightningModule
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
@ -8,6 +9,7 @@ from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
@ -21,9 +23,10 @@ class EvaluationModule(LightningModule):
|
||||
self.model = model
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.clip_evaluations = []
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[List[RawPrediction]] = []
|
||||
|
||||
def test_step(self, batch: TestExample):
|
||||
def test_step(self, batch: TestExample, batch_idx: int):
|
||||
dataset = self.get_dataset()
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
@ -43,18 +46,22 @@ class EvaluationModule(LightningModule):
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
self.clip_evaluations.extend(
|
||||
self.evaluator.evaluate(clip_annotations, predictions)
|
||||
)
|
||||
self.clip_annotations.extend(clip_annotations)
|
||||
self.predictions.extend(predictions)
|
||||
|
||||
def on_test_epoch_start(self):
|
||||
self.clip_evaluations = []
|
||||
self.clip_annotations = []
|
||||
self.predictions = []
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
self.log_metrics(self.clip_evaluations)
|
||||
self.plot_examples(self.clip_evaluations)
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
self.clip_annotations,
|
||||
self.predictions,
|
||||
)
|
||||
self.log_metrics(clip_evals)
|
||||
self.generate_plots(clip_evals)
|
||||
|
||||
def plot_examples(self, evaluated_clips: Any):
|
||||
def generate_plots(self, evaluated_clips: Any):
|
||||
plotter = get_image_logger(self.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
|
||||
@ -122,7 +122,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||
np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user