Compare commits

...

14 Commits

Author SHA1 Message Date
mbsantiago
2341f822a7 Ignore plotting failures in gallery plot 2025-08-08 13:30:44 +01:00
mbsantiago
c3d377b6e0 Update soundevent 2025-08-08 13:07:12 +01:00
mbsantiago
d9395d3eeb Updated callback to include plotting 2025-08-08 13:06:28 +01:00
mbsantiago
aaec66c15e Added BatDetect2Prediction wrapper and method 2025-08-08 13:06:10 +01:00
mbsantiago
6213238585 Added matching configs 2025-08-08 13:05:50 +01:00
mbsantiago
a485ea4f79 Add get_default_preprocessor function 2025-08-08 13:03:53 +01:00
mbsantiago
3cfceb76b4 Added plotting example gallery function 2025-08-08 13:03:42 +01:00
mbsantiago
d877d383a4 Disable logs by default 2025-08-08 12:25:57 +01:00
mbsantiago
bb4a9fe645 Move legacy plot module to legacy folder 2025-08-08 12:25:49 +01:00
mbsantiago
87ce2acd6f Adding plotting functions 2025-08-08 12:25:26 +01:00
mbsantiago
e1908c35ca Update compat module to use new term module 2025-08-08 12:25:16 +01:00
mbsantiago
62923a201b Move dataset example to config 2025-08-08 12:24:59 +01:00
mbsantiago
d9323a1383 Move legacy evaluate code to legacy folder 2025-08-07 16:20:16 +01:00
mbsantiago
1ee9643a61 Update soundevent package to get match fixes 2025-08-06 21:39:42 +01:00
24 changed files with 1182 additions and 140 deletions

View File

@ -1,3 +1,14 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio
targets:
classes:
classes:
@ -46,7 +57,7 @@ preprocess:
max_freq: 120000
min_freq: 10000
pcen:
time_constant: 0.4
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5

View File

