Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
bdb9e18964 Add explicit kwarg name to _freq_to_index to avoid confusion 2025-11-16 23:57:11 +00:00
mbsantiago
a4498cfd83 Add functional versions of metric and plotting utils 2025-11-16 21:37:47 +00:00
mbsantiago
960b9a92e4 Fix legacy import to use reproducible UUIDs 2025-11-16 21:37:33 +00:00
10 changed files with 201 additions and 110 deletions

View File

@ -19,6 +19,7 @@ from batdetect2.data.predictions import (
SoundEventOutputConfig, SoundEventOutputConfig,
build_output_formatter, build_output_formatter,
get_output_formatter, get_output_formatter,
load_predictions,
) )
from batdetect2.data.summary import ( from batdetect2.data.summary import (
compute_class_summary, compute_class_summary,
@ -46,4 +47,5 @@ __all__ = [
"load_dataset", "load_dataset",
"load_dataset_config", "load_dataset_config",
"load_dataset_from_config", "load_dataset_from_config",
"load_predictions",
] ]

View File

@ -18,6 +18,14 @@ UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242") NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
CLIP_NAMESPACE = uuid.uuid5(NAMESPACE, "clip")
CLIP_ANNOTATION_NAMESPACE = uuid.uuid5(NAMESPACE, "clip_annotation")
RECORDING_NAMESPACE = uuid.uuid5(NAMESPACE, "recording")
SOUND_EVENT_NAMESPACE = uuid.uuid5(NAMESPACE, "sound_event")
SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
NAMESPACE, "sound_event_annotation"
)
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
@ -71,8 +79,8 @@ def annotation_to_sound_event(
"""Convert annotation to sound event annotation.""" """Convert annotation to sound event annotation."""
sound_event = data.SoundEvent( sound_event = data.SoundEvent(
uuid=uuid.uuid5( uuid=uuid.uuid5(
NAMESPACE, SOUND_EVENT_NAMESPACE,
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}", f"{recording.uuid}_{annotation.start_time}_{annotation.end_time}",
), ),
recording=recording, recording=recording,
geometry=data.BoundingBox( geometry=data.BoundingBox(
@ -86,7 +94,10 @@ def annotation_to_sound_event(
) )
return data.SoundEventAnnotation( return data.SoundEventAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), uuid=uuid.uuid5(
SOUND_EVENT_ANNOTATION_NAMESPACE,
f"{sound_event.uuid}",
),
sound_event=sound_event, sound_event=sound_event,
tags=get_sound_event_tags( tags=get_sound_event_tags(
annotation, label_key, event_key, individual_key annotation, label_key, event_key, individual_key
@ -139,12 +150,18 @@ def file_annotation_to_clip(
time_expansion=file_annotation.time_exp, time_expansion=file_annotation.time_exp,
tags=tags, tags=tags,
) )
recording.uuid = uuid.uuid5(RECORDING_NAMESPACE, f"{recording.hash}")
start_time = 0
end_time = recording.duration
return data.Clip( return data.Clip(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"), uuid=uuid.uuid5(
CLIP_NAMESPACE,
f"{recording.uuid}_{start_time}_{end_time}",
),
recording=recording, recording=recording,
start_time=0, start_time=start_time,
end_time=recording.duration, end_time=end_time,
) )
@ -165,7 +182,7 @@ def file_annotation_to_clip_annotation(
tags.append(data.Tag(key=label_key, value=file_annotation.label)) tags.append(data.Tag(key=label_key, value=file_annotation.label))
return data.ClipAnnotation( return data.ClipAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"), uuid=uuid.uuid5(CLIP_ANNOTATION_NAMESPACE, f"{clip.uuid}"),
clip=clip, clip=clip,
notes=notes, notes=notes,
tags=tags, tags=tags,

View File

@ -1,6 +1,7 @@
from typing import Annotated, Optional, Union from typing import Annotated, Optional, Union
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.data.predictions.base import ( from batdetect2.data.predictions.base import (
OutputFormatterProtocol, OutputFormatterProtocol,
@ -21,7 +22,11 @@ __all__ = [
OutputFormatConfig = Annotated[ OutputFormatConfig = Annotated[
Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig], Union[
BatDetect2OutputConfig,
SoundEventOutputConfig,
RawOutputConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -40,13 +45,16 @@ def build_output_formatter(
def get_output_formatter( def get_output_formatter(
name: str, name: Optional[str] = None,
targets: Optional[TargetProtocol] = None, targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None, config: Optional[OutputFormatConfig] = None,
) -> OutputFormatterProtocol: ) -> OutputFormatterProtocol:
"""Get the output formatter by name.""" """Get the output formatter by name."""
if config is None: if config is None:
if name is None:
raise ValueError("Either config or name must be provided.")
config_class = prediction_formatters.get_config_type(name) config_class = prediction_formatters.get_config_type(name)
config = config_class() # type: ignore config = config_class() # type: ignore
@ -56,3 +64,17 @@ def get_output_formatter(
) )
return build_output_formatter(targets, config) return build_output_formatter(targets, config)
def load_predictions(
path: PathLike,
format: Optional[str] = "raw",
config: Optional[OutputFormatConfig] = None,
targets: Optional[TargetProtocol] = None,
):
"""Load predictions from a file."""
from batdetect2.targets import build_targets
targets = targets or build_targets()
formatter = get_output_formatter(format, targets, config)
return formatter.load(path)

View File

@ -5,6 +5,7 @@ from uuid import UUID, uuid4
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from loguru import logger
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
@ -36,11 +37,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
include_class_scores: bool = True, include_class_scores: bool = True,
include_features: bool = True, include_features: bool = True,
include_geometry: bool = True, include_geometry: bool = True,
parse_full_geometry: bool = False,
): ):
self.targets = targets self.targets = targets
self.include_class_scores = include_class_scores self.include_class_scores = include_class_scores
self.include_features = include_features self.include_features = include_features
self.include_geometry = include_geometry self.include_geometry = include_geometry
self.parse_full_geometry = parse_full_geometry
def format( def format(
self, self,
@ -169,6 +172,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
predictions: List[BatDetect2Prediction] = [] predictions: List[BatDetect2Prediction] = []
for _, clip_data in root.items(): for _, clip_data in root.items():
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
recording = data.Recording.model_validate_json( recording = data.Recording.model_validate_json(
clip_data.attrs["recording"] clip_data.attrs["recording"]
) )
@ -183,37 +187,36 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
sound_events = [] sound_events = []
for detection in clip_data.detection: for detection in clip_data.coords["detection"]:
score = clip_data.score.sel(detection=detection).item() detection_data = clip_data.sel(detection=detection)
score = detection_data.score.item()
if "geometry" in clip_data: if "geometry" in clip_data and self.parse_full_geometry:
geometry = data.geometry_validate( geometry = data.geometry_validate(
clip_data.geometry.sel(detection=detection).item() detection_data.geometry.item()
) )
else: else:
start_time = clip_data.start_time.sel(detection=detection) start_time = detection_data.start_time
end_time = clip_data.end_time.sel(detection=detection) end_time = detection_data.end_time
low_freq = clip_data.low_freq.sel(detection=detection) low_freq = detection_data.low_freq
high_freq = clip_data.high_freq.sel(detection=detection) high_freq = detection_data.high_freq
geometry = data.BoundingBox( geometry = data.BoundingBox.model_construct(
coordinates=[start_time, low_freq, end_time, high_freq] coordinates=[start_time, low_freq, end_time, high_freq]
) )
if "class_scores" in clip_data: if "class_scores" in detection_data:
class_scores = clip_data.class_scores.sel( class_scores = detection_data.class_scores.data
detection=detection
).data
else: else:
class_scores = np.zeros(len(self.targets.class_names)) class_scores = np.zeros(len(self.targets.class_names))
class_index = self.targets.class_names.index( class_index = self.targets.class_names.index(
clip_data.top_class.sel(detection=detection).item() detection_data.top_class.item()
)
class_scores[class_index] = (
detection_data.top_class_score.item()
) )
class_scores[class_index] = clip_data.top_class_score.sel(
detection=detection
).item()
if "features" in clip_data: if "features" in detection_data:
features = clip_data.features.sel(detection=detection).data features = detection_data.features.data
else: else:
features = np.zeros(0) features = np.zeros(0)

View File

@ -9,6 +9,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple,
Union, Union,
) )
@ -18,7 +19,10 @@ from sklearn import metrics
from soundevent import data from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import (
average_precision,
compute_precision_recall,
)
from batdetect2.typing import RawPrediction, TargetProtocol from batdetect2.typing import RawPrediction, TargetProtocol
__all__ = [ __all__ = [
@ -265,3 +269,24 @@ def _extract_per_class_metric_data(
y_score[class_name].append(m.score) y_score[class_name].append(m.score)
return y_true, y_score, num_positives return y_true, y_score, num_positives
def compute_precision_recall_curves(
clip_evaluations: Sequence[ClipEval],
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=ignore_non_predictions,
ignore_generic=ignore_generic,
)
return {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in y_true
}

View File

@ -18,6 +18,7 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction from batdetect2.typing import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"TopClassMetricConfig", "TopClassMetricConfig",
@ -312,3 +313,61 @@ TopClassMetricConfig = Annotated[
def build_top_class_metric(config: TopClassMetricConfig): def build_top_class_metric(config: TopClassMetricConfig):
return top_class_metrics.build(config) return top_class_metrics.build(config)
def compute_confusion_matrix(
clip_evaluations: Sequence[ClipEval],
targets: TargetProtocol,
threshold: float = 0.2,
normalize: Literal["true", "pred", "all", "none"] = "true",
exclude_generic: bool = True,
exclude_false_positives: bool = True,
exclude_false_negatives: bool = True,
noise_class: str = "noise",
):
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and exclude_false_negatives:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and exclude_false_positives:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < threshold:
if exclude_false_negatives:
continue
pred_class = noise_class
if m.is_generic:
if exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = targets.detection_class_name
y_true.append(true_class or noise_class)
y_pred.append(pred_class or noise_class)
labels = sorted(targets.class_names)
if not exclude_generic:
labels.append(targets.detection_class_name)
if not exclude_false_positives or not exclude_false_negatives:
labels.append(noise_class)
return metrics.confusion_matrix(
y_true,
y_pred,
labels=labels,
normalize=normalize,
), labels

View File

@ -18,8 +18,8 @@ from batdetect2.core import Registry
from batdetect2.evaluate.metrics.classification import ( from batdetect2.evaluate.metrics.classification import (
ClipEval, ClipEval,
_extract_per_class_metric_data, _extract_per_class_metric_data,
compute_precision_recall_curves,
) )
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import ( from batdetect2.plotting.metrics import (
plot_pr_curve, plot_pr_curve,
@ -69,21 +69,12 @@ class PRCurve(BasePlot):
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data( data = compute_precision_recall_curves(
clip_evaluations, clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions, ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic, ignore_generic=self.ignore_generic,
) )
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures: if not self.separate_figures:
fig = self.create_figure() fig = self.create_figure()
ax = fig.subplots() ax = fig.subplots()
@ -141,21 +132,12 @@ class ThresholdPrecisionCurve(BasePlot):
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data( data = compute_precision_recall_curves(
clip_evaluations, clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions, ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic, ignore_generic=self.ignore_generic,
) )
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures: if not self.separate_figures:
fig = self.create_figure() fig = self.create_figure()
ax = fig.subplots() ax = fig.subplots()
@ -223,21 +205,12 @@ class ThresholdRecallCurve(BasePlot):
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data( data = compute_precision_recall_curves(
clip_evaluations, clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions, ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic, ignore_generic=self.ignore_generic,
) )
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures: if not self.separate_figures:
fig = self.create_figure() fig = self.create_figure()
ax = fig.subplots() ax = fig.subplots()

View File

@ -23,7 +23,11 @@ from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval from batdetect2.evaluate.metrics.top_class import (
ClipEval,
MatchEval,
compute_confusion_matrix,
)
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
@ -186,6 +190,8 @@ class ConfusionMatrix(BasePlot):
self, self,
*args, *args,
exclude_generic: bool = True, exclude_generic: bool = True,
exclude_false_positives: bool = True,
exclude_false_negatives: bool = True,
exclude_noise: bool = False, exclude_noise: bool = False,
noise_class: str = "noise", noise_class: str = "noise",
add_colorbar: bool = True, add_colorbar: bool = True,
@ -196,9 +202,11 @@ class ConfusionMatrix(BasePlot):
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.exclude_generic = exclude_generic self.exclude_generic = exclude_generic
self.exclude_false_positives = exclude_false_positives
self.exclude_false_negatives = exclude_false_negatives
self.exclude_noise = exclude_noise self.exclude_noise = exclude_noise
self.noise_class = noise_class self.noise_class = noise_class
self.normalize = normalize self.normalize: Literal["true", "pred", "all", "none"] = normalize
self.add_colorbar = add_colorbar self.add_colorbar = add_colorbar
self.threshold = threshold self.threshold = threshold
self.cmap = cmap self.cmap = cmap
@ -207,58 +215,25 @@ class ConfusionMatrix(BasePlot):
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
y_true: List[str] = [] cm, labels = compute_confusion_matrix(
y_pred: List[str] = [] clip_evaluations,
self.targets,
for clip_eval in clip_evaluations: threshold=self.threshold,
for m in clip_eval.matches: normalize=self.normalize,
true_class = m.true_class exclude_generic=self.exclude_generic,
pred_class = m.pred_class exclude_false_positives=self.exclude_false_positives,
exclude_false_negatives=self.exclude_false_negatives,
if not m.is_prediction and self.exclude_noise: noise_class=self.noise_class,
# Ignore matches that don't correspond to a prediction )
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < self.threshold:
if self.exclude_noise:
continue
pred_class = self.noise_class
if m.is_generic:
if self.exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = self.targets.detection_class_name
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
fig = self.create_figure() fig = self.create_figure()
ax = fig.subplots() ax = fig.subplots()
class_names = [*self.targets.class_names] metrics.ConfusionMatrixDisplay(cm, display_labels=labels).plot(
if not self.exclude_generic:
class_names.append(self.targets.detection_class_name)
if not self.exclude_noise:
class_names.append(self.noise_class)
metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=class_names,
ax=ax, ax=ax,
xticks_rotation="vertical", xticks_rotation="vertical",
cmap=self.cmap, cmap=self.cmap,
colorbar=self.add_colorbar, colorbar=self.add_colorbar,
normalize=self.normalize if self.normalize != "none" else None,
values_format=".2f", values_format=".2f",
) )

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns
@ -34,8 +34,14 @@ def plot_pr_curve(
thresholds: np.ndarray, thresholds: np.ndarray,
ax: Optional[axes.Axes] = None, ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
color: Union[str, Tuple[float, float, float], None] = None,
add_labels: bool = True, add_labels: bool = True,
add_legend: bool = False, add_legend: bool = False,
marker: Union[str, Tuple[int, int, float], None] = "o",
markeredgecolor: Union[str, Tuple[float, float, float], None] = None,
markersize: Optional[float] = None,
linestyle: Union[str, Tuple[int, ...], None] = None,
linewidth: Optional[float] = None,
label: str = "PR Curve", label: str = "PR Curve",
) -> axes.Axes: ) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
@ -45,9 +51,14 @@ def plot_pr_curve(
ax.plot( ax.plot(
recall, recall,
precision, precision,
color=color,
label=label, label=label,
marker="o", marker=marker,
markeredgecolor=markeredgecolor,
markevery=_get_marker_positions(thresholds), markevery=_get_marker_positions(thresholds),
markersize=markersize,
linestyle=linestyle,
linewidth=linewidth,
) )
ax.set_xlim(0, 1.05) ax.set_xlim(0, 1.05)

View File

@ -146,14 +146,18 @@ class FrequencyCrop(torch.nn.Module):
low_index = None low_index = None
if min_freq is not None: if min_freq is not None:
low_index = _frequency_to_index( low_index = _frequency_to_index(
min_freq, self.samplerate, self.n_fft min_freq,
n_fft=self.n_fft,
samplerate=self.samplerate,
) )
self.low_index = low_index self.low_index = low_index
high_index = None high_index = None
if max_freq is not None: if max_freq is not None:
high_index = _frequency_to_index( high_index = _frequency_to_index(
max_freq, self.samplerate, self.n_fft max_freq,
n_fft=self.n_fft,
samplerate=self.samplerate,
) )
self.high_index = high_index self.high_index = high_index