Compare commits

..

No commits in common. "2341f822a7522e780a44ae37ced59615160c8351" and "bf6f52a65d5de8d64fe4394857c037d48eba0e54" have entirely different histories.

24 changed files with 140 additions and 1182 deletions

View File

@ -1,14 +1,3 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio
targets: targets:
classes: classes:
classes: classes:
@ -57,7 +46,7 @@ preprocess:
max_freq: 120000 max_freq: 120000
min_freq: 10000 min_freq: 10000
pcen: pcen:
time_constant: 0.1 time_constant: 0.4
gain: 0.98 gain: 0.98
bias: 2 bias: 2
power: 0.5 power: 0.5

View File

@ -0,0 +1,10 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0", "torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0", "torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.6.5", "soundevent[audio,geometry,plot]>=2.5.0",
"click>=8.1.7", "click>=8.1.7",
"netcdf4>=1.6.5", "netcdf4>=1.6.5",
"tqdm>=4.66.2", "tqdm>=4.66.2",
@ -66,7 +66,10 @@ batdetect2 = "batdetect2.cli:cli"
[dependency-groups] [dependency-groups]
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"] jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"] marimo = [
"marimo>=0.12.2",
"pyarrow>=20.0.0",
]
dev = [ dev = [
"debugpy>=1.8.8", "debugpy>=1.8.8",
"hypothesis>=6.118.7", "hypothesis>=6.118.7",
@ -74,7 +77,7 @@ dev = [
"ruff>=0.7.3", "ruff>=0.7.3",
"ipykernel>=6.29.4", "ipykernel>=6.29.4",
"setuptools>=69.5.1", "setuptools>=69.5.1",
"basedpyright>=1.31.0", "pyright>=1.1.399",
"myst-parser>=3.0.1", "myst-parser>=3.0.1",
"sphinx-autobuild>=2024.10.3", "sphinx-autobuild>=2024.10.3",
"numpydoc>=1.8.0", "numpydoc>=1.8.0",
@ -85,8 +88,12 @@ dev = [
"ty>=0.0.1a12", "ty>=0.0.1a12",
"rust-just>=1.40.0", "rust-just>=1.40.0",
] ]
dvclive = ["dvclive>=3.48.2"] dvclive = [
mlflow = ["mlflow>=3.1.1"] "dvclive>=3.48.2",
]
mlflow = [
"mlflow>=3.1.1",
]
[tool.ruff] [tool.ruff]
line-length = 79 line-length = 79

View File

@ -1,10 +1,5 @@
import logging import logging
from loguru import logger
logger.disable("batdetect2")
numba_logger = logging.getLogger("numba") numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING) numba_logger.setLevel(logging.WARNING)

View File

@ -1,15 +0,0 @@
from batdetect2.compat.data import (
annotation_to_sound_event_annotation,
annotation_to_sound_event_prediction,
convert_to_annotation_group,
file_annotation_to_clip_annotation,
load_file_annotation,
)
__all__ = [
"annotation_to_sound_event_annotation",
"annotation_to_sound_event_prediction",
"convert_to_annotation_group",
"file_annotation_to_clip_annotation",
"load_file_annotation",
]

View File

@ -1,30 +1,24 @@
"""Compatibility functions between old and new data structures.""" """Compatibility functions between old and new data structures."""
import json
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
from pydantic import BaseModel, Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper from soundevent.types import ClassMapper
from batdetect2.targets.terms import get_term_from_key from batdetect2 import types
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
FileAnnotation,
)
PathLike = Union[Path, str, os.PathLike] PathLike = Union[Path, str, os.PathLike]
__all__ = [ __all__ = [
"convert_to_annotation_group", "convert_to_annotation_group",
"load_file_annotation", "load_file_annotation",
"annotation_to_sound_event_annotation", "annotation_to_sound_event",
"annotation_to_sound_event_prediction",
] ]
SPECIES_TAG_KEY = "species" SPECIES_TAG_KEY = "species"
@ -43,7 +37,7 @@ IndividualFn = Callable[[data.SoundEventAnnotation], int]
def get_recording_class_name(recording: data.Recording) -> str: def get_recording_class_name(recording: data.Recording) -> str:
"""Get the class name for a recording.""" """Get the class name for a recording."""
tag = data.find_tag(recording.tags, label=SPECIES_TAG_KEY) tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
if tag is None: if tag is None:
return UNKNOWN_CLASS return UNKNOWN_CLASS
return tag.value return tag.value
@ -65,7 +59,7 @@ def convert_to_annotation_group(
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT, event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
class_fn: ClassFn = lambda _: 0, class_fn: ClassFn = lambda _: 0,
individual_fn: IndividualFn = lambda _: 0, individual_fn: IndividualFn = lambda _: 0,
) -> AudioLoaderAnnotationGroup: ) -> types.AudioLoaderAnnotationGroup:
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup.""" """Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
recording = annotation.clip.recording recording = annotation.clip.recording
@ -77,7 +71,7 @@ def convert_to_annotation_group(
x_inds = [] x_inds = []
y_inds = [] y_inds = []
individual_ids = [] individual_ids = []
annotations: List[Annotation] = [] annotations: List[types.Annotation] = []
class_id_file = class_fn(recording) class_id_file = class_fn(recording)
for sound_event in annotation.sound_events: for sound_event in annotation.sound_events:
@ -139,13 +133,42 @@ def convert_to_annotation_group(
} }
class Annotation(BaseModel):
"""Annotation class to hold batdetect annotations."""
label: str = Field(alias="class")
event: str
individual: int = 0
start_time: float
end_time: float
low_freq: float
high_freq: float
class FileAnnotation(BaseModel):
"""FileAnnotation class to hold batdetect annotations for a file."""
id: str
duration: float
time_exp: float = 1
label: str = Field(alias="class_name")
annotation: List[Annotation]
annotated: bool = False
issues: bool = False
notes: str = ""
def load_file_annotation(path: PathLike) -> FileAnnotation: def load_file_annotation(path: PathLike) -> FileAnnotation:
"""Load annotation from batdetect format.""" """Load annotation from batdetect format."""
path = Path(path) path = Path(path)
return json.loads(path.read_text()) return FileAnnotation.model_validate_json(path.read_text())
def annotation_to_sound_event_annotation( def annotation_to_sound_event(
annotation: Annotation, annotation: Annotation,
recording: data.Recording, recording: data.Recording,
label_key: str = "class", label_key: str = "class",
@ -156,15 +179,15 @@ def annotation_to_sound_event_annotation(
sound_event = data.SoundEvent( sound_event = data.SoundEvent(
uuid=uuid.uuid5( uuid=uuid.uuid5(
NAMESPACE, NAMESPACE,
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}", f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
), ),
recording=recording, recording=recording,
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=[ coordinates=[
annotation["start_time"], annotation.start_time,
annotation["low_freq"], annotation.low_freq,
annotation["end_time"], annotation.end_time,
annotation["high_freq"], annotation.high_freq,
], ],
), ),
) )
@ -174,62 +197,16 @@ def annotation_to_sound_event_annotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag( data.Tag(
term=get_term_from_key(label_key), term=data.term_from_key(label_key),
value=annotation["class"], value=annotation.label,
), ),
data.Tag( data.Tag(
term=get_term_from_key(event_key), term=data.term_from_key(event_key),
value=annotation["event"], value=annotation.event,
), ),
data.Tag( data.Tag(
term=get_term_from_key(individual_key), term=data.term_from_key(individual_key),
value=str(annotation["individual"]), value=str(annotation.individual),
),
],
)
def annotation_to_sound_event_prediction(
annotation: Annotation,
recording: data.Recording,
label_key: str = "class",
event_key: str = "event",
) -> data.SoundEventPrediction:
"""Convert annotation to sound event annotation."""
sound_event = data.SoundEvent(
uuid=uuid.uuid5(
NAMESPACE,
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}",
),
recording=recording,
geometry=data.BoundingBox(
coordinates=[
annotation["start_time"],
annotation["low_freq"],
annotation["end_time"],
annotation["high_freq"],
],
),
)
return data.SoundEventPrediction(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
score=annotation["det_prob"],
tags=[
data.PredictedTag(
score=annotation["class_prob"],
tag=data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
),
data.PredictedTag(
score=annotation["det_prob"],
tag=data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
), ),
], ],
) )
@ -243,24 +220,24 @@ def file_annotation_to_clip(
"""Convert file annotation to recording.""" """Convert file annotation to recording."""
audio_dir = audio_dir or Path.cwd() audio_dir = audio_dir or Path.cwd()
full_path = Path(audio_dir) / file_annotation["id"] full_path = Path(audio_dir) / file_annotation.id
if not full_path.exists(): if not full_path.exists():
raise FileNotFoundError(f"File {full_path} not found.") raise FileNotFoundError(f"File {full_path} not found.")
recording = data.Recording.from_file( recording = data.Recording.from_file(
full_path, full_path,
time_expansion=file_annotation["time_exp"], time_expansion=file_annotation.time_exp,
tags=[ tags=[
data.Tag( data.Tag(
term=data.term_from_key(label_key), term=data.term_from_key(label_key),
value=file_annotation["class_name"], value=file_annotation.label,
) )
], ],
) )
return data.Clip( return data.Clip(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip"), uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
recording=recording, recording=recording,
start_time=0, start_time=0,
end_time=recording.duration, end_time=recording.duration,
@ -276,28 +253,27 @@ def file_annotation_to_clip_annotation(
) -> data.ClipAnnotation: ) -> data.ClipAnnotation:
"""Convert file annotation to clip annotation.""" """Convert file annotation to clip annotation."""
notes = [] notes = []
if file_annotation["notes"]: if file_annotation.notes:
notes.append(data.Note(message=file_annotation["notes"])) notes.append(data.Note(message=file_annotation.notes))
return data.ClipAnnotation( return data.ClipAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip_annotation"), uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
clip=clip, clip=clip,
notes=notes, notes=notes,
tags=[ tags=[
data.Tag( data.Tag(
term=data.term_from_key(label_key), term=data.term_from_key(label_key), value=file_annotation.label
value=file_annotation["class_name"],
) )
], ],
sound_events=[ sound_events=[
annotation_to_sound_event_annotation( annotation_to_sound_event(
annotation, annotation,
clip.recording, clip.recording,
label_key=label_key, label_key=label_key,
event_key=event_key, event_key=event_key,
individual_key=individual_key, individual_key=individual_key,
) )
for annotation in file_annotation["annotation"] for annotation in file_annotation.annotation
], ],
) )
@ -308,17 +284,17 @@ def file_annotation_to_annotation_task(
) -> data.AnnotationTask: ) -> data.AnnotationTask:
status_badges = [] status_badges = []
if file_annotation["issues"]: if file_annotation.issues:
status_badges.append( status_badges.append(
data.StatusBadge(state=data.AnnotationState.rejected) data.StatusBadge(state=data.AnnotationState.rejected)
) )
elif file_annotation["annotated"]: elif file_annotation.annotated:
status_badges.append( status_badges.append(
data.StatusBadge(state=data.AnnotationState.completed) data.StatusBadge(state=data.AnnotationState.completed)
) )
return data.AnnotationTask( return data.AnnotationTask(
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation['id']}_task"), uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
clip=clip, clip=clip,
status_badges=status_badges, status_badges=status_badges,
) )