@ -1,10 +0,0 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.5.0",
"soundevent[audio,geometry,plot]>=2.6.5",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
@ -66,10 +66,7 @@ batdetect2 = "batdetect2.cli:cli"
[dependency-groups]
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
marimo = [
"marimo>=0.12.2",
"pyarrow>=20.0.0",
]
marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"]
dev = [
"debugpy>=1.8.8",
"hypothesis>=6.118.7",
@ -77,7 +74,7 @@ dev = [
"ruff>=0.7.3",
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"pyright>=1.1.399",
"basedpyright>=1.31.0",
"myst-parser>=3.0.1",
"sphinx-autobuild>=2024.10.3",
"numpydoc>=1.8.0",
@ -88,12 +85,8 @@ dev = [
"ty>=0.0.1a12",
"rust-just>=1.40.0",
]
dvclive = [
"dvclive>=3.48.2",
]
mlflow = [
"mlflow>=3.1.1",
]
dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"]
[tool.ruff]
line-length = 79

View File

@ -1,5 +1,10 @@
import logging
from loguru import logger
logger.disable("batdetect2")
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)

View File

@ -0,0 +1,15 @@
from batdetect2.compat.data import (
annotation_to_sound_event_annotation,
annotation_to_sound_event_prediction,
convert_to_annotation_group,
file_annotation_to_clip_annotation,
load_file_annotation,
)
__all__ = [
"annotation_to_sound_event_annotation",
"annotation_to_sound_event_prediction",
"convert_to_annotation_group",
"file_annotation_to_clip_annotation",
"load_file_annotation",
]

View File

@ -1,24 +1,30 @@
"""Compatibility functions between old and new data structures."""
import json
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
from pydantic import BaseModel, Field
from soundevent import data
from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2 import types
from batdetect2.targets.terms import get_term_from_key
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
FileAnnotation,
)
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"convert_to_annotation_group",
"load_file_annotation",
"annotation_to_sound_event",
"annotation_to_sound_event_annotation",
"annotation_to_sound_event_prediction",
]
SPECIES_TAG_KEY = "species"
@ -37,7 +43,7 @@ IndividualFn = Callable[[data.SoundEventAnnotation], int]
def get_recording_class_name(recording: data.Recording) -> str:
"""Get the class name for a recording."""
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
tag = data.find_tag(recording.tags, label=SPECIES_TAG_KEY)
if tag is None:
return UNKNOWN_CLASS
return tag.value
@ -59,7 +65,7 @@ def convert_to_annotation_group(
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
class_fn: ClassFn = lambda _: 0,
individual_fn: IndividualFn = lambda _: 0,
) -> types.AudioLoaderAnnotationGroup:
) -> AudioLoaderAnnotationGroup:
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
recording = annotation.clip.recording
@ -71,7 +77,7 @@ def convert_to_annotation_group(
x_inds = []
y_inds = []
individual_ids = []
annotations: List[types.Annotation] = []
annotations: List[Annotation] = []
class_id_file = class_fn(recording)
for sound_event in annotation.sound_events:
@ -133,42 +139,13 @@ def convert_to_annotation_group(
}
class Annotation(BaseModel):
"""Annotation class to hold batdetect annotations."""
label: str = Field(alias="class")
event: str
individual: int = 0
start_time: float
end_time: float
low_freq: float
high_freq: float
class FileAnnotation(BaseModel):
"""FileAnnotation class to hold batdetect annotations for a file."""
id: str
duration: float
time_exp: float = 1
label: str = Field(alias="class_name")
annotation: List[Annotation]
annotated: bool = False
issues: bool = False
notes: str = ""
def load_file_annotation(path: PathLike) -> FileAnnotation:
"""Load annotation from batdetect format."""
path = Path(path)
return FileAnnotation.model_validate_json(path.read_text())
return json.loads(path.read_text())
def annotation_to_sound_event(
def annotation_to_sound_event_annotation(
annotation: Annotation,
recording: data.Recording,
label_key: str = "class",
@ -179,15 +156,15 @@ def annotation_to_sound_event(
sound_event = data.SoundEvent(
uuid=uuid.uuid5(
NAMESPACE,
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}",
),
recording=recording,
geometry=data.BoundingBox(
coordinates=[
annotation.start_time,
annotation.low_freq,
annotation.end_time,
annotation.high_freq,
annotation["start_time"],
annotation["low_freq"],
annotation["end_time"],
annotation["high_freq"],
],
),
)
@ -197,16 +174,62 @@ def annotation_to_sound_event(
sound_event=sound_event,
tags=[
data.Tag(
term=data.term_from_key(label_key),
value=annotation.label,
term=get_term_from_key(label_key),
value=annotation["class"],
),
data.Tag(
term=data.term_from_key(event_key),
value=annotation.event,
term=get_term_from_key(event_key),
value=annotation["event"],
),
data.Tag(
term=data.term_from_key(individual_key),
value=str(annotation.individual),
term=get_term_from_key(individual_key),
value=str(annotation["individual"]),
),
],
)
def annotation_to_sound_event_prediction(
annotation: Annotation,
recording: data.Recording,
label_key: str = "class",
event_key: str = "event",
) -> data.SoundEventPrediction:
"""Convert annotation to sound event annotation."""
sound_event = data.SoundEvent(
uuid=uuid.uuid5(
NAMESPACE,
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}",
),
recording=recording,
geometry=data.BoundingBox(
coordinates=[
annotation["start_time"],
annotation["low_freq"],
annotation["end_time"],
annotation["high_freq"],
],
),
)
return data.SoundEventPrediction(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
score=annotation["det_prob"],
tags=[
data.PredictedTag(
score=annotation["class_prob"],
tag=data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
),
data.PredictedTag(
score=annotation["det_prob"],
tag=data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
),
],
)
@ -220,24 +243,24 @@ def file_annotation_to_clip(
"""Convert file annotation to recording."""
audio_dir = audio_dir or Path.cwd()
full_path = Path(audio_dir) / file_annotation.id
full_path = Path(audio_dir) / file_annotation["id"]
if not full_path.exists():
raise FileNotFoundError(f"File {full_path} not found.")
recording = data.Recording.from_file(
full_path,
time_expansion=file_annotation.time_exp,
time_expansion=file_annotation["time_exp"],
tags=[
data.Tag(
term=data.term_from_key(label_key),
value=file_annotation.label,
value=file_annotation["class_name"],
)
],
)
return data.Clip(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip"),
recording=recording,
start_time=0,
end_time=recording.duration,
@ -253,27 +276,28 @@ def file_annotation_to_clip_annotation(
) -> data.ClipAnnotation:
"""Convert file annotation to clip annotation."""
notes = []
if file_annotation.notes:
notes.append(data.Note(message=file_annotation.notes))
if file_annotation["notes"]:
notes.append(data.Note(message=file_annotation["notes"]))
return data.ClipAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip_annotation"),
clip=clip,
notes=notes,
tags=[
data.Tag(
term=data.term_from_key(label_key), value=file_annotation.label
term=data.term_from_key(label_key),
value=file_annotation["class_name"],
)
],
sound_events=[
annotation_to_sound_event(
annotation_to_sound_event_annotation(
annotation,
clip.recording,
label_key=label_key,
event_key=event_key,
individual_key=individual_key,
)
for annotation in file_annotation.annotation
for annotation in file_annotation["annotation"]
],
)
@ -284,17 +308,17 @@ def file_annotation_to_annotation_task(
) -> data.AnnotationTask:
status_badges = []
if file_annotation.issues:
if file_annotation["issues"]:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.rejected)
)
elif file_annotation.annotated:
elif file_annotation["annotated"]:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.completed)
)
return data.AnnotationTask(
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation['id']}_task"),
clip=clip,
status_badges=status_badges,
)

