mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
3 Commits
76503fbd12
...
bdb9e18964
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bdb9e18964 | ||
|
|
a4498cfd83 | ||
|
|
960b9a92e4 |
@ -19,6 +19,7 @@ from batdetect2.data.predictions import (
|
|||||||
SoundEventOutputConfig,
|
SoundEventOutputConfig,
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
get_output_formatter,
|
get_output_formatter,
|
||||||
|
load_predictions,
|
||||||
)
|
)
|
||||||
from batdetect2.data.summary import (
|
from batdetect2.data.summary import (
|
||||||
compute_class_summary,
|
compute_class_summary,
|
||||||
@ -46,4 +47,5 @@ __all__ = [
|
|||||||
"load_dataset",
|
"load_dataset",
|
||||||
"load_dataset_config",
|
"load_dataset_config",
|
||||||
"load_dataset_from_config",
|
"load_dataset_from_config",
|
||||||
|
"load_predictions",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -18,6 +18,14 @@ UNKNOWN_CLASS = "__UNKNOWN__"
|
|||||||
|
|
||||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||||
|
|
||||||
|
CLIP_NAMESPACE = uuid.uuid5(NAMESPACE, "clip")
|
||||||
|
CLIP_ANNOTATION_NAMESPACE = uuid.uuid5(NAMESPACE, "clip_annotation")
|
||||||
|
RECORDING_NAMESPACE = uuid.uuid5(NAMESPACE, "recording")
|
||||||
|
SOUND_EVENT_NAMESPACE = uuid.uuid5(NAMESPACE, "sound_event")
|
||||||
|
SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
|
||||||
|
NAMESPACE, "sound_event_annotation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
|
||||||
@ -71,8 +79,8 @@ def annotation_to_sound_event(
|
|||||||
"""Convert annotation to sound event annotation."""
|
"""Convert annotation to sound event annotation."""
|
||||||
sound_event = data.SoundEvent(
|
sound_event = data.SoundEvent(
|
||||||
uuid=uuid.uuid5(
|
uuid=uuid.uuid5(
|
||||||
NAMESPACE,
|
SOUND_EVENT_NAMESPACE,
|
||||||
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
f"{recording.uuid}_{annotation.start_time}_{annotation.end_time}",
|
||||||
),
|
),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
@ -86,7 +94,10 @@ def annotation_to_sound_event(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return data.SoundEventAnnotation(
|
return data.SoundEventAnnotation(
|
||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(
|
||||||
|
SOUND_EVENT_ANNOTATION_NAMESPACE,
|
||||||
|
f"{sound_event.uuid}",
|
||||||
|
),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=get_sound_event_tags(
|
tags=get_sound_event_tags(
|
||||||
annotation, label_key, event_key, individual_key
|
annotation, label_key, event_key, individual_key
|
||||||
@ -139,12 +150,18 @@ def file_annotation_to_clip(
|
|||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
recording.uuid = uuid.uuid5(RECORDING_NAMESPACE, f"{recording.hash}")
|
||||||
|
|
||||||
|
start_time = 0
|
||||||
|
end_time = recording.duration
|
||||||
return data.Clip(
|
return data.Clip(
|
||||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
uuid=uuid.uuid5(
|
||||||
|
CLIP_NAMESPACE,
|
||||||
|
f"{recording.uuid}_{start_time}_{end_time}",
|
||||||
|
),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
start_time=0,
|
start_time=start_time,
|
||||||
end_time=recording.duration,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -165,7 +182,7 @@ def file_annotation_to_clip_annotation(
|
|||||||
tags.append(data.Tag(key=label_key, value=file_annotation.label))
|
tags.append(data.Tag(key=label_key, value=file_annotation.label))
|
||||||
|
|
||||||
return data.ClipAnnotation(
|
return data.ClipAnnotation(
|
||||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
uuid=uuid.uuid5(CLIP_ANNOTATION_NAMESPACE, f"{clip.uuid}"),
|
||||||
clip=clip,
|
clip=clip,
|
||||||
notes=notes,
|
notes=notes,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.data.predictions.base import (
|
from batdetect2.data.predictions.base import (
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
@ -21,7 +22,11 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
OutputFormatConfig = Annotated[
|
OutputFormatConfig = Annotated[
|
||||||
Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig],
|
Union[
|
||||||
|
BatDetect2OutputConfig,
|
||||||
|
SoundEventOutputConfig,
|
||||||
|
RawOutputConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -40,13 +45,16 @@ def build_output_formatter(
|
|||||||
|
|
||||||
|
|
||||||
def get_output_formatter(
|
def get_output_formatter(
|
||||||
name: str,
|
name: Optional[str] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
config: Optional[OutputFormatConfig] = None,
|
config: Optional[OutputFormatConfig] = None,
|
||||||
) -> OutputFormatterProtocol:
|
) -> OutputFormatterProtocol:
|
||||||
"""Get the output formatter by name."""
|
"""Get the output formatter by name."""
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
|
if name is None:
|
||||||
|
raise ValueError("Either config or name must be provided.")
|
||||||
|
|
||||||
config_class = prediction_formatters.get_config_type(name)
|
config_class = prediction_formatters.get_config_type(name)
|
||||||
config = config_class() # type: ignore
|
config = config_class() # type: ignore
|
||||||
|
|
||||||
@ -56,3 +64,17 @@ def get_output_formatter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return build_output_formatter(targets, config)
|
return build_output_formatter(targets, config)
|
||||||
|
|
||||||
|
|
||||||
|
def load_predictions(
|
||||||
|
path: PathLike,
|
||||||
|
format: Optional[str] = "raw",
|
||||||
|
config: Optional[OutputFormatConfig] = None,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
):
|
||||||
|
"""Load predictions from a file."""
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
|
||||||
|
targets = targets or build_targets()
|
||||||
|
formatter = get_output_formatter(format, targets, config)
|
||||||
|
return formatter.load(path)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
@ -36,11 +37,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
include_class_scores: bool = True,
|
include_class_scores: bool = True,
|
||||||
include_features: bool = True,
|
include_features: bool = True,
|
||||||
include_geometry: bool = True,
|
include_geometry: bool = True,
|
||||||
|
parse_full_geometry: bool = False,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.include_class_scores = include_class_scores
|
self.include_class_scores = include_class_scores
|
||||||
self.include_features = include_features
|
self.include_features = include_features
|
||||||
self.include_geometry = include_geometry
|
self.include_geometry = include_geometry
|
||||||
|
self.parse_full_geometry = parse_full_geometry
|
||||||
|
|
||||||
def format(
|
def format(
|
||||||
self,
|
self,
|
||||||
@ -169,6 +172,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
predictions: List[BatDetect2Prediction] = []
|
predictions: List[BatDetect2Prediction] = []
|
||||||
|
|
||||||
for _, clip_data in root.items():
|
for _, clip_data in root.items():
|
||||||
|
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
|
||||||
recording = data.Recording.model_validate_json(
|
recording = data.Recording.model_validate_json(
|
||||||
clip_data.attrs["recording"]
|
clip_data.attrs["recording"]
|
||||||
)
|
)
|
||||||
@ -183,37 +187,36 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
|
|
||||||
sound_events = []
|
sound_events = []
|
||||||
|
|
||||||
for detection in clip_data.detection:
|
for detection in clip_data.coords["detection"]:
|
||||||
score = clip_data.score.sel(detection=detection).item()
|
detection_data = clip_data.sel(detection=detection)
|
||||||
|
score = detection_data.score.item()
|
||||||
|
|
||||||
if "geometry" in clip_data:
|
if "geometry" in clip_data and self.parse_full_geometry:
|
||||||
geometry = data.geometry_validate(
|
geometry = data.geometry_validate(
|
||||||
clip_data.geometry.sel(detection=detection).item()
|
detection_data.geometry.item()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
start_time = clip_data.start_time.sel(detection=detection)
|
start_time = detection_data.start_time
|
||||||
end_time = clip_data.end_time.sel(detection=detection)
|
end_time = detection_data.end_time
|
||||||
low_freq = clip_data.low_freq.sel(detection=detection)
|
low_freq = detection_data.low_freq
|
||||||
high_freq = clip_data.high_freq.sel(detection=detection)
|
high_freq = detection_data.high_freq
|
||||||
geometry = data.BoundingBox(
|
geometry = data.BoundingBox.model_construct(
|
||||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
)
|
)
|
||||||
|
|
||||||
if "class_scores" in clip_data:
|
if "class_scores" in detection_data:
|
||||||
class_scores = clip_data.class_scores.sel(
|
class_scores = detection_data.class_scores.data
|
||||||
detection=detection
|
|
||||||
).data
|
|
||||||
else:
|
else:
|
||||||
class_scores = np.zeros(len(self.targets.class_names))
|
class_scores = np.zeros(len(self.targets.class_names))
|
||||||
class_index = self.targets.class_names.index(
|
class_index = self.targets.class_names.index(
|
||||||
clip_data.top_class.sel(detection=detection).item()
|
detection_data.top_class.item()
|
||||||
|
)
|
||||||
|
class_scores[class_index] = (
|
||||||
|
detection_data.top_class_score.item()
|
||||||
)
|
)
|
||||||
class_scores[class_index] = clip_data.top_class_score.sel(
|
|
||||||
detection=detection
|
|
||||||
).item()
|
|
||||||
|
|
||||||
if "features" in clip_data:
|
if "features" in detection_data:
|
||||||
features = clip_data.features.sel(detection=detection).data
|
features = detection_data.features.data
|
||||||
else:
|
else:
|
||||||
features = np.zeros(0)
|
features = np.zeros(0)
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,7 +19,10 @@ from sklearn import metrics
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import (
|
||||||
|
average_precision,
|
||||||
|
compute_precision_recall,
|
||||||
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -265,3 +269,24 @@ def _extract_per_class_metric_data(
|
|||||||
y_score[class_name].append(m.score)
|
y_score[class_name].append(m.score)
|
||||||
|
|
||||||
return y_true, y_score, num_positives
|
return y_true, y_score, num_positives
|
||||||
|
|
||||||
|
|
||||||
|
def compute_precision_recall_curves(
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=ignore_non_predictions,
|
||||||
|
ignore_generic=ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
class_name: compute_precision_recall(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives=num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in y_true
|
||||||
|
}
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from soundevent import data
|
|||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.typing import RawPrediction
|
from batdetect2.typing import RawPrediction
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TopClassMetricConfig",
|
"TopClassMetricConfig",
|
||||||
@ -312,3 +313,61 @@ TopClassMetricConfig = Annotated[
|
|||||||
|
|
||||||
def build_top_class_metric(config: TopClassMetricConfig):
|
def build_top_class_metric(config: TopClassMetricConfig):
|
||||||
return top_class_metrics.build(config)
|
return top_class_metrics.build(config)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_confusion_matrix(
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
threshold: float = 0.2,
|
||||||
|
normalize: Literal["true", "pred", "all", "none"] = "true",
|
||||||
|
exclude_generic: bool = True,
|
||||||
|
exclude_false_positives: bool = True,
|
||||||
|
exclude_false_negatives: bool = True,
|
||||||
|
noise_class: str = "noise",
|
||||||
|
):
|
||||||
|
y_true: List[str] = []
|
||||||
|
y_pred: List[str] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
true_class = m.true_class
|
||||||
|
pred_class = m.pred_class
|
||||||
|
|
||||||
|
if not m.is_prediction and exclude_false_negatives:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_ground_truth and exclude_false_positives:
|
||||||
|
# Ignore matches that don't correspond to a ground truth
|
||||||
|
continue
|
||||||
|
|
||||||
|
if m.score < threshold:
|
||||||
|
if exclude_false_negatives:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pred_class = noise_class
|
||||||
|
|
||||||
|
if m.is_generic:
|
||||||
|
if exclude_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
true_class = targets.detection_class_name
|
||||||
|
|
||||||
|
y_true.append(true_class or noise_class)
|
||||||
|
y_pred.append(pred_class or noise_class)
|
||||||
|
|
||||||
|
labels = sorted(targets.class_names)
|
||||||
|
|
||||||
|
if not exclude_generic:
|
||||||
|
labels.append(targets.detection_class_name)
|
||||||
|
|
||||||
|
if not exclude_false_positives or not exclude_false_negatives:
|
||||||
|
labels.append(noise_class)
|
||||||
|
|
||||||
|
return metrics.confusion_matrix(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
labels=labels,
|
||||||
|
normalize=normalize,
|
||||||
|
), labels
|
||||||
|
|||||||
@ -18,8 +18,8 @@ from batdetect2.core import Registry
|
|||||||
from batdetect2.evaluate.metrics.classification import (
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
_extract_per_class_metric_data,
|
_extract_per_class_metric_data,
|
||||||
|
compute_precision_recall_curves,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.metrics import (
|
from batdetect2.plotting.metrics import (
|
||||||
plot_pr_curve,
|
plot_pr_curve,
|
||||||
@ -69,21 +69,12 @@ class PRCurve(BasePlot):
|
|||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
data = compute_precision_recall_curves(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
ignore_generic=self.ignore_generic,
|
ignore_generic=self.ignore_generic,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
|
||||||
class_name: compute_precision_recall(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
num_positives=num_positives[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
if not self.separate_figures:
|
||||||
fig = self.create_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
@ -141,21 +132,12 @@ class ThresholdPrecisionCurve(BasePlot):
|
|||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
data = compute_precision_recall_curves(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
ignore_generic=self.ignore_generic,
|
ignore_generic=self.ignore_generic,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
|
||||||
class_name: compute_precision_recall(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
num_positives[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
if not self.separate_figures:
|
||||||
fig = self.create_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
@ -223,21 +205,12 @@ class ThresholdRecallCurve(BasePlot):
|
|||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
data = compute_precision_recall_curves(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
ignore_generic=self.ignore_generic,
|
ignore_generic=self.ignore_generic,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
|
||||||
class_name: compute_precision_recall(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
num_positives[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
if not self.separate_figures:
|
||||||
fig = self.create_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|||||||
@ -23,7 +23,11 @@ from sklearn import metrics
|
|||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.core import Registry
|
from batdetect2.core import Registry
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
|
ClipEval,
|
||||||
|
MatchEval,
|
||||||
|
compute_confusion_matrix,
|
||||||
|
)
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
@ -186,6 +190,8 @@ class ConfusionMatrix(BasePlot):
|
|||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
exclude_generic: bool = True,
|
exclude_generic: bool = True,
|
||||||
|
exclude_false_positives: bool = True,
|
||||||
|
exclude_false_negatives: bool = True,
|
||||||
exclude_noise: bool = False,
|
exclude_noise: bool = False,
|
||||||
noise_class: str = "noise",
|
noise_class: str = "noise",
|
||||||
add_colorbar: bool = True,
|
add_colorbar: bool = True,
|
||||||
@ -196,9 +202,11 @@ class ConfusionMatrix(BasePlot):
|
|||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.exclude_generic = exclude_generic
|
self.exclude_generic = exclude_generic
|
||||||
|
self.exclude_false_positives = exclude_false_positives
|
||||||
|
self.exclude_false_negatives = exclude_false_negatives
|
||||||
self.exclude_noise = exclude_noise
|
self.exclude_noise = exclude_noise
|
||||||
self.noise_class = noise_class
|
self.noise_class = noise_class
|
||||||
self.normalize = normalize
|
self.normalize: Literal["true", "pred", "all", "none"] = normalize
|
||||||
self.add_colorbar = add_colorbar
|
self.add_colorbar = add_colorbar
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.cmap = cmap
|
self.cmap = cmap
|
||||||
@ -207,58 +215,25 @@ class ConfusionMatrix(BasePlot):
|
|||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true: List[str] = []
|
cm, labels = compute_confusion_matrix(
|
||||||
y_pred: List[str] = []
|
clip_evaluations,
|
||||||
|
self.targets,
|
||||||
for clip_eval in clip_evaluations:
|
threshold=self.threshold,
|
||||||
for m in clip_eval.matches:
|
normalize=self.normalize,
|
||||||
true_class = m.true_class
|
exclude_generic=self.exclude_generic,
|
||||||
pred_class = m.pred_class
|
exclude_false_positives=self.exclude_false_positives,
|
||||||
|
exclude_false_negatives=self.exclude_false_negatives,
|
||||||
if not m.is_prediction and self.exclude_noise:
|
noise_class=self.noise_class,
|
||||||
# Ignore matches that don't correspond to a prediction
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
if not m.is_ground_truth and self.exclude_noise:
|
|
||||||
# Ignore matches that don't correspond to a ground truth
|
|
||||||
continue
|
|
||||||
|
|
||||||
if m.score < self.threshold:
|
|
||||||
if self.exclude_noise:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pred_class = self.noise_class
|
|
||||||
|
|
||||||
if m.is_generic:
|
|
||||||
if self.exclude_generic:
|
|
||||||
# Ignore gt sounds with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
true_class = self.targets.detection_class_name
|
|
||||||
|
|
||||||
y_true.append(true_class or self.noise_class)
|
|
||||||
y_pred.append(pred_class or self.noise_class)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
class_names = [*self.targets.class_names]
|
metrics.ConfusionMatrixDisplay(cm, display_labels=labels).plot(
|
||||||
|
|
||||||
if not self.exclude_generic:
|
|
||||||
class_names.append(self.targets.detection_class_name)
|
|
||||||
|
|
||||||
if not self.exclude_noise:
|
|
||||||
class_names.append(self.noise_class)
|
|
||||||
|
|
||||||
metrics.ConfusionMatrixDisplay.from_predictions(
|
|
||||||
y_true,
|
|
||||||
y_pred,
|
|
||||||
labels=class_names,
|
|
||||||
ax=ax,
|
ax=ax,
|
||||||
xticks_rotation="vertical",
|
xticks_rotation="vertical",
|
||||||
cmap=self.cmap,
|
cmap=self.cmap,
|
||||||
colorbar=self.add_colorbar,
|
colorbar=self.add_colorbar,
|
||||||
normalize=self.normalize if self.normalize != "none" else None,
|
|
||||||
values_format=".2f",
|
values_format=".2f",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
@ -34,8 +34,14 @@ def plot_pr_curve(
|
|||||||
thresholds: np.ndarray,
|
thresholds: np.ndarray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
color: Union[str, Tuple[float, float, float], None] = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
add_legend: bool = False,
|
add_legend: bool = False,
|
||||||
|
marker: Union[str, Tuple[int, int, float], None] = "o",
|
||||||
|
markeredgecolor: Union[str, Tuple[float, float, float], None] = None,
|
||||||
|
markersize: Optional[float] = None,
|
||||||
|
linestyle: Union[str, Tuple[int, ...], None] = None,
|
||||||
|
linewidth: Optional[float] = None,
|
||||||
label: str = "PR Curve",
|
label: str = "PR Curve",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
@ -45,9 +51,14 @@ def plot_pr_curve(
|
|||||||
ax.plot(
|
ax.plot(
|
||||||
recall,
|
recall,
|
||||||
precision,
|
precision,
|
||||||
|
color=color,
|
||||||
label=label,
|
label=label,
|
||||||
marker="o",
|
marker=marker,
|
||||||
|
markeredgecolor=markeredgecolor,
|
||||||
markevery=_get_marker_positions(thresholds),
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
markersize=markersize,
|
||||||
|
linestyle=linestyle,
|
||||||
|
linewidth=linewidth,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
ax.set_xlim(0, 1.05)
|
||||||
|
|||||||
@ -146,14 +146,18 @@ class FrequencyCrop(torch.nn.Module):
|
|||||||
low_index = None
|
low_index = None
|
||||||
if min_freq is not None:
|
if min_freq is not None:
|
||||||
low_index = _frequency_to_index(
|
low_index = _frequency_to_index(
|
||||||
min_freq, self.samplerate, self.n_fft
|
min_freq,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
samplerate=self.samplerate,
|
||||||
)
|
)
|
||||||
self.low_index = low_index
|
self.low_index = low_index
|
||||||
|
|
||||||
high_index = None
|
high_index = None
|
||||||
if max_freq is not None:
|
if max_freq is not None:
|
||||||
high_index = _frequency_to_index(
|
high_index = _frequency_to_index(
|
||||||
max_freq, self.samplerate, self.n_fft
|
max_freq,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
samplerate=self.samplerate,
|
||||||
)
|
)
|
||||||
self.high_index = high_index
|
self.high_index = high_index
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user