mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
14 Commits
bf6f52a65d
...
2341f822a7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2341f822a7 | ||
|
|
c3d377b6e0 | ||
|
|
d9395d3eeb | ||
|
|
aaec66c15e | ||
|
|
6213238585 | ||
|
|
a485ea4f79 | ||
|
|
3cfceb76b4 | ||
|
|
d877d383a4 | ||
|
|
bb4a9fe645 | ||
|
|
87ce2acd6f | ||
|
|
e1908c35ca | ||
|
|
62923a201b | ||
|
|
d9323a1383 | ||
|
|
1ee9643a61 |
@ -1,3 +1,14 @@
|
||||
datasets:
|
||||
train:
|
||||
name: example dataset
|
||||
description: Only for demonstration purposes
|
||||
sources:
|
||||
- format: batdetect2
|
||||
name: Example Data
|
||||
description: Examples included for testing batdetect2
|
||||
annotations_dir: example_data/anns
|
||||
audio_dir: example_data/audio
|
||||
|
||||
targets:
|
||||
classes:
|
||||
classes:
|
||||
@ -46,7 +57,7 @@ preprocess:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
pcen:
|
||||
time_constant: 0.4
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
|
||||
@ -1,10 +0,0 @@
|
||||
datasets:
|
||||
train:
|
||||
name: example dataset
|
||||
description: Only for demonstration purposes
|
||||
sources:
|
||||
- format: batdetect2
|
||||
name: Example Data
|
||||
description: Examples included for testing batdetect2
|
||||
annotations_dir: example_data/anns
|
||||
audio_dir: example_data/audio
|
||||
@ -17,7 +17,7 @@ dependencies = [
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.5.0",
|
||||
"soundevent[audio,geometry,plot]>=2.6.5",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
@ -66,10 +66,7 @@ batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[dependency-groups]
|
||||
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
|
||||
marimo = [
|
||||
"marimo>=0.12.2",
|
||||
"pyarrow>=20.0.0",
|
||||
]
|
||||
marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"]
|
||||
dev = [
|
||||
"debugpy>=1.8.8",
|
||||
"hypothesis>=6.118.7",
|
||||
@ -77,7 +74,7 @@ dev = [
|
||||
"ruff>=0.7.3",
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"pyright>=1.1.399",
|
||||
"basedpyright>=1.31.0",
|
||||
"myst-parser>=3.0.1",
|
||||
"sphinx-autobuild>=2024.10.3",
|
||||
"numpydoc>=1.8.0",
|
||||
@ -88,12 +85,8 @@ dev = [
|
||||
"ty>=0.0.1a12",
|
||||
"rust-just>=1.40.0",
|
||||
]
|
||||
dvclive = [
|
||||
"dvclive>=3.48.2",
|
||||
]
|
||||
mlflow = [
|
||||
"mlflow>=3.1.1",
|
||||
]
|
||||
dvclive = ["dvclive>=3.48.2"]
|
||||
mlflow = ["mlflow>=3.1.1"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
|
||||
@ -1,5 +1,10 @@
|
||||
import logging
|
||||
|
||||
from loguru import logger
|
||||
|
||||
logger.disable("batdetect2")
|
||||
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@ -0,0 +1,15 @@
|
||||
from batdetect2.compat.data import (
|
||||
annotation_to_sound_event_annotation,
|
||||
annotation_to_sound_event_prediction,
|
||||
convert_to_annotation_group,
|
||||
file_annotation_to_clip_annotation,
|
||||
load_file_annotation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"annotation_to_sound_event_annotation",
|
||||
"annotation_to_sound_event_prediction",
|
||||
"convert_to_annotation_group",
|
||||
"file_annotation_to_clip_annotation",
|
||||
"load_file_annotation",
|
||||
]
|
||||
@ -1,24 +1,30 @@
|
||||
"""Compatibility functions between old and new data structures."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2 import types
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AudioLoaderAnnotationGroup,
|
||||
FileAnnotation,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
"load_file_annotation",
|
||||
"annotation_to_sound_event",
|
||||
"annotation_to_sound_event_annotation",
|
||||
"annotation_to_sound_event_prediction",
|
||||
]
|
||||
|
||||
SPECIES_TAG_KEY = "species"
|
||||
@ -37,7 +43,7 @@ IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
||||
|
||||
def get_recording_class_name(recording: data.Recording) -> str:
|
||||
"""Get the class name for a recording."""
|
||||
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
|
||||
tag = data.find_tag(recording.tags, label=SPECIES_TAG_KEY)
|
||||
if tag is None:
|
||||
return UNKNOWN_CLASS
|
||||
return tag.value
|
||||
@ -59,7 +65,7 @@ def convert_to_annotation_group(
|
||||
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||
class_fn: ClassFn = lambda _: 0,
|
||||
individual_fn: IndividualFn = lambda _: 0,
|
||||
) -> types.AudioLoaderAnnotationGroup:
|
||||
) -> AudioLoaderAnnotationGroup:
|
||||
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
|
||||
recording = annotation.clip.recording
|
||||
|
||||
@ -71,7 +77,7 @@ def convert_to_annotation_group(
|
||||
x_inds = []
|
||||
y_inds = []
|
||||
individual_ids = []
|
||||
annotations: List[types.Annotation] = []
|
||||
annotations: List[Annotation] = []
|
||||
class_id_file = class_fn(recording)
|
||||
|
||||
for sound_event in annotation.sound_events:
|
||||
@ -133,42 +139,13 @@ def convert_to_annotation_group(
|
||||
}
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
"""Annotation class to hold batdetect annotations."""
|
||||
|
||||
label: str = Field(alias="class")
|
||||
event: str
|
||||
individual: int = 0
|
||||
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
|
||||
|
||||
class FileAnnotation(BaseModel):
|
||||
"""FileAnnotation class to hold batdetect annotations for a file."""
|
||||
|
||||
id: str
|
||||
duration: float
|
||||
time_exp: float = 1
|
||||
|
||||
label: str = Field(alias="class_name")
|
||||
|
||||
annotation: List[Annotation]
|
||||
|
||||
annotated: bool = False
|
||||
issues: bool = False
|
||||
notes: str = ""
|
||||
|
||||
|
||||
def load_file_annotation(path: PathLike) -> FileAnnotation:
|
||||
"""Load annotation from batdetect format."""
|
||||
path = Path(path)
|
||||
return FileAnnotation.model_validate_json(path.read_text())
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
||||
def annotation_to_sound_event(
|
||||
def annotation_to_sound_event_annotation(
|
||||
annotation: Annotation,
|
||||
recording: data.Recording,
|
||||
label_key: str = "class",
|
||||
@ -179,15 +156,15 @@ def annotation_to_sound_event(
|
||||
sound_event = data.SoundEvent(
|
||||
uuid=uuid.uuid5(
|
||||
NAMESPACE,
|
||||
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
||||
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}",
|
||||
),
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
annotation.start_time,
|
||||
annotation.low_freq,
|
||||
annotation.end_time,
|
||||
annotation.high_freq,
|
||||
annotation["start_time"],
|
||||
annotation["low_freq"],
|
||||
annotation["end_time"],
|
||||
annotation["high_freq"],
|
||||
],
|
||||
),
|
||||
)
|
||||
@ -197,16 +174,62 @@ def annotation_to_sound_event(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=annotation.label,
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(event_key),
|
||||
value=annotation.event,
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(individual_key),
|
||||
value=str(annotation.individual),
|
||||
term=get_term_from_key(individual_key),
|
||||
value=str(annotation["individual"]),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def annotation_to_sound_event_prediction(
|
||||
annotation: Annotation,
|
||||
recording: data.Recording,
|
||||
label_key: str = "class",
|
||||
event_key: str = "event",
|
||||
) -> data.SoundEventPrediction:
|
||||
"""Convert annotation to sound event annotation."""
|
||||
sound_event = data.SoundEvent(
|
||||
uuid=uuid.uuid5(
|
||||
NAMESPACE,
|
||||
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}",
|
||||
),
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
annotation["start_time"],
|
||||
annotation["low_freq"],
|
||||
annotation["end_time"],
|
||||
annotation["high_freq"],
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
return data.SoundEventPrediction(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
score=annotation["det_prob"],
|
||||
tags=[
|
||||
data.PredictedTag(
|
||||
score=annotation["class_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
),
|
||||
data.PredictedTag(
|
||||
score=annotation["det_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -220,24 +243,24 @@ def file_annotation_to_clip(
|
||||
"""Convert file annotation to recording."""
|
||||
audio_dir = audio_dir or Path.cwd()
|
||||
|
||||
full_path = Path(audio_dir) / file_annotation.id
|
||||
full_path = Path(audio_dir) / file_annotation["id"]
|
||||
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"File {full_path} not found.")
|
||||
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
time_expansion=file_annotation["time_exp"],
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=file_annotation.label,
|
||||
value=file_annotation["class_name"],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip"),
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
@ -253,27 +276,28 @@ def file_annotation_to_clip_annotation(
|
||||
) -> data.ClipAnnotation:
|
||||
"""Convert file annotation to clip annotation."""
|
||||
notes = []
|
||||
if file_annotation.notes:
|
||||
notes.append(data.Note(message=file_annotation.notes))
|
||||
if file_annotation["notes"]:
|
||||
notes.append(data.Note(message=file_annotation["notes"]))
|
||||
|
||||
return data.ClipAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip_annotation"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key), value=file_annotation.label
|
||||
term=data.term_from_key(label_key),
|
||||
value=file_annotation["class_name"],
|
||||
)
|
||||
],
|
||||
sound_events=[
|
||||
annotation_to_sound_event(
|
||||
annotation_to_sound_event_annotation(
|
||||
annotation,
|
||||
clip.recording,
|
||||
label_key=label_key,
|
||||
event_key=event_key,
|
||||
individual_key=individual_key,
|
||||
)
|
||||
for annotation in file_annotation.annotation
|
||||
for annotation in file_annotation["annotation"]
|
||||
],
|
||||
)
|
||||
|
||||
@ -284,17 +308,17 @@ def file_annotation_to_annotation_task(
|
||||
) -> data.AnnotationTask:
|
||||
status_badges = []
|
||||
|
||||
if file_annotation.issues:
|
||||
if file_annotation["issues"]:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.rejected)
|
||||
)
|
||||
elif file_annotation.annotated:
|
||||
elif file_annotation["annotated"]:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.completed)
|
||||
)
|
||||
|
||||
return data.AnnotationTask(
|
||||
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
|
||||
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation['id']}_task"),
|
||||
clip=clip,
|
||||
status_badges=status_badges,
|
||||
)
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
from batdetect2.evaluate.evaluate import (
|
||||
compute_error_auc,
|
||||
)
|
||||
from batdetect2.evaluate.match import (
|
||||
match_predictions_and_annotations,
|
||||
match_sound_events_and_raw_predictions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"compute_error_auc",
|
||||
"match_predictions_and_annotations",
|
||||
"match_sound_events_and_raw_predictions",
|
||||
"match_predictions_and_annotations",
|
||||
]
|
||||
|
||||
@ -1,54 +1,133 @@
|
||||
from typing import List
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.evaluate.types import Match
|
||||
from batdetect2.postprocess.types import RawPrediction
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.utils.arrays import iterate_over_array
|
||||
|
||||
|
||||
class BBoxMatchConfig(BaseConfig):
|
||||
match_method: Literal["BBoxIOU"] = "BBoxIOU"
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.01
|
||||
frequency_buffer: float = 1_000
|
||||
|
||||
|
||||
class IntervalMatchConfig(BaseConfig):
|
||||
match_method: Literal["IntervalIOU"] = "IntervalIOU"
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
match_method: Literal["StartTime"] = "StartTime"
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
|
||||
Field(discriminator="match_method"),
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
|
||||
|
||||
|
||||
def prepare_geometry(
|
||||
geometry: data.Geometry, config: MatchConfig
|
||||
) -> data.Geometry:
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
|
||||
if config.match_method == "BBoxIOU":
|
||||
return data.BoundingBox(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
if config.match_method == "IntervalIOU":
|
||||
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||
|
||||
if config.match_method == "StartTime":
|
||||
return data.TimeStamp(coordinates=start_time)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Invalid matching configuration. Unknown match method: {config.match_method}"
|
||||
)
|
||||
|
||||
|
||||
def _get_frequency_buffer(config: MatchConfig) -> float:
|
||||
if config.match_method == "BBoxIOU":
|
||||
return config.frequency_buffer
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _get_affinity_threshold(config: MatchConfig) -> float:
|
||||
if (
|
||||
config.match_method == "BBoxIOU"
|
||||
or config.match_method == "IntervalIOU"
|
||||
):
|
||||
return config.affinity_threshold
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def match_sound_events_and_raw_predictions(
|
||||
sound_events: List[data.SoundEventAnnotation],
|
||||
raw_predictions: List[RawPrediction],
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
raw_predictions: List[BatDetect2Prediction],
|
||||
targets: TargetProtocol,
|
||||
) -> List[Match]:
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
config = config or DEFAULT_MATCH_CONFIG
|
||||
|
||||
target_sound_events = [
|
||||
targets.transform(sound_event_annotation)
|
||||
for sound_event_annotation in sound_events
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if targets.filter(sound_event_annotation)
|
||||
and sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
sound_event_annotation.sound_event.geometry
|
||||
prepare_geometry(
|
||||
sound_event_annotation.sound_event.geometry,
|
||||
config=config,
|
||||
)
|
||||
for sound_event_annotation in target_sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
prepare_geometry(raw_prediction.raw.geometry, config=config)
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for id1, id2, affinity in match_geometries(
|
||||
target_geometries,
|
||||
predicted_geometries,
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=_get_frequency_buffer(config),
|
||||
affinity_threshold=_get_affinity_threshold(config),
|
||||
):
|
||||
target = target_sound_events[id1] if id1 is not None else None
|
||||
prediction = raw_predictions[id2] if id2 is not None else None
|
||||
|
||||
gt_uuid = target.uuid if target is not None else None
|
||||
gt_det = target is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
pred_score = float(prediction.raw.detection_score) if prediction else 0
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
str(class_name): float(score)
|
||||
for class_name, score in iterate_over_array(
|
||||
prediction.class_scores
|
||||
prediction.raw.class_scores
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
@ -56,13 +135,18 @@ def match_sound_events_and_raw_predictions(
|
||||
)
|
||||
|
||||
matches.append(
|
||||
Match(
|
||||
gt_uuid=gt_uuid,
|
||||
MatchEvaluation(
|
||||
match=data.Match(
|
||||
source=None
|
||||
if prediction is None
|
||||
else prediction.sound_event_prediction,
|
||||
target=target,
|
||||
affinity=affinity,
|
||||
),
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
pred_score=pred_score,
|
||||
affinity=affinity,
|
||||
class_scores=class_scores,
|
||||
pred_class_scores=class_scores,
|
||||
)
|
||||
)
|
||||
|
||||
@ -72,7 +156,10 @@ def match_sound_events_and_raw_predictions(
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
clip_prediction: data.ClipPrediction,
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[data.Match]:
|
||||
config = config or DEFAULT_MATCH_CONFIG
|
||||
|
||||
annotated_sound_events = [
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
@ -86,13 +173,13 @@ def match_predictions_and_annotations(
|
||||
]
|
||||
|
||||
annotated_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||
for sound_event in annotated_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
@ -101,6 +188,9 @@ def match_predictions_and_annotations(
|
||||
for id1, id2, affinity in match_geometries(
|
||||
annotated_geometries,
|
||||
predicted_geometries,
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=_get_frequency_buffer(config),
|
||||
affinity_threshold=_get_affinity_threshold(config),
|
||||
):
|
||||
target = annotated_sound_events[id1] if id1 is not None else None
|
||||
source = predicted_sound_events[id2] if id2 is not None else None
|
||||
|
||||
@ -4,13 +4,13 @@ import pandas as pd
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
|
||||
__all__ = ["DetectionAveragePrecision"]
|
||||
|
||||
|
||||
class DetectionAveragePrecision(MetricsProtocol):
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[(match.gt_det, match.pred_score) for match in matches]
|
||||
)
|
||||
@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
self.class_names = class_names
|
||||
self.per_class = per_class
|
||||
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true = label_binarize(
|
||||
[
|
||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||
@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
y_pred = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
name: match.class_scores.get(name, 0)
|
||||
name: match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
}
|
||||
for match in matches
|
||||
@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str]):
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true = [
|
||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||
for match in matches
|
||||
@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol):
|
||||
y_pred = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
name: match.class_scores.get(name, 0)
|
||||
name: match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
}
|
||||
for match in matches
|
||||
|
||||
@ -1,22 +1,40 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"MetricsProtocol",
|
||||
"Match",
|
||||
"MatchEvaluation",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
gt_uuid: Optional[UUID]
|
||||
class MatchEvaluation:
|
||||
match: data.Match
|
||||
|
||||
gt_det: bool
|
||||
gt_class: Optional[str]
|
||||
|
||||
pred_score: float
|
||||
affinity: float
|
||||
class_scores: Dict[str, float]
|
||||
pred_class_scores: Dict[str, float]
|
||||
|
||||
@property
|
||||
def pred_class(self) -> Optional[str]:
|
||||
if not self.pred_class_scores:
|
||||
return None
|
||||
|
||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||
|
||||
@property
|
||||
def pred_class_score(self) -> float:
|
||||
pred_class = self.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
return 0
|
||||
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
|
||||
class MetricsProtocol(Protocol):
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
||||
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.matches import (
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
plot_false_positive_match,
|
||||
plot_matches,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"plot_clip",
|
||||
"plot_clip_annotation",
|
||||
"plot_clip_prediction",
|
||||
"plot_matches",
|
||||
"plot_false_positive_match",
|
||||
"plot_true_positive_match",
|
||||
"plot_false_negative_match",
|
||||
"plot_cross_trigger_match",
|
||||
]
|
||||
49
src/batdetect2/plotting/clip_annotations.py
Normal file
49
src/batdetect2/plotting/clip_annotations.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_annotation",
|
||||
]
|
||||
|
||||
|
||||
def plot_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
cmap: str = "gray",
|
||||
alpha: float = 1,
|
||||
linewidth: float = 1,
|
||||
fill: bool = False,
|
||||
) -> Axes:
|
||||
ax = plot_clip(
|
||||
clip_annotation.clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotations(
|
||||
clip_annotation.sound_events,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
facecolor="none" if not fill else None,
|
||||
)
|
||||
return ax
|
||||
141
src/batdetect2/plotting/clip_predictions.py
Normal file
141
src/batdetect2/plotting/clip_predictions.py
Normal file
@ -0,0 +1,141 @@
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
from soundevent.geometry.operations import Positions, get_geometry_point
|
||||
from soundevent.plot.common import create_axes
|
||||
from soundevent.plot.geometries import plot_geometry
|
||||
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_prediction",
|
||||
]
|
||||
|
||||
|
||||
def plot_clip_prediction(
|
||||
clip_prediction: data.ClipPrediction,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_legend: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
linewidth: float = 1,
|
||||
fill: bool = False,
|
||||
) -> Axes:
|
||||
ax = plot_clip(
|
||||
clip_prediction.clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot_predictions(
|
||||
clip_prediction.sound_events,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=False,
|
||||
linewidth=linewidth,
|
||||
facecolor="none" if not fill else None,
|
||||
legend=add_legend,
|
||||
)
|
||||
return ax
|
||||
|
||||
|
||||
def plot_predictions(
|
||||
predictions: Iterable[data.SoundEventPrediction],
|
||||
ax: Optional[Axes] = None,
|
||||
position: Positions = "top-right",
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
time_offset: float = 0.001,
|
||||
freq_offset: float = 1000,
|
||||
legend: bool = True,
|
||||
max_alpha: float = 0.5,
|
||||
color: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot an prediction."""
|
||||
if ax is None:
|
||||
ax = create_axes(**kwargs)
|
||||
|
||||
if color_mapper is None:
|
||||
color_mapper = TagColorMapper()
|
||||
|
||||
for prediction in predictions:
|
||||
ax = plot_prediction(
|
||||
prediction,
|
||||
ax=ax,
|
||||
position=position,
|
||||
color_mapper=color_mapper,
|
||||
time_offset=time_offset,
|
||||
freq_offset=freq_offset,
|
||||
max_alpha=max_alpha,
|
||||
color=color,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legend:
|
||||
ax = add_tags_legend(ax, color_mapper)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_prediction(
|
||||
prediction: data.SoundEventPrediction,
|
||||
ax: Optional[Axes] = None,
|
||||
position: Positions = "top-right",
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
time_offset: float = 0.001,
|
||||
freq_offset: float = 1000,
|
||||
max_alpha: float = 0.5,
|
||||
alpha: Optional[float] = None,
|
||||
color: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Axes:
|
||||
"""Plot an annotation."""
|
||||
geometry = prediction.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
raise ValueError("Annotation does not have a geometry.")
|
||||
|
||||
if ax is None:
|
||||
ax = create_axes(**kwargs)
|
||||
|
||||
if color_mapper is None:
|
||||
color_mapper = TagColorMapper()
|
||||
|
||||
if alpha is None:
|
||||
alpha = min(prediction.score * max_alpha, 1)
|
||||
|
||||
ax = plot_geometry(
|
||||
geometry,
|
||||
ax=ax,
|
||||
color=color,
|
||||
alpha=alpha,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
x, y = get_geometry_point(geometry, position=position)
|
||||
|
||||
for index, tag in enumerate(prediction.tags):
|
||||
color = color_mapper.get_color(tag.tag)
|
||||
ax = plot_tag(
|
||||
time=x + time_offset,
|
||||
frequency=y - index * freq_offset,
|
||||
color=color,
|
||||
ax=ax,
|
||||
alpha=min(tag.score, prediction.score),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return ax
|
||||
44
src/batdetect2/plotting/clips.py
Normal file
44
src/batdetect2/plotting/clips.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
get_default_preprocessor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"plot_clip",
|
||||
]
|
||||
|
||||
|
||||
def plot_clip(
|
||||
clip: data.Clip,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
) -> Axes:
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = get_default_preprocessor()
|
||||
|
||||
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
|
||||
|
||||
spec.plot( # type: ignore
|
||||
ax=ax,
|
||||
add_colorbar=add_colorbar,
|
||||
cmap=spec_cmap,
|
||||
add_labels=add_labels,
|
||||
vmin=spec.min().item(),
|
||||
vmax=spec.max().item(),
|
||||
)
|
||||
|
||||
return ax
|
||||
160
src/batdetect2/plotting/evaluation.py
Normal file
160
src/batdetect2/plotting/evaluation.py
Normal file
@ -0,0 +1,160 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from batdetect2 import plotting
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassExamples:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def plot_examples(
|
||||
matches: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
):
|
||||
class_examples = defaultdict(ClassExamples)
|
||||
|
||||
for match in matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[gt_class].cross_triggers.append(match)
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
for class_name, examples in class_examples.items():
|
||||
true_positives = get_binned_sample(
|
||||
examples.true_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_positives = get_binned_sample(
|
||||
examples.false_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_negatives = random.sample(
|
||||
examples.false_negatives,
|
||||
k=min(n_examples, len(examples.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers = get_binned_sample(
|
||||
examples.cross_triggers,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
fig = plot_class_examples(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=preprocessor,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
yield class_name, fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_class_examples(
|
||||
true_positives: List[MatchEvaluation],
|
||||
false_positives: List[MatchEvaluation],
|
||||
false_negatives: List[MatchEvaluation],
|
||||
cross_triggers: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
):
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
for index, match in enumerate(true_positives):
|
||||
ax = plt.subplot(4, n_examples, index + 1)
|
||||
try:
|
||||
plotting.plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives):
|
||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives):
|
||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers):
|
||||
ax = plt.subplot(4, n_examples, 4 * n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[
|
||||
(index, match.pred_class_scores[pred_class])
|
||||
for index, match in enumerate(matches)
|
||||
if (pred_class := match.pred_class) is not None
|
||||
]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False)
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").apply(lambda x: x.sample(1))
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
417
src/batdetect2/plotting/matches.py
Normal file
417
src/batdetect2/plotting/matches.py
Normal file
@ -0,0 +1,417 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
get_default_preprocessor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"plot_matches",
|
||||
"plot_false_positive_match",
|
||||
"plot_true_positive_match",
|
||||
"plot_false_negative_match",
|
||||
"plot_cross_trigger_match",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_DURATION = 0.05
|
||||
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
||||
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
||||
DEFAULT_TRUE_POSITIVE_COLOR = "green"
|
||||
DEFAULT_CROSS_TRIGGER_COLOR = "orange"
|
||||
DEFAULT_ANNOTATION_LINE_STYLE = "-"
|
||||
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||
|
||||
|
||||
def plot_matches(
|
||||
matches: List[data.Match],
|
||||
clip: data.Clip,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
if preprocessor is None:
|
||||
preprocessor = get_default_preprocessor()
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
if color_mapper is None:
|
||||
color_mapper = TagColorMapper()
|
||||
|
||||
for match in matches:
|
||||
if match.source is None and match.target is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_negative_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
elif match.target is None and match.source is not None:
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
elif match.target is not None and match.source is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_false_positive_match(
|
||||
match: MatchEvaluation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
time_offset: float = 0,
|
||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
) -> Axes:
|
||||
assert match.match.source is not None
|
||||
assert match.match.target is None
|
||||
sound_event = match.match.source.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
sound_event.recording.duration,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot_prediction(
|
||||
match.match.source,
|
||||
ax=ax,
|
||||
time_offset=time_offset,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_false_negative_match(
|
||||
match: MatchEvaluation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
) -> Axes:
|
||||
assert match.match.source is None
|
||||
assert match.match.target is not None
|
||||
sound_event = match.match.target.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
match.match.target,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Negative \nClass: {match.gt_class} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_true_positive_match(
|
||||
match: MatchEvaluation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
assert match.match.source is not None
|
||||
assert match.match.target is not None
|
||||
sound_event = match.match.target.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
match.match.target,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
|
||||
plot_prediction(
|
||||
match.match.source,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_cross_trigger_match(
|
||||
match: MatchEvaluation,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
assert match.match.source is not None
|
||||
assert match.match.target is not None
|
||||
sound_event = match.match.source.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
match.match.target,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
|
||||
plot_prediction(
|
||||
match.match.source,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
@ -39,6 +39,7 @@ from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
)
|
||||
@ -61,7 +62,11 @@ from batdetect2.postprocess.remapping import (
|
||||
features_to_xarray,
|
||||
sizes_to_xarray,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
|
||||
from batdetect2.postprocess.types import (
|
||||
BatDetect2Prediction,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
@ -537,6 +542,27 @@ class Postprocessor(PostprocessorProtocol):
|
||||
for dataset in detection_datasets
|
||||
]
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
[
|
||||
BatDetect2Prediction(
|
||||
raw=raw,
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
),
|
||||
)
|
||||
for raw in predictions
|
||||
]
|
||||
for predictions, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
def get_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[data.ClipPrediction]:
|
||||
|
||||
@ -95,7 +95,6 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
for det_num in range(detection_dataset.sizes["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
|
||||
# TODO: Maybe clean this up
|
||||
highest_scoring_class = det_info.coords["category"][
|
||||
det_info["classes"].argmax()
|
||||
].item()
|
||||
|
||||
@ -11,6 +11,7 @@ modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import xarray as xr
|
||||
@ -75,6 +76,12 @@ class RawPrediction(NamedTuple):
|
||||
features: xr.DataArray
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatDetect2Prediction:
|
||||
raw: RawPrediction
|
||||
sound_event_prediction: data.SoundEventPrediction
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||
|
||||
@ -254,6 +261,10 @@ class PostprocessorProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[BatDetect2Prediction]]: ...
|
||||
|
||||
def get_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
|
||||
@ -86,6 +86,7 @@ __all__ = [
|
||||
"build_spectrogram_builder",
|
||||
"get_spectrogram_resolution",
|
||||
"load_preprocessing_config",
|
||||
"get_default_preprocessor",
|
||||
]
|
||||
|
||||
|
||||
@ -451,3 +452,7 @@ def build_preprocessor(
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
def get_default_preprocessor():
|
||||
return build_preprocessor()
|
||||
|
||||
@ -2,11 +2,13 @@ from typing import List
|
||||
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.plotting.evaluation import plot_examples
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -14,25 +16,55 @@ from batdetect2.train.types import ModelOutput
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self, metrics: List[MetricsProtocol]):
|
||||
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
|
||||
super().__init__()
|
||||
|
||||
if len(metrics) == 0:
|
||||
raise ValueError("At least one metric needs to be provided")
|
||||
|
||||
self.matches: List[Match] = []
|
||||
self.matches: List[MatchEvaluation] = []
|
||||
self.metrics = metrics
|
||||
self.plot = plot
|
||||
|
||||
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
return dataset
|
||||
|
||||
def plot_examples(self, pl_module: LightningModule):
|
||||
if not isinstance(pl_module.logger, TensorBoardLogger):
|
||||
return
|
||||
|
||||
for class_name, fig in plot_examples(
|
||||
self.matches,
|
||||
preprocessor=pl_module.preprocessor,
|
||||
n_examples=5,
|
||||
):
|
||||
pl_module.logger.experiment.add_figure(
|
||||
f"{class_name}/examples",
|
||||
fig,
|
||||
pl_module.global_step,
|
||||
)
|
||||
|
||||
def log_metrics(self, pl_module: LightningModule):
|
||||
metrics = {}
|
||||
for metric in self.metrics:
|
||||
metrics.update(metric(self.matches).items())
|
||||
|
||||
pl_module.log_dict(metrics)
|
||||
|
||||
def on_validation_epoch_end(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
metrics = {}
|
||||
for metric in self.metrics:
|
||||
metrics.update(metric(self.matches).items())
|
||||
self.log_metrics(pl_module)
|
||||
|
||||
if self.plot:
|
||||
self.plot_examples(pl_module)
|
||||
|
||||
pl_module.log_dict(metrics)
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
def on_validation_epoch_start(
|
||||
@ -52,11 +84,7 @@ class ValidationMetrics(Callback):
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
_get_subclip(
|
||||
@ -74,7 +102,7 @@ class ValidationMetrics(Callback):
|
||||
|
||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||
|
||||
raw_predictions = pl_module.postprocessor.get_raw_predictions(
|
||||
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
|
||||
outputs,
|
||||
clips,
|
||||
)
|
||||
@ -84,7 +112,7 @@ class ValidationMetrics(Callback):
|
||||
):
|
||||
self.matches.extend(
|
||||
match_sound_events_and_raw_predictions(
|
||||
sound_events=clip_annotation.sound_events,
|
||||
clip_annotation=clip_annotation,
|
||||
raw_predictions=clip_predictions,
|
||||
targets=pl_module.targets,
|
||||
)
|
||||
|
||||
@ -48,20 +48,19 @@ class TrainingModule(L.LightningModule):
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/train", losses.total, logger=True)
|
||||
self.log("size_loss/train", losses.total, logger=True)
|
||||
self.log("classification_loss/train", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step( # type: ignore
|
||||
self, batch: TrainExample, batch_idx: int
|
||||
self,
|
||||
batch: TrainExample,
|
||||
batch_idx: int,
|
||||
) -> ModelOutput:
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
||||
self.log("detection_loss/val", losses.total, logger=True)
|
||||
self.log("size_loss/val", losses.total, logger=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user