View File

@ -1,13 +1,9 @@
from batdetect2.evaluate.evaluate import (
compute_error_auc,
)
from batdetect2.evaluate.match import (
match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
)
__all__ = [
"compute_error_auc",
"match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
"match_predictions_and_annotations",
]

View File

@ -1,54 +1,133 @@
from typing import List
from typing import Annotated, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_geometries
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.types import Match
from batdetect2.postprocess.types import RawPrediction
from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig):
match_method: Literal["BBoxIOU"] = "BBoxIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
frequency_buffer: float = 1_000
class IntervalMatchConfig(BaseConfig):
match_method: Literal["IntervalIOU"] = "IntervalIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
class StartTimeMatchConfig(BaseConfig):
match_method: Literal["StartTime"] = "StartTime"
time_buffer: float = 0.01
MatchConfig = Annotated[
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
Field(discriminator="match_method"),
]
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
def prepare_geometry(
geometry: data.Geometry, config: MatchConfig
) -> data.Geometry:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
if config.match_method == "BBoxIOU":
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if config.match_method == "IntervalIOU":
return data.TimeInterval(coordinates=[start_time, end_time])
if config.match_method == "StartTime":
return data.TimeStamp(coordinates=start_time)
raise NotImplementedError(
f"Invalid matching configuration. Unknown match method: {config.match_method}"
)
def _get_frequency_buffer(config: MatchConfig) -> float:
if config.match_method == "BBoxIOU":
return config.frequency_buffer
return 0
def _get_affinity_threshold(config: MatchConfig) -> float:
if (
config.match_method == "BBoxIOU"
or config.match_method == "IntervalIOU"
):
return config.affinity_threshold
return 0
def match_sound_events_and_raw_predictions(
sound_events: List[data.SoundEventAnnotation],
raw_predictions: List[RawPrediction],
clip_annotation: data.ClipAnnotation,
raw_predictions: List[BatDetect2Prediction],
targets: TargetProtocol,
) -> List[Match]:
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
config = config or DEFAULT_MATCH_CONFIG
target_sound_events = [
targets.transform(sound_event_annotation)
for sound_event_annotation in sound_events
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
]
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
prepare_geometry(
sound_event_annotation.sound_event.geometry,
config=config,
)
for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
prepare_geometry(raw_prediction.raw.geometry, config=config)
for raw_prediction in raw_predictions
]
matches = []
for id1, id2, affinity in match_geometries(
target_geometries,
predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
):
target = target_sound_events[id1] if id1 is not None else None
prediction = raw_predictions[id2] if id2 is not None else None
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
pred_score = float(prediction.raw.detection_score) if prediction else 0
class_scores = (
{
str(class_name): float(score)
for class_name, score in iterate_over_array(
prediction.class_scores
prediction.raw.class_scores
)
}
if prediction is not None
@ -56,13 +135,18 @@ def match_sound_events_and_raw_predictions(
)
matches.append(
Match(
gt_uuid=gt_uuid,
MatchEvaluation(
match=data.Match(
source=None
if prediction is None
else prediction.sound_event_prediction,
target=target,
affinity=affinity,
),
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
affinity=affinity,
class_scores=class_scores,
pred_class_scores=class_scores,
)
)
@ -72,7 +156,10 @@ def match_sound_events_and_raw_predictions(
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]:
config = config or DEFAULT_MATCH_CONFIG
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
@ -86,13 +173,13 @@ def match_predictions_and_annotations(
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
@ -101,6 +188,9 @@ def match_predictions_and_annotations(
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None

View File

@ -4,13 +4,13 @@ import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import Match, MetricsProtocol
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
__all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]:
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
)
@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[Match]) -> Dict[str, float]:
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[Match]) -> Dict[str, float]:
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol):
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches

