Compare commits

..

No commits in common. "7336638fa99e110fe723fe4b0c400cd9590bfbc6" and "dbd2d30ead9c9ba4a2b10f7fd976f29884f729e7" have entirely different histories.

7 changed files with 9 additions and 69 deletions

View File

@ -194,10 +194,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
detection_data.geometry.item()
)
else:
start_time = detection_data.start_time.item()
end_time = detection_data.end_time.item()
low_freq = detection_data.low_freq.item()
high_freq = detection_data.high_freq.item()
start_time = detection_data.start_time
end_time = detection_data.end_time
low_freq = detection_data.low_freq
high_freq = detection_data.high_freq
geometry = data.BoundingBox.model_construct(
coordinates=[start_time, low_freq, end_time, high_freq]
)

View File

@ -29,7 +29,6 @@ __all__ = [
"ClassificationMetric",
"ClassificationMetricConfig",
"build_classification_metric",
"compute_precision_recall_curves",
]

View File

@ -52,14 +52,6 @@ def average_precision(
num_positives=num_positives,
)
# pascal 12 way
return _average_precision(recall, precision)
def _average_precision(
recall: np.ndarray,
precision: np.ndarray,
) -> float:
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
@ -67,4 +59,5 @@ def _average_precision(
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Annotated,

View File

@ -1,9 +1,8 @@
from typing import Annotated, Optional, Sequence, Union
from typing import Annotated, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.tasks.base import BaseTaskConfig, tasks_registry
from batdetect2.evaluate.tasks.base import tasks_registry
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.clip_classification import (
ClipClassificationTaskConfig,
@ -12,16 +11,11 @@ from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.targets import build_targets
from batdetect2.typing import (
BatDetect2Prediction,
EvaluatorProtocol,
TargetProtocol,
)
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
__all__ = [
"TaskConfig",
"build_task",
"evaluate_task",
]
@ -43,25 +37,3 @@ def build_task(
) -> EvaluatorProtocol:
targets = targets or build_targets()
return tasks_registry.build(config, targets)
def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
task: Optional["str"] = None,
targets: Optional[TargetProtocol] = None,
config: Optional[Union[TaskConfig, dict]] = None,
):
if isinstance(config, BaseTaskConfig):
task_obj = build_task(config, targets)
return task_obj.evaluate(clip_annotations, predictions)
if task is None:
raise ValueError(
"Task must be specified if a full config is not provided.",
)
config_class = tasks_registry.get_config_type(task)
config = config_class.model_validate(config or {}) # type: ignore
task_obj = build_task(config, targets) # type: ignore
return task_obj.evaluate(clip_annotations, predictions)

View File

@ -174,22 +174,6 @@ class SelfAttention(nn.Module):
return op
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul(
x, self.key_fun.weight.T
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
query = torch.matmul(
x, self.query_fun.weight.T
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
self.temperature * self.att_dim
)
att_weights = F.softmax(kk_qq, 1)
return att_weights
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""

View File

@ -5,7 +5,6 @@ import seaborn as sns
from cycler import cycler
from matplotlib import axes
from batdetect2.evaluate.metrics.common import _average_precision
from batdetect2.plotting.common import create_ax
@ -81,21 +80,15 @@ def plot_pr_curves(
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
include_ap: bool = False,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, recall, thresholds) in data.items():
label = name
if include_ap:
label += f" (AP={_average_precision(recall, precision):.2f})"
ax.plot(
recall,
precision,
label=label,
label=name,
markevery=_get_marker_positions(thresholds),
)