mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Store anns and preds instead of evals in EvaluatorModule
This commit is contained in:
parent
49ec1916ce
commit
2d796394f6
@ -17,6 +17,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
@click.argument("model-path", type=click.Path(exists=True))
|
@click.argument("model-path", type=click.Path(exists=True))
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--config", "config_path", type=click.Path())
|
@click.option("--config", "config_path", type=click.Path())
|
||||||
|
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||||
@click.option("--experiment-name", type=str)
|
@click.option("--experiment-name", type=str)
|
||||||
@click.option("--run-name", type=str)
|
@click.option("--run-name", type=str)
|
||||||
@ -30,6 +31,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
def evaluate_command(
|
def evaluate_command(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
|
base_dir: Path,
|
||||||
config_path: Optional[Path],
|
config_path: Optional[Path],
|
||||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
@ -52,7 +54,10 @@ def evaluate_command(
|
|||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
test_annotations = load_dataset_from_config(test_dataset)
|
test_annotations = load_dataset_from_config(
|
||||||
|
test_dataset,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded {num_annotations} test examples",
|
"Loaded {num_annotations} test examples",
|
||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import Any
|
from typing import Any, List
|
||||||
|
|
||||||
from lightning import LightningModule
|
from lightning import LightningModule
|
||||||
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
@ -8,6 +9,7 @@ from batdetect2.logging import get_image_logger
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import EvaluatorProtocol
|
from batdetect2.typing import EvaluatorProtocol
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
@ -21,9 +23,10 @@ class EvaluationModule(LightningModule):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
|
||||||
self.clip_evaluations = []
|
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||||
|
self.predictions: List[List[RawPrediction]] = []
|
||||||
|
|
||||||
def test_step(self, batch: TestExample):
|
def test_step(self, batch: TestExample, batch_idx: int):
|
||||||
dataset = self.get_dataset()
|
dataset = self.get_dataset()
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
dataset.clip_annotations[int(example_idx)]
|
dataset.clip_annotations[int(example_idx)]
|
||||||
@ -43,18 +46,22 @@ class EvaluationModule(LightningModule):
|
|||||||
for clip_dets in clip_detections
|
for clip_dets in clip_detections
|
||||||
]
|
]
|
||||||
|
|
||||||
self.clip_evaluations.extend(
|
self.clip_annotations.extend(clip_annotations)
|
||||||
self.evaluator.evaluate(clip_annotations, predictions)
|
self.predictions.extend(predictions)
|
||||||
)
|
|
||||||
|
|
||||||
def on_test_epoch_start(self):
|
def on_test_epoch_start(self):
|
||||||
self.clip_evaluations = []
|
self.clip_annotations = []
|
||||||
|
self.predictions = []
|
||||||
|
|
||||||
def on_test_epoch_end(self):
|
def on_test_epoch_end(self):
|
||||||
self.log_metrics(self.clip_evaluations)
|
clip_evals = self.evaluator.evaluate(
|
||||||
self.plot_examples(self.clip_evaluations)
|
self.clip_annotations,
|
||||||
|
self.predictions,
|
||||||
|
)
|
||||||
|
self.log_metrics(clip_evals)
|
||||||
|
self.generate_plots(clip_evals)
|
||||||
|
|
||||||
def plot_examples(self, evaluated_clips: Any):
|
def generate_plots(self, evaluated_clips: Any):
|
||||||
plotter = get_image_logger(self.logger) # type: ignore
|
plotter = get_image_logger(self.logger) # type: ignore
|
||||||
|
|
||||||
if plotter is None:
|
if plotter is None:
|
||||||
|
|||||||
@ -122,7 +122,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
|||||||
}
|
}
|
||||||
|
|
||||||
mean_score = float(
|
mean_score = float(
|
||||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user