View File

@ -1,22 +1,40 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from uuid import UUID
from soundevent import data
__all__ = [
"MetricsProtocol",
"Match",
"MatchEvaluation",
]
@dataclass
class Match:
gt_uuid: Optional[UUID]
class MatchEvaluation:
match: data.Match
gt_det: bool
gt_class: Optional[str]
pred_score: float
affinity: float
class_scores: Dict[str, float]
pred_class_scores: Dict[str, float]
@property
def pred_class(self) -> Optional[str]:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def pred_class_score(self) -> float:
pred_class = self.pred_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
class MetricsProtocol(Protocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...

View File

@ -0,0 +1,21 @@
from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
__all__ = [
"plot_clip",
"plot_clip_annotation",
"plot_clip_prediction",
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
]

View File

@ -0,0 +1,49 @@
from typing import Optional, Tuple
from matplotlib.axes import Axes
from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_annotation",
]
def plot_clip_annotation(
clip_annotation: data.ClipAnnotation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
cmap: str = "gray",
alpha: float = 1,
linewidth: float = 1,
fill: bool = False,
) -> Axes:
ax = plot_clip(
clip_annotation.clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=cmap,
)
plot.plot_annotations(
clip_annotation.sound_events,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
alpha=alpha,
linewidth=linewidth,
facecolor="none" if not fill else None,
)
return ax

View File

