mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Create metrics
This commit is contained in:
parent
bc86c94f8e
commit
9c8b8fb200
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
@ -10,22 +10,20 @@ __all__ = ["DetectionAveragePrecision"]
|
|||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecision(MetricsProtocol):
|
class DetectionAveragePrecision(MetricsProtocol):
|
||||||
name: str = "detection/average_precision"
|
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||||
|
|
||||||
def __call__(self, matches: List[Match]) -> float:
|
|
||||||
y_true, y_score = zip(
|
y_true, y_score = zip(
|
||||||
*[(match.gt_det, match.pred_score) for match in matches]
|
*[(match.gt_det, match.pred_score) for match in matches]
|
||||||
)
|
)
|
||||||
return float(metrics.average_precision_score(y_true, y_score))
|
score = float(metrics.average_precision_score(y_true, y_score))
|
||||||
|
return {"detection_AP": score}
|
||||||
|
|
||||||
|
|
||||||
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||||
name: str = "classification/average_precision"
|
def __init__(self, class_names: List[str], per_class: bool = True):
|
||||||
|
|
||||||
def __init__(self, class_names: List[str]):
|
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
|
self.per_class = per_class
|
||||||
|
|
||||||
def __call__(self, matches: List[Match]) -> float:
|
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||||
y_true = label_binarize(
|
y_true = label_binarize(
|
||||||
[
|
[
|
||||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
@ -33,7 +31,67 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
],
|
],
|
||||||
classes=self.class_names,
|
classes=self.class_names,
|
||||||
)
|
)
|
||||||
y_pred = pd.DataFrame([match.class_scores for match in matches])
|
y_pred = pd.DataFrame(
|
||||||
return float(
|
[
|
||||||
metrics.average_precision_score(y_true, y_pred[self.class_names])
|
{
|
||||||
|
name: match.class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
}
|
||||||
|
for match in matches
|
||||||
|
]
|
||||||
|
).fillna(0)
|
||||||
|
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
||||||
|
|
||||||
|
ret = {
|
||||||
|
"classification_mAP": float(mAP),
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.per_class:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[class_name]
|
||||||
|
class_ap = metrics.average_precision_score(
|
||||||
|
y_true_class,
|
||||||
|
y_pred_class,
|
||||||
|
)
|
||||||
|
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationAccuracy(MetricsProtocol):
|
||||||
|
def __init__(self, class_names: List[str]):
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||||
|
y_true = [
|
||||||
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
|
for match in matches
|
||||||
|
]
|
||||||
|
|
||||||
|
y_pred = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: match.class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
}
|
||||||
|
for match in matches
|
||||||
|
]
|
||||||
|
).fillna(0)
|
||||||
|
y_pred = y_pred.apply(
|
||||||
|
lambda row: row.idxmax()
|
||||||
|
if row.max() >= (1 - row.sum())
|
||||||
|
else "__NONE__",
|
||||||
|
axis=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
accuracy = metrics.balanced_accuracy_score(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"classification_acc": float(accuracy),
|
||||||
|
}
|
||||||
|
@ -19,6 +19,4 @@ class Match:
|
|||||||
|
|
||||||
|
|
||||||
class MetricsProtocol(Protocol):
|
class MetricsProtocol(Protocol):
|
||||||
name: str
|
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...
|
||||||
|
|
||||||
def __call__(self, matches: List[Match]) -> float: ...
|
|
||||||
|
@ -26,7 +26,14 @@ from batdetect2.train.dataset import (
|
|||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.losses import LossFunction, build_loss
|
from batdetect2.train.losses import (
|
||||||
|
ClassificationLossConfig,
|
||||||
|
DetectionLossConfig,
|
||||||
|
LossConfig,
|
||||||
|
LossFunction,
|
||||||
|
SizeLossConfig,
|
||||||
|
build_loss,
|
||||||
|
)
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
generate_train_example,
|
generate_train_example,
|
||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
@ -39,11 +46,15 @@ from batdetect2.train.train import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
|
"ClassificationLossConfig",
|
||||||
|
"DetectionLossConfig",
|
||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
|
"LossConfig",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
"RandomExampleSource",
|
"RandomExampleSource",
|
||||||
|
"SizeLossConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
"TrainerConfig",
|
||||||
|
@ -28,10 +28,11 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
metrics = {}
|
||||||
for metric in self.metrics:
|
for metric in self.metrics:
|
||||||
value = metric(self.matches)
|
metrics.update(metric(self.matches).items())
|
||||||
pl_module.log(f"val/metric/{metric.name}", value, prog_bar=True)
|
|
||||||
|
|
||||||
|
pl_module.log_dict(metrics)
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
def on_validation_epoch_start(
|
def on_validation_epoch_start(
|
||||||
|
@ -51,10 +51,10 @@ class TrainingModule(L.LightningModule):
|
|||||||
outputs = self.forward(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("train/loss/detection", losses.total, logger=True)
|
self.log("detection_loss/train", losses.total, logger=True)
|
||||||
self.log("train/loss/size", losses.total, logger=True)
|
self.log("size_loss/train", losses.total, logger=True)
|
||||||
self.log("train/loss/classification", losses.total, logger=True)
|
self.log("classification_loss/train", losses.total, logger=True)
|
||||||
|
|
||||||
return losses.total
|
return losses.total
|
||||||
|
|
||||||
@ -64,10 +64,10 @@ class TrainingModule(L.LightningModule):
|
|||||||
outputs = self.forward(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("val/loss/detection", losses.total, logger=True)
|
self.log("detection_loss/val", losses.total, logger=True)
|
||||||
self.log("val/loss/size", losses.total, logger=True)
|
self.log("size_loss/val", losses.total, logger=True)
|
||||||
self.log("val/loss/classification", losses.total, logger=True)
|
self.log("classification_loss/val", losses.total, logger=True)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ def train(
|
|||||||
val_examples: Optional[List[data.PathLike]] = None,
|
val_examples: Optional[List[data.PathLike]] = None,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
callbacks: Optional[List[Callback]] = None,
|
callbacks: Optional[List[Callback]] = None,
|
||||||
|
**trainer_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
@ -74,9 +75,7 @@ def train(
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
**config.trainer.model_dump(exclude_none=True),
|
**config.trainer.model_dump(exclude_none=True),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
num_sanity_val_steps=0,
|
**trainer_kwargs,
|
||||||
# enable_model_summary=False,
|
|
||||||
# enable_progress_bar=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
|
Loading…
Reference in New Issue
Block a user