View File

@ -1,9 +1,13 @@
from batdetect2.evaluate.evaluate import (
compute_error_auc,
)
from batdetect2.evaluate.match import ( from batdetect2.evaluate.match import (
match_predictions_and_annotations, match_predictions_and_annotations,
match_sound_events_and_raw_predictions, match_sound_events_and_raw_predictions,
) )
__all__ = [ __all__ = [
"match_sound_events_and_raw_predictions", "compute_error_auc",
"match_predictions_and_annotations", "match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
] ]

View File

@ -1,133 +1,54 @@
from typing import Annotated, List, Literal, Optional, Union from typing import List
from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_geometries from soundevent.evaluation import match_geometries
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.evaluate.types import Match
from batdetect2.evaluate.types import MatchEvaluation from batdetect2.postprocess.types import RawPrediction
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig):
match_method: Literal["BBoxIOU"] = "BBoxIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
frequency_buffer: float = 1_000
class IntervalMatchConfig(BaseConfig):
match_method: Literal["IntervalIOU"] = "IntervalIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
class StartTimeMatchConfig(BaseConfig):
match_method: Literal["StartTime"] = "StartTime"
time_buffer: float = 0.01
MatchConfig = Annotated[
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
Field(discriminator="match_method"),
]
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
def prepare_geometry(
geometry: data.Geometry, config: MatchConfig
) -> data.Geometry:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
if config.match_method == "BBoxIOU":
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if config.match_method == "IntervalIOU":
return data.TimeInterval(coordinates=[start_time, end_time])
if config.match_method == "StartTime":
return data.TimeStamp(coordinates=start_time)
raise NotImplementedError(
f"Invalid matching configuration. Unknown match method: {config.match_method}"
)
def _get_frequency_buffer(config: MatchConfig) -> float:
if config.match_method == "BBoxIOU":
return config.frequency_buffer
return 0
def _get_affinity_threshold(config: MatchConfig) -> float:
if (
config.match_method == "BBoxIOU"
or config.match_method == "IntervalIOU"
):
return config.affinity_threshold
return 0
def match_sound_events_and_raw_predictions( def match_sound_events_and_raw_predictions(
clip_annotation: data.ClipAnnotation, sound_events: List[data.SoundEventAnnotation],
raw_predictions: List[BatDetect2Prediction], raw_predictions: List[RawPrediction],
targets: TargetProtocol, targets: TargetProtocol,
config: Optional[MatchConfig] = None, ) -> List[Match]:
) -> List[MatchEvaluation]:
config = config or DEFAULT_MATCH_CONFIG
target_sound_events = [ target_sound_events = [
targets.transform(sound_event_annotation) targets.transform(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events for sound_event_annotation in sound_events
if targets.filter(sound_event_annotation) if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None and sound_event_annotation.sound_event.geometry is not None
] ]
target_geometries: List[data.Geometry] = [ # type: ignore target_geometries: List[data.Geometry] = [ # type: ignore
prepare_geometry( sound_event_annotation.sound_event.geometry
sound_event_annotation.sound_event.geometry,
config=config,
)
for sound_event_annotation in target_sound_events for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
] ]
predicted_geometries = [ predicted_geometries = [
prepare_geometry(raw_prediction.raw.geometry, config=config) raw_prediction.geometry for raw_prediction in raw_predictions
for raw_prediction in raw_predictions
] ]
matches = [] matches = []
for id1, id2, affinity in match_geometries( for id1, id2, affinity in match_geometries(
target_geometries, target_geometries,
predicted_geometries, predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
): ):
target = target_sound_events[id1] if id1 is not None else None target = target_sound_events[id1] if id1 is not None else None
prediction = raw_predictions[id2] if id2 is not None else None prediction = raw_predictions[id2] if id2 is not None else None
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.raw.detection_score) if prediction else 0 pred_score = float(prediction.detection_score) if prediction else 0
class_scores = ( class_scores = (
{ {
str(class_name): float(score) str(class_name): float(score)
for class_name, score in iterate_over_array( for class_name, score in iterate_over_array(
prediction.raw.class_scores prediction.class_scores
) )
} }
if prediction is not None if prediction is not None
@ -135,18 +56,13 @@ def match_sound_events_and_raw_predictions(
) )
matches.append( matches.append(
MatchEvaluation( Match(
match=data.Match( gt_uuid=gt_uuid,
source=None
if prediction is None
else prediction.sound_event_prediction,
target=target,
affinity=affinity,
),
gt_det=gt_det, gt_det=gt_det,
gt_class=gt_class, gt_class=gt_class,
pred_score=pred_score, pred_score=pred_score,
pred_class_scores=class_scores, affinity=affinity,
class_scores=class_scores,
) )
) )
@ -156,10 +72,7 @@ def match_sound_events_and_raw_predictions(
def match_predictions_and_annotations( def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction, clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]: ) -> List[data.Match]:
config = config or DEFAULT_MATCH_CONFIG
annotated_sound_events = [ annotated_sound_events = [
sound_event_annotation sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events for sound_event_annotation in clip_annotation.sound_events
@ -173,13 +86,13 @@ def match_predictions_and_annotations(
] ]
annotated_geometries: List[data.Geometry] = [ annotated_geometries: List[data.Geometry] = [
prepare_geometry(sound_event.sound_event.geometry, config=config) sound_event.sound_event.geometry
for sound_event in annotated_sound_events for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None if sound_event.sound_event.geometry is not None
] ]
predicted_geometries: List[data.Geometry] = [ predicted_geometries: List[data.Geometry] = [
prepare_geometry(sound_event.sound_event.geometry, config=config) sound_event.sound_event.geometry
for sound_event in predicted_sound_events for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None if sound_event.sound_event.geometry is not None
] ]
@ -188,9 +101,6 @@ def match_predictions_and_annotations(
for id1, id2, affinity in match_geometries( for id1, id2, affinity in match_geometries(
annotated_geometries, annotated_geometries,
predicted_geometries, predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
): ):
target = annotated_sound_events[id1] if id1 is not None else None target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None source = predicted_sound_events[id2] if id2 is not None else None

View File

@ -4,13 +4,13 @@ import pandas as pd
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.evaluate.types import Match, MetricsProtocol
__all__ = ["DetectionAveragePrecision"] __all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol): class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true, y_score = zip( y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches] *[(match.gt_det, match.pred_score) for match in matches]
) )
@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
self.class_names = class_names self.class_names = class_names
self.per_class = per_class self.per_class = per_class
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = label_binarize( y_true = label_binarize(
[ [
match.gt_class if match.gt_class is not None else "__NONE__" match.gt_class if match.gt_class is not None else "__NONE__"
@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
y_pred = pd.DataFrame( y_pred = pd.DataFrame(
[ [
{ {
name: match.pred_class_scores.get(name, 0) name: match.class_scores.get(name, 0)
for name in self.class_names for name in self.class_names
} }
for match in matches for match in matches
@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]): def __init__(self, class_names: List[str]):
self.class_names = class_names self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = [ y_true = [
match.gt_class if match.gt_class is not None else "__NONE__" match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches for match in matches
@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol):
y_pred = pd.DataFrame( y_pred = pd.DataFrame(
[ [
{ {
name: match.pred_class_scores.get(name, 0) name: match.class_scores.get(name, 0)
for name in self.class_names for name in self.class_names
} }
for match in matches for match in matches

View File

@ -1,40 +1,22 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol from typing import Dict, List, Optional, Protocol
from uuid import UUID
from soundevent import data
__all__ = [ __all__ = [
"MetricsProtocol", "MetricsProtocol",
"MatchEvaluation", "Match",
] ]
@dataclass @dataclass
class MatchEvaluation: class Match:
match: data.Match gt_uuid: Optional[UUID]
gt_det: bool gt_det: bool
gt_class: Optional[str] gt_class: Optional[str]
pred_score: float pred_score: float
pred_class_scores: Dict[str, float] affinity: float
class_scores: Dict[str, float]
@property
def pred_class(self) -> Optional[str]:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def pred_class_score(self) -> float:
pred_class = self.pred_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
class MetricsProtocol(Protocol): class MetricsProtocol(Protocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ... def __call__(self, matches: List[Match]) -> Dict[str, float]: ...

View File

@ -1,21 +0,0 @@
from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
__all__ = [
"plot_clip",
"plot_clip_annotation",
"plot_clip_prediction",
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
]

View File

@ -1,49 +0,0 @@
from typing import Optional, Tuple
from matplotlib.axes import Axes
from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_annotation",
]
def plot_clip_annotation(
clip_annotation: data.ClipAnnotation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
cmap: str = "gray",
alpha: float = 1,
linewidth: float = 1,
fill: bool = False,
) -> Axes:
ax = plot_clip(
clip_annotation.clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=cmap,
)
plot.plot_annotations(
clip_annotation.sound_events,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
alpha=alpha,
linewidth=linewidth,
facecolor="none" if not fill else None,
)
return ax

View File

@ -1,141 +0,0 @@
from typing import Iterable, Optional, Tuple
from matplotlib.axes import Axes
from soundevent import data
from soundevent.geometry.operations import Positions, get_geometry_point
from soundevent.plot.common import create_axes
from soundevent.plot.geometries import plot_geometry
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_prediction",
]
def plot_clip_prediction(
clip_prediction: data.ClipPrediction,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_legend: bool = False,
spec_cmap: str = "gray",
linewidth: float = 1,
fill: bool = False,
) -> Axes:
ax = plot_clip(
clip_prediction.clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot_predictions(
clip_prediction.sound_events,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=False,
linewidth=linewidth,
facecolor="none" if not fill else None,
legend=add_legend,
)
return ax
def plot_predictions(
predictions: Iterable[data.SoundEventPrediction],
ax: Optional[Axes] = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
legend: bool = True,
max_alpha: float = 0.5,
color: Optional[str] = None,
**kwargs,
):
"""Plot an prediction."""
if ax is None:
ax = create_axes(**kwargs)
if color_mapper is None:
color_mapper = TagColorMapper()
for prediction in predictions:
ax = plot_prediction(
prediction,
ax=ax,
position=position,
color_mapper=color_mapper,
time_offset=time_offset,
freq_offset=freq_offset,
max_alpha=max_alpha,
color=color,
**kwargs,
)
if legend:
ax = add_tags_legend(ax, color_mapper)
return ax
def plot_prediction(
prediction: data.SoundEventPrediction,
ax: Optional[Axes] = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
max_alpha: float = 0.5,
alpha: Optional[float] = None,
color: Optional[str] = None,
**kwargs,
) -> Axes:
"""Plot an annotation."""
geometry = prediction.sound_event.geometry
if geometry is None:
raise ValueError("Annotation does not have a geometry.")
if ax is None:
ax = create_axes(**kwargs)
if color_mapper is None:
color_mapper = TagColorMapper()
if alpha is None:
alpha = min(prediction.score * max_alpha, 1)
ax = plot_geometry(
geometry,
ax=ax,
color=color,
alpha=alpha,
**kwargs,
)
x, y = get_geometry_point(geometry, position=position)
for index, tag in enumerate(prediction.tags):
color = color_mapper.get_color(tag.tag)
ax = plot_tag(
time=x + time_offset,
frequency=y - index * freq_offset,
color=color,
ax=ax,
alpha=min(tag.score, prediction.score),
**kwargs,
)
return ax

View File

@ -1,44 +0,0 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
__all__ = [
"plot_clip",
]
def plot_clip(
clip: data.Clip,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
spec_cmap: str = "gray",
) -> Axes:
if ax is None:
_, ax = plt.subplots(figsize=figsize)
if preprocessor is None:
preprocessor = get_default_preprocessor()
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
spec.plot( # type: ignore
ax=ax,
add_colorbar=add_colorbar,
cmap=spec_cmap,
add_labels=add_labels,
vmin=spec.min().item(),
vmax=spec.max().item(),
)
return ax

View File

@ -1,160 +0,0 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
from batdetect2 import plotting
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.preprocess.types import PreprocessorProtocol
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_examples(
matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
):
class_examples = defaultdict(ClassExamples)
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
for class_name, examples in class_examples.items():
true_positives = get_binned_sample(
examples.true_positives,
n_examples=n_examples,
)
false_positives = get_binned_sample(
examples.false_positives,
n_examples=n_examples,
)
false_negatives = random.sample(
examples.false_negatives,
k=min(n_examples, len(examples.false_negatives)),
)
cross_triggers = get_binned_sample(
examples.cross_triggers,
n_examples=n_examples,
)
fig = plot_class_examples(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=preprocessor,
n_examples=n_examples,
)
yield class_name, fig
plt.close(fig)
def plot_class_examples(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
duration: float = 0.1,
):
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives):
ax = plt.subplot(4, n_examples, index + 1)
try:
plotting.plot_true_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
for index, match in enumerate(false_positives):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plotting.plot_false_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
for index, match in enumerate(false_negatives):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plotting.plot_false_negative_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
for index, match in enumerate(cross_triggers):
ax = plt.subplot(4, n_examples, 4 * n_examples + index + 1)
try:
plotting.plot_cross_trigger_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
return fig
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False)
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").apply(lambda x: x.sample(1))
return [matches[ind] for ind in sample["indices"]]

View File

@ -1,417 +0,0 @@
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
__all__ = [
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
]
DEFAULT_DURATION = 0.05
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
DEFAULT_TRUE_POSITIVE_COLOR = "green"
DEFAULT_CROSS_TRIGGER_COLOR = "orange"
DEFAULT_ANNOTATION_LINE_STYLE = "-"
DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[data.Match],
clip: data.Clip,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
if preprocessor is None:
preprocessor = get_default_preprocessor()
ax = plot_clip(
clip,
ax=ax,
figsize=figsize,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if color_mapper is None:
color_mapper = TagColorMapper()
for match in matches:
if match.source is None and match.target is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_negative_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
elif match.target is None and match.source is not None:
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
elif match.target is not None and match.source is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
else:
continue
return ax
def plot_false_positive_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
time_offset: float = 0,
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is not None
assert match.match.target is None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2,
sound_event.recording.duration,
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=time_offset,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_false_negative_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_true_positive_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_cross_trigger_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax

View File

@ -39,7 +39,6 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import ( from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction, convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction, convert_xr_dataset_to_raw_prediction,
) )
@ -62,11 +61,7 @@ from batdetect2.postprocess.remapping import (
features_to_xarray, features_to_xarray,
sizes_to_xarray, sizes_to_xarray,
) )
from batdetect2.postprocess.types import ( from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
BatDetect2Prediction,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -542,27 +537,6 @@ class Postprocessor(PostprocessorProtocol):
for dataset in detection_datasets for dataset in detection_datasets
] ]
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions( def get_predictions(
self, output: ModelOutput, clips: List[data.Clip] self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]: ) -> List[data.ClipPrediction]:

View File

@ -95,6 +95,7 @@ def convert_xr_dataset_to_raw_prediction(
for det_num in range(detection_dataset.sizes["detection"]): for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num) det_info = detection_dataset.sel(detection=det_num)
# TODO: Maybe clean this up
highest_scoring_class = det_info.coords["category"][ highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax() det_info["classes"].argmax()
].item() ].item()

View File

@ -11,7 +11,6 @@ modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions. system that deal with model predictions.
""" """
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol from typing import List, NamedTuple, Optional, Protocol
import xarray as xr import xarray as xr
@ -76,12 +75,6 @@ class RawPrediction(NamedTuple):
features: xr.DataArray features: xr.DataArray
@dataclass
class BatDetect2Prediction:
raw: RawPrediction
sound_event_prediction: data.SoundEventPrediction
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline. """Protocol defining the interface for the full postprocessing pipeline.
@ -261,10 +254,6 @@ class PostprocessorProtocol(Protocol):
""" """
... ...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions( def get_predictions(
self, self,
output: ModelOutput, output: ModelOutput,

View File

@ -86,7 +86,6 @@ __all__ = [
"build_spectrogram_builder", "build_spectrogram_builder",
"get_spectrogram_resolution", "get_spectrogram_resolution",
"load_preprocessing_config", "load_preprocessing_config",
"get_default_preprocessor",
] ]
@ -452,7 +451,3 @@ def build_preprocessor(
min_freq=min_freq, min_freq=min_freq,
max_freq=max_freq, max_freq=max_freq,
) )
def get_default_preprocessor():
return build_preprocessor()

View File

@ -2,13 +2,11 @@ from typing import List
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.evaluate.types import Match, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -16,55 +14,25 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True): def __init__(self, metrics: List[MetricsProtocol]):
super().__init__() super().__init__()
if len(metrics) == 0: if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided") raise ValueError("At least one metric needs to be provided")
self.matches: List[MatchEvaluation] = [] self.matches: List[Match] = []
self.metrics = metrics self.metrics = metrics
self.plot = plot
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
return dataset
def plot_examples(self, pl_module: LightningModule):
if not isinstance(pl_module.logger, TensorBoardLogger):
return
for class_name, fig in plot_examples(
self.matches,
preprocessor=pl_module.preprocessor,
n_examples=5,
):
pl_module.logger.experiment.add_figure(
f"{class_name}/examples",
fig,
pl_module.global_step,
)
def log_metrics(self, pl_module: LightningModule):
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
pl_module.log_dict(metrics)
def on_validation_epoch_end( def on_validation_epoch_end(
self, self,
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self.log_metrics(pl_module) metrics = {}
for metric in self.metrics:
if self.plot: metrics.update(metric(self.matches).items())
self.plot_examples(pl_module)
pl_module.log_dict(metrics)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)
def on_validation_epoch_start( def on_validation_epoch_start(
@ -84,7 +52,11 @@ class ValidationMetrics(Callback):
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> None: ) -> None:
dataset = self.get_dataset(trainer) dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotations = [ clip_annotations = [
_get_subclip( _get_subclip(
@ -102,7 +74,7 @@ class ValidationMetrics(Callback):
clips = [clip_annotation.clip for clip_annotation in clip_annotations] clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_sound_event_predictions( raw_predictions = pl_module.postprocessor.get_raw_predictions(
outputs, outputs,
clips, clips,
) )
@ -112,7 +84,7 @@ class ValidationMetrics(Callback):
): ):
self.matches.extend( self.matches.extend(
match_sound_events_and_raw_predictions( match_sound_events_and_raw_predictions(
clip_annotation=clip_annotation, sound_events=clip_annotation.sound_events,
raw_predictions=clip_predictions, raw_predictions=clip_predictions,
targets=pl_module.targets, targets=pl_module.targets,
) )

View File

@ -48,19 +48,20 @@ class TrainingModule(L.LightningModule):
def training_step(self, batch: TrainExample): def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec) outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True) self.log("detection_loss/train", losses.total, logger=True)
self.log("size_loss/train", losses.total, logger=True) self.log("size_loss/train", losses.total, logger=True)
self.log("classification_loss/train", losses.total, logger=True) self.log("classification_loss/train", losses.total, logger=True)
return losses.total return losses.total
def validation_step( # type: ignore def validation_step( # type: ignore
self, self, batch: TrainExample, batch_idx: int
batch: TrainExample,
batch_idx: int,
) -> ModelOutput: ) -> ModelOutput:
outputs = self.forward(batch.spec) outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True) self.log("detection_loss/val", losses.total, logger=True)
self.log("size_loss/val", losses.total, logger=True) self.log("size_loss/val", losses.total, logger=True)