@ -0,0 +1,141 @@
from typing import Iterable, Optional, Tuple
from matplotlib.axes import Axes
from soundevent import data
from soundevent.geometry.operations import Positions, get_geometry_point
from soundevent.plot.common import create_axes
from soundevent.plot.geometries import plot_geometry
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_prediction",
]
def plot_clip_prediction(
clip_prediction: data.ClipPrediction,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_legend: bool = False,
spec_cmap: str = "gray",
linewidth: float = 1,
fill: bool = False,
) -> Axes:
ax = plot_clip(
clip_prediction.clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot_predictions(
clip_prediction.sound_events,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=False,
linewidth=linewidth,
facecolor="none" if not fill else None,
legend=add_legend,
)
return ax
def plot_predictions(
predictions: Iterable[data.SoundEventPrediction],
ax: Optional[Axes] = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
legend: bool = True,
max_alpha: float = 0.5,
color: Optional[str] = None,
**kwargs,
):
"""Plot an prediction."""
if ax is None:
ax = create_axes(**kwargs)
if color_mapper is None:
color_mapper = TagColorMapper()
for prediction in predictions:
ax = plot_prediction(
prediction,
ax=ax,
position=position,
color_mapper=color_mapper,
time_offset=time_offset,
freq_offset=freq_offset,
max_alpha=max_alpha,
color=color,
**kwargs,
)
if legend:
ax = add_tags_legend(ax, color_mapper)
return ax
def plot_prediction(
prediction: data.SoundEventPrediction,
ax: Optional[Axes] = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
max_alpha: float = 0.5,
alpha: Optional[float] = None,
color: Optional[str] = None,
**kwargs,
) -> Axes:
"""Plot an annotation."""
geometry = prediction.sound_event.geometry
if geometry is None:
raise ValueError("Annotation does not have a geometry.")
if ax is None:
ax = create_axes(**kwargs)
if color_mapper is None:
color_mapper = TagColorMapper()
if alpha is None:
alpha = min(prediction.score * max_alpha, 1)
ax = plot_geometry(
geometry,
ax=ax,
color=color,
alpha=alpha,
**kwargs,
)
x, y = get_geometry_point(geometry, position=position)
for index, tag in enumerate(prediction.tags):
color = color_mapper.get_color(tag.tag)
ax = plot_tag(
time=x + time_offset,
frequency=y - index * freq_offset,
color=color,
ax=ax,
alpha=min(tag.score, prediction.score),
**kwargs,
)
return ax

View File

@ -0,0 +1,44 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
__all__ = [
"plot_clip",
]
def plot_clip(
clip: data.Clip,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
spec_cmap: str = "gray",
) -> Axes:
if ax is None:
_, ax = plt.subplots(figsize=figsize)
if preprocessor is None:
preprocessor = get_default_preprocessor()
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
spec.plot( # type: ignore
ax=ax,
add_colorbar=add_colorbar,
cmap=spec_cmap,
add_labels=add_labels,
vmin=spec.min().item(),
vmax=spec.max().item(),
)
return ax

View File

@ -0,0 +1,160 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
from batdetect2 import plotting
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.preprocess.types import PreprocessorProtocol
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_examples(
matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
):
class_examples = defaultdict(ClassExamples)
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
for class_name, examples in class_examples.items():
true_positives = get_binned_sample(
examples.true_positives,
n_examples=n_examples,
)
false_positives = get_binned_sample(
examples.false_positives,
n_examples=n_examples,
)
false_negatives = random.sample(
examples.false_negatives,
k=min(n_examples, len(examples.false_negatives)),
)
cross_triggers = get_binned_sample(
examples.cross_triggers,
n_examples=n_examples,
)
fig = plot_class_examples(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=preprocessor,
n_examples=n_examples,
)
yield class_name, fig
plt.close(fig)
def plot_class_examples(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
duration: float = 0.1,
):
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives):
ax = plt.subplot(4, n_examples, index + 1)
try:
plotting.plot_true_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
for index, match in enumerate(false_positives):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plotting.plot_false_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
for index, match in enumerate(false_negatives):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plotting.plot_false_negative_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
for index, match in enumerate(cross_triggers):
ax = plt.subplot(4, n_examples, 4 * n_examples + index + 1)
try:
plotting.plot_cross_trigger_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except ValueError:
continue
return fig
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False)
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").apply(lambda x: x.sample(1))
return [matches[ind] for ind in sample["indices"]]

View File

