From 9c8b8fb200f4ab93ded683ad9bca409e561e4b0f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 30 Apr 2025 22:51:33 +0100 Subject: [PATCH] Create metrics --- batdetect2/evaluate/metrics.py | 82 +++++++++++++++++++++++++++++----- batdetect2/evaluate/types.py | 4 +- batdetect2/train/__init__.py | 13 +++++- batdetect2/train/callbacks.py | 5 ++- batdetect2/train/lightning.py | 16 +++---- batdetect2/train/train.py | 5 +-- 6 files changed, 96 insertions(+), 29 deletions(-) diff --git a/batdetect2/evaluate/metrics.py b/batdetect2/evaluate/metrics.py index 3cb762e..c1bc924 100644 --- a/batdetect2/evaluate/metrics.py +++ b/batdetect2/evaluate/metrics.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import pandas as pd from sklearn import metrics @@ -10,22 +10,20 @@ __all__ = ["DetectionAveragePrecision"] class DetectionAveragePrecision(MetricsProtocol): - name: str = "detection/average_precision" - - def __call__(self, matches: List[Match]) -> float: + def __call__(self, matches: List[Match]) -> Dict[str, float]: y_true, y_score = zip( *[(match.gt_det, match.pred_score) for match in matches] ) - return float(metrics.average_precision_score(y_true, y_score)) + score = float(metrics.average_precision_score(y_true, y_score)) + return {"detection_AP": score} class ClassificationMeanAveragePrecision(MetricsProtocol): - name: str = "classification/average_precision" - - def __init__(self, class_names: List[str]): + def __init__(self, class_names: List[str], per_class: bool = True): self.class_names = class_names + self.per_class = per_class - def __call__(self, matches: List[Match]) -> float: + def __call__(self, matches: List[Match]) -> Dict[str, float]: y_true = label_binarize( [ match.gt_class if match.gt_class is not None else "__NONE__" @@ -33,7 +31,67 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): ], classes=self.class_names, ) - y_pred = pd.DataFrame([match.class_scores for match in matches]) - return float( - metrics.average_precision_score(y_true, y_pred[self.class_names]) + y_pred = pd.DataFrame( + [ + { + name: match.class_scores.get(name, 0) + for name in self.class_names + } + for match in matches + ] + ).fillna(0) + mAP = metrics.average_precision_score(y_true, y_pred[self.class_names]) + + ret = { + "classification_mAP": float(mAP), + } + + if not self.per_class: + return ret + + for class_index, class_name in enumerate(self.class_names): + y_true_class = y_true[:, class_index] + y_pred_class = y_pred[class_name] + class_ap = metrics.average_precision_score( + y_true_class, + y_pred_class, + ) + ret[f"classification_AP/{class_name}"] = float(class_ap) + + return ret + + +class ClassificationAccuracy(MetricsProtocol): + def __init__(self, class_names: List[str]): + self.class_names = class_names + + def __call__(self, matches: List[Match]) -> Dict[str, float]: + y_true = [ + match.gt_class if match.gt_class is not None else "__NONE__" + for match in matches + ] + + y_pred = pd.DataFrame( + [ + { + name: match.class_scores.get(name, 0) + for name in self.class_names + } + for match in matches + ] + ).fillna(0) + y_pred = y_pred.apply( + lambda row: row.idxmax() + if row.max() >= (1 - row.sum()) + else "__NONE__", + axis=1, ) + + accuracy = metrics.balanced_accuracy_score( + y_true, + y_pred, + ) + + return { + "classification_acc": float(accuracy), + } diff --git a/batdetect2/evaluate/types.py b/batdetect2/evaluate/types.py index 081253f..76e39a6 100644 --- a/batdetect2/evaluate/types.py +++ b/batdetect2/evaluate/types.py @@ -19,6 +19,4 @@ class Match: class MetricsProtocol(Protocol): - name: str - - def __call__(self, matches: List[Match]) -> float: ... + def __call__(self, matches: List[Match]) -> Dict[str, float]: ... diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index f15b87c..d0baebc 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -26,7 +26,14 @@ from batdetect2.train.dataset import ( list_preprocessed_files, ) from batdetect2.train.labels import build_clip_labeler, load_label_config -from batdetect2.train.losses import LossFunction, build_loss +from batdetect2.train.losses import ( + ClassificationLossConfig, + DetectionLossConfig, + LossConfig, + LossFunction, + SizeLossConfig, + build_loss, +) from batdetect2.train.preprocess import ( generate_train_example, preprocess_annotations, @@ -39,11 +46,15 @@ from batdetect2.train.train import ( __all__ = [ "AugmentationsConfig", + "ClassificationLossConfig", + "DetectionLossConfig", "EchoAugmentationConfig", "FrequencyMaskAugmentationConfig", "LabeledDataset", + "LossConfig", "LossFunction", "RandomExampleSource", + "SizeLossConfig", "TimeMaskAugmentationConfig", "TrainExample", "TrainerConfig", diff --git a/batdetect2/train/callbacks.py b/batdetect2/train/callbacks.py index f1e0895..fe30f40 100644 --- a/batdetect2/train/callbacks.py +++ b/batdetect2/train/callbacks.py @@ -28,10 +28,11 @@ class ValidationMetrics(Callback): trainer: Trainer, pl_module: LightningModule, ) -> None: + metrics = {} for metric in self.metrics: - value = metric(self.matches) - pl_module.log(f"val/metric/{metric.name}", value, prog_bar=True) + metrics.update(metric(self.matches).items()) + pl_module.log_dict(metrics) return super().on_validation_epoch_end(trainer, pl_module) def on_validation_epoch_start( diff --git a/batdetect2/train/lightning.py b/batdetect2/train/lightning.py index 515f374..5080a80 100644 --- a/batdetect2/train/lightning.py +++ b/batdetect2/train/lightning.py @@ -51,10 +51,10 @@ class TrainingModule(L.LightningModule): outputs = self.forward(batch.spec) losses = self.loss(outputs, batch) - self.log("train/loss/total", losses.total, prog_bar=True, logger=True) - self.log("train/loss/detection", losses.total, logger=True) - self.log("train/loss/size", losses.total, logger=True) - self.log("train/loss/classification", losses.total, logger=True) + self.log("total_loss/train", losses.total, prog_bar=True, logger=True) + self.log("detection_loss/train", losses.total, logger=True) + self.log("size_loss/train", losses.total, logger=True) + self.log("classification_loss/train", losses.total, logger=True) return losses.total @@ -64,10 +64,10 @@ class TrainingModule(L.LightningModule): outputs = self.forward(batch.spec) losses = self.loss(outputs, batch) - self.log("val/loss/total", losses.total, prog_bar=True, logger=True) - self.log("val/loss/detection", losses.total, logger=True) - self.log("val/loss/size", losses.total, logger=True) - self.log("val/loss/classification", losses.total, logger=True) + self.log("total_loss/val", losses.total, prog_bar=True, logger=True) + self.log("detection_loss/val", losses.total, logger=True) + self.log("size_loss/val", losses.total, logger=True) + self.log("classification_loss/val", losses.total, logger=True) return outputs diff --git a/batdetect2/train/train.py b/batdetect2/train/train.py index e4611ba..757ec99 100644 --- a/batdetect2/train/train.py +++ b/batdetect2/train/train.py @@ -37,6 +37,7 @@ def train( val_examples: Optional[List[data.PathLike]] = None, config: Optional[TrainingConfig] = None, callbacks: Optional[List[Callback]] = None, + **trainer_kwargs, ) -> None: config = config or TrainingConfig() @@ -74,9 +75,7 @@ def train( trainer = Trainer( **config.trainer.model_dump(exclude_none=True), callbacks=callbacks, - num_sanity_val_steps=0, - # enable_model_summary=False, - # enable_progress_bar=False, + **trainer_kwargs, ) train_dataloader = DataLoader(