Create metrics

This commit is contained in:
mbsantiago 2025-04-30 22:51:33 +01:00
parent bc86c94f8e
commit 9c8b8fb200
6 changed files with 96 additions and 29 deletions

View File

@ -1,4 +1,4 @@
from typing import List
from typing import Dict, List
import pandas as pd
from sklearn import metrics
@ -10,22 +10,20 @@ __all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol):
name: str = "detection/average_precision"
def __call__(self, matches: List[Match]) -> float:
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
)
return float(metrics.average_precision_score(y_true, y_score))
score = float(metrics.average_precision_score(y_true, y_score))
return {"detection_AP": score}
class ClassificationMeanAveragePrecision(MetricsProtocol):
name: str = "classification/average_precision"
def __init__(self, class_names: List[str]):
def __init__(self, class_names: List[str], per_class: bool = True):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[Match]) -> float:
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
@ -33,7 +31,67 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
],
classes=self.class_names,
)
y_pred = pd.DataFrame([match.class_scores for match in matches])
return float(
metrics.average_precision_score(y_true, y_pred[self.class_names])
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = {
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name]
class_ap = metrics.average_precision_score(
y_true_class,
y_pred_class,
)
ret[f"classification_AP/{class_name}"] = float(class_ap)
return ret
class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
]
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
y_pred = y_pred.apply(
lambda row: row.idxmax()
if row.max() >= (1 - row.sum())
else "__NONE__",
axis=1,
)
accuracy = metrics.balanced_accuracy_score(
y_true,
y_pred,
)
return {
"classification_acc": float(accuracy),
}

View File

@ -19,6 +19,4 @@ class Match:
class MetricsProtocol(Protocol):
name: str
def __call__(self, matches: List[Match]) -> float: ...
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...

View File

@ -26,7 +26,14 @@ from batdetect2.train.dataset import (
list_preprocessed_files,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.losses import LossFunction, build_loss
from batdetect2.train.losses import (
ClassificationLossConfig,
DetectionLossConfig,
LossConfig,
LossFunction,
SizeLossConfig,
build_loss,
)
from batdetect2.train.preprocess import (
generate_train_example,
preprocess_annotations,
@ -39,11 +46,15 @@ from batdetect2.train.train import (
__all__ = [
"AugmentationsConfig",
"ClassificationLossConfig",
"DetectionLossConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"LabeledDataset",
"LossConfig",
"LossFunction",
"RandomExampleSource",
"SizeLossConfig",
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainerConfig",

View File

@ -28,10 +28,11 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
metrics = {}
for metric in self.metrics:
value = metric(self.matches)
pl_module.log(f"val/metric/{metric.name}", value, prog_bar=True)
metrics.update(metric(self.matches).items())
pl_module.log_dict(metrics)
return super().on_validation_epoch_end(trainer, pl_module)
def on_validation_epoch_start(

View File

@ -51,10 +51,10 @@ class TrainingModule(L.LightningModule):
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
self.log("train/loss/detection", losses.total, logger=True)
self.log("train/loss/size", losses.total, logger=True)
self.log("train/loss/classification", losses.total, logger=True)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True)
self.log("size_loss/train", losses.total, logger=True)
self.log("classification_loss/train", losses.total, logger=True)
return losses.total
@ -64,10 +64,10 @@ class TrainingModule(L.LightningModule):
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
self.log("val/loss/detection", losses.total, logger=True)
self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True)
self.log("size_loss/val", losses.total, logger=True)
self.log("classification_loss/val", losses.total, logger=True)
return outputs

View File

@ -37,6 +37,7 @@ def train(
val_examples: Optional[List[data.PathLike]] = None,
config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None,
**trainer_kwargs,
) -> None:
config = config or TrainingConfig()
@ -74,9 +75,7 @@ def train(
trainer = Trainer(
**config.trainer.model_dump(exclude_none=True),
callbacks=callbacks,
num_sanity_val_steps=0,
# enable_model_summary=False,
# enable_progress_bar=False,
**trainer_kwargs,
)
train_dataloader = DataLoader(