@ -0,0 +1,417 @@
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
__all__ = [
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
]
DEFAULT_DURATION = 0.05
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
DEFAULT_TRUE_POSITIVE_COLOR = "green"
DEFAULT_CROSS_TRIGGER_COLOR = "orange"
DEFAULT_ANNOTATION_LINE_STYLE = "-"
DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[data.Match],
clip: data.Clip,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
if preprocessor is None:
preprocessor = get_default_preprocessor()
ax = plot_clip(
clip,
ax=ax,
figsize=figsize,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if color_mapper is None:
color_mapper = TagColorMapper()
for match in matches:
if match.source is None and match.target is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_negative_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
elif match.target is None and match.source is not None:
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
elif match.target is not None and match.source is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
else:
continue
return ax
def plot_false_positive_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
time_offset: float = 0,
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is not None
assert match.match.target is None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2,
sound_event.recording.duration,
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=time_offset,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_false_negative_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_true_positive_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_cross_trigger_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax

View File

@ -39,6 +39,7 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
)
@ -61,7 +62,11 @@ from batdetect2.postprocess.remapping import (
features_to_xarray,
sizes_to_xarray,
)
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol
@ -537,6 +542,27 @@ class Postprocessor(PostprocessorProtocol):
for dataset in detection_datasets
]
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]:

View File

@ -95,7 +95,6 @@ def convert_xr_dataset_to_raw_prediction(
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
# TODO: Maybe clean this up
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()

View File

@ -11,6 +11,7 @@ modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
import xarray as xr
@ -75,6 +76,12 @@ class RawPrediction(NamedTuple):
features: xr.DataArray
@dataclass
class BatDetect2Prediction:
raw: RawPrediction
sound_event_prediction: data.SoundEventPrediction
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.
@ -254,6 +261,10 @@ class PostprocessorProtocol(Protocol):
"""
...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions(
self,
output: ModelOutput,

View File

@ -86,6 +86,7 @@ __all__ = [
"build_spectrogram_builder",
"get_spectrogram_resolution",
"load_preprocessing_config",
"get_default_preprocessor",
]
@ -451,3 +452,7 @@ def build_preprocessor(
min_freq=min_freq,
max_freq=max_freq,
)
def get_default_preprocessor():
return build_preprocessor()

View File

@ -2,11 +2,13 @@ from typing import List
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.types import Match, MetricsProtocol
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_examples
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule
@ -14,25 +16,55 @@ from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback):
def __init__(self, metrics: List[MetricsProtocol]):
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
super().__init__()
if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided")
self.matches: List[Match] = []
self.matches: List[MatchEvaluation] = []
self.metrics = metrics
self.plot = plot
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
return dataset
def plot_examples(self, pl_module: LightningModule):
if not isinstance(pl_module.logger, TensorBoardLogger):
return
for class_name, fig in plot_examples(
self.matches,
preprocessor=pl_module.preprocessor,
n_examples=5,
):
pl_module.logger.experiment.add_figure(
f"{class_name}/examples",
fig,
pl_module.global_step,
)
def log_metrics(self, pl_module: LightningModule):
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
pl_module.log_dict(metrics)
def on_validation_epoch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
self.log_metrics(pl_module)
if self.plot:
self.plot_examples(pl_module)
pl_module.log_dict(metrics)
return super().on_validation_epoch_end(trainer, pl_module)
def on_validation_epoch_start(
@ -52,11 +84,7 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
dataset = self.get_dataset(trainer)
clip_annotations = [
_get_subclip(
@ -74,7 +102,7 @@ class ValidationMetrics(Callback):
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_raw_predictions(
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs,
clips,
)
@ -84,7 +112,7 @@ class ValidationMetrics(Callback):
):
self.matches.extend(
match_sound_events_and_raw_predictions(
sound_events=clip_annotation.sound_events,
clip_annotation=clip_annotation,
raw_predictions=clip_predictions,
targets=pl_module.targets,
)

View File

@ -48,20 +48,19 @@ class TrainingModule(L.LightningModule):
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True)
self.log("size_loss/train", losses.total, logger=True)
self.log("classification_loss/train", losses.total, logger=True)
return losses.total
def validation_step( # type: ignore
self, batch: TrainExample, batch_idx: int
self,
batch: TrainExample,
batch_idx: int,
) -> ModelOutput:
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True)
self.log("size_loss/val", losses.total, logger=True)