mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
3 Commits
dbd2d30ead
...
7336638fa9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7336638fa9 | ||
|
|
16c401b1da | ||
|
|
4ecbc2b734 |
@ -194,10 +194,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
detection_data.geometry.item()
|
detection_data.geometry.item()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
start_time = detection_data.start_time
|
start_time = detection_data.start_time.item()
|
||||||
end_time = detection_data.end_time
|
end_time = detection_data.end_time.item()
|
||||||
low_freq = detection_data.low_freq
|
low_freq = detection_data.low_freq.item()
|
||||||
high_freq = detection_data.high_freq
|
high_freq = detection_data.high_freq.item()
|
||||||
geometry = data.BoundingBox.model_construct(
|
geometry = data.BoundingBox.model_construct(
|
||||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -29,6 +29,7 @@ __all__ = [
|
|||||||
"ClassificationMetric",
|
"ClassificationMetric",
|
||||||
"ClassificationMetricConfig",
|
"ClassificationMetricConfig",
|
||||||
"build_classification_metric",
|
"build_classification_metric",
|
||||||
|
"compute_precision_recall_curves",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,14 @@ def average_precision(
|
|||||||
num_positives=num_positives,
|
num_positives=num_positives,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pascal 12 way
|
||||||
|
return _average_precision(recall, precision)
|
||||||
|
|
||||||
|
|
||||||
|
def _average_precision(
|
||||||
|
recall: np.ndarray,
|
||||||
|
precision: np.ndarray,
|
||||||
|
) -> float:
|
||||||
# pascal 12 way
|
# pascal 12 way
|
||||||
mprec = np.hstack((0, precision, 0))
|
mprec = np.hstack((0, precision, 0))
|
||||||
mrec = np.hstack((0, recall, 1))
|
mrec = np.hstack((0, recall, 1))
|
||||||
@ -59,5 +67,4 @@ def average_precision(
|
|||||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||||
|
|
||||||
return ave_prec
|
return ave_prec
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated, Optional, Sequence, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.evaluate.tasks.base import tasks_registry
|
from batdetect2.evaluate.tasks.base import BaseTaskConfig, tasks_registry
|
||||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
from batdetect2.evaluate.tasks.clip_classification import (
|
from batdetect2.evaluate.tasks.clip_classification import (
|
||||||
ClipClassificationTaskConfig,
|
ClipClassificationTaskConfig,
|
||||||
@ -11,11 +12,16 @@ from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
|||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
EvaluatorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TaskConfig",
|
"TaskConfig",
|
||||||
"build_task",
|
"build_task",
|
||||||
|
"evaluate_task",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -37,3 +43,25 @@ def build_task(
|
|||||||
) -> EvaluatorProtocol:
|
) -> EvaluatorProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
return tasks_registry.build(config, targets)
|
return tasks_registry.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_task(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
task: Optional["str"] = None,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
config: Optional[Union[TaskConfig, dict]] = None,
|
||||||
|
):
|
||||||
|
if isinstance(config, BaseTaskConfig):
|
||||||
|
task_obj = build_task(config, targets)
|
||||||
|
return task_obj.evaluate(clip_annotations, predictions)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Task must be specified if a full config is not provided.",
|
||||||
|
)
|
||||||
|
|
||||||
|
config_class = tasks_registry.get_config_type(task)
|
||||||
|
config = config_class.model_validate(config or {}) # type: ignore
|
||||||
|
task_obj = build_task(config, targets) # type: ignore
|
||||||
|
return task_obj.evaluate(clip_annotations, predictions)
|
||||||
|
|||||||
@ -174,6 +174,22 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
|
|
||||||
|
key = torch.matmul(
|
||||||
|
x, self.key_fun.weight.T
|
||||||
|
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
query = torch.matmul(
|
||||||
|
x, self.query_fun.weight.T
|
||||||
|
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||||
|
self.temperature * self.att_dim
|
||||||
|
)
|
||||||
|
att_weights = F.softmax(kk_qq, 1)
|
||||||
|
return att_weights
|
||||||
|
|
||||||
|
|
||||||
class ConvConfig(BaseConfig):
|
class ConvConfig(BaseConfig):
|
||||||
"""Configuration for a basic ConvBlock."""
|
"""Configuration for a basic ConvBlock."""
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import seaborn as sns
|
|||||||
from cycler import cycler
|
from cycler import cycler
|
||||||
from matplotlib import axes
|
from matplotlib import axes
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.common import _average_precision
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
|
||||||
@ -80,15 +81,21 @@ def plot_pr_curves(
|
|||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
|
include_ap: bool = False,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
ax = set_default_style(ax)
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
for name, (precision, recall, thresholds) in data.items():
|
for name, (precision, recall, thresholds) in data.items():
|
||||||
|
label = name
|
||||||
|
|
||||||
|
if include_ap:
|
||||||
|
label += f" (AP={_average_precision(recall, precision):.2f})"
|
||||||
|
|
||||||
ax.plot(
|
ax.plot(
|
||||||
recall,
|
recall,
|
||||||
precision,
|
precision,
|
||||||
label=name,
|
label=label,
|
||||||
markevery=_get_marker_positions(thresholds),
|
markevery=_get_marker_positions(thresholds),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user