mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Refactor eval code
This commit is contained in:
parent
356be57f62
commit
5b9a5a968f
@ -2,14 +2,10 @@ from batdetect2.evaluate.config import (
|
|||||||
EvaluationConfig,
|
EvaluationConfig,
|
||||||
load_evaluation_config,
|
load_evaluation_config,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.match import (
|
from batdetect2.evaluate.match import match_predictions_and_annotations
|
||||||
match_predictions_and_annotations,
|
|
||||||
match_sound_events_and_raw_predictions,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
"match_predictions_and_annotations",
|
"match_predictions_and_annotations",
|
||||||
"match_sound_events_and_raw_predictions",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import List, Literal, Optional, Protocol, Tuple
|
from typing import List, Literal, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
@ -10,10 +11,10 @@ from soundevent.geometry import compute_bounds
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
|
||||||
MatchingStrategy = Literal["greedy", "optimal"]
|
MatchingStrategy = Literal["greedy", "optimal"]
|
||||||
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
||||||
@ -274,7 +275,7 @@ def greedy_match(
|
|||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
def match_sound_events_and_raw_predictions(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
raw_predictions: List[BatDetect2Prediction],
|
raw_predictions: List[RawPrediction],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
@ -294,12 +295,11 @@ def match_sound_events_and_raw_predictions(
|
|||||||
]
|
]
|
||||||
|
|
||||||
predicted_geometries = [
|
predicted_geometries = [
|
||||||
raw_prediction.raw.geometry for raw_prediction in raw_predictions
|
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||||
]
|
]
|
||||||
|
|
||||||
scores = [
|
scores = [
|
||||||
raw_prediction.raw.detection_score
|
raw_prediction.detection_score for raw_prediction in raw_predictions
|
||||||
for raw_prediction in raw_predictions
|
|
||||||
]
|
]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
@ -320,14 +320,20 @@ def match_sound_events_and_raw_predictions(
|
|||||||
gt_det = target is not None
|
gt_det = target is not None
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
|
|
||||||
pred_score = float(prediction.raw.detection_score) if prediction else 0
|
pred_score = float(prediction.detection_score) if prediction else 0
|
||||||
|
|
||||||
|
pred_geometry = (
|
||||||
|
predicted_geometries[source_idx]
|
||||||
|
if source_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
class_scores = (
|
class_scores = (
|
||||||
{
|
{
|
||||||
str(class_name): float(score)
|
str(class_name): float(score)
|
||||||
for class_name, score in zip(
|
for class_name, score in zip(
|
||||||
targets.class_names,
|
targets.class_names,
|
||||||
prediction.raw.class_scores,
|
prediction.class_scores,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if prediction is not None
|
if prediction is not None
|
||||||
@ -336,17 +342,14 @@ def match_sound_events_and_raw_predictions(
|
|||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEvaluation(
|
MatchEvaluation(
|
||||||
match=data.Match(
|
clip=clip_annotation.clip,
|
||||||
source=None
|
sound_event_annotation=target,
|
||||||
if prediction is None
|
|
||||||
else prediction.sound_event_prediction,
|
|
||||||
target=target,
|
|
||||||
affinity=affinity,
|
|
||||||
),
|
|
||||||
gt_det=gt_det,
|
gt_det=gt_det,
|
||||||
gt_class=gt_class,
|
gt_class=gt_class,
|
||||||
pred_score=pred_score,
|
pred_score=pred_score,
|
||||||
pred_class_scores=class_scores,
|
pred_class_scores=class_scores,
|
||||||
|
pred_geometry=pred_geometry,
|
||||||
|
affinity=affinity,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -418,6 +421,28 @@ def match_predictions_and_annotations(
|
|||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def match_all_predictions(
|
||||||
|
clip_annotations: List[data.ClipAnnotation],
|
||||||
|
predictions: List[List[RawPrediction]],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
config: Optional[MatchConfig] = None,
|
||||||
|
) -> List[MatchEvaluation]:
|
||||||
|
logger.info("Matching all annotations and predictions...")
|
||||||
|
return [
|
||||||
|
match
|
||||||
|
for clip_annotation, raw_predictions in zip(
|
||||||
|
clip_annotations,
|
||||||
|
predictions,
|
||||||
|
)
|
||||||
|
for match in match_sound_events_and_raw_predictions(
|
||||||
|
clip_annotation,
|
||||||
|
raw_predictions,
|
||||||
|
targets=targets,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClassExamples:
|
class ClassExamples:
|
||||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
|||||||
@ -68,7 +68,7 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
|||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.typing.models import DetectionModel
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
|
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ class Model(LightningModule):
|
|||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[Detections]:
|
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|||||||
@ -124,25 +124,21 @@ def plot_false_positive_match(
|
|||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
time_offset: float = 0,
|
|
||||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: Union[float, str] = "small",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.match.source is not None
|
assert match.pred_geometry is not None
|
||||||
assert match.match.target is None
|
assert match.sound_event_annotation is None
|
||||||
sound_event = match.match.source.sound_event
|
|
||||||
geometry = sound_event.geometry
|
|
||||||
assert geometry is not None
|
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(start_time - duration / 2, 0),
|
||||||
end_time=min(
|
end_time=min(
|
||||||
start_time + duration / 2,
|
start_time + duration / 2,
|
||||||
sound_event.recording.duration,
|
match.clip.end_time,
|
||||||
),
|
),
|
||||||
recording=sound_event.recording,
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
ax = plot_clip(
|
||||||
@ -154,11 +150,9 @@ def plot_false_positive_match(
|
|||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_prediction(
|
plot.plot_geometry(
|
||||||
match.match.source,
|
match.pred_geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=time_offset,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -191,9 +185,9 @@ def plot_false_negative_match(
|
|||||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: Union[float, str] = "small",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.match.source is None
|
assert match.pred_geometry is None
|
||||||
assert match.match.target is not None
|
assert match.sound_event_annotation is not None
|
||||||
sound_event = match.match.target.sound_event
|
sound_event = match.sound_event_annotation.sound_event
|
||||||
geometry = sound_event.geometry
|
geometry = sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
@ -217,7 +211,7 @@ def plot_false_negative_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
plot.plot_annotation(
|
||||||
match.match.target,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
freq_offset=2_000,
|
freq_offset=2_000,
|
||||||
@ -255,9 +249,9 @@ def plot_true_positive_match(
|
|||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.match.source is not None
|
assert match.sound_event_annotation is not None
|
||||||
assert match.match.target is not None
|
assert match.pred_geometry is not None
|
||||||
sound_event = match.match.target.sound_event
|
sound_event = match.sound_event_annotation.sound_event
|
||||||
geometry = sound_event.geometry
|
geometry = sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
@ -281,7 +275,7 @@ def plot_true_positive_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
plot.plot_annotation(
|
||||||
match.match.target,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
freq_offset=2_000,
|
freq_offset=2_000,
|
||||||
@ -292,11 +286,9 @@ def plot_true_positive_match(
|
|||||||
linestyle=annotation_linestyle,
|
linestyle=annotation_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_prediction(
|
plot.plot_geometry(
|
||||||
match.match.source,
|
match.pred_geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -332,9 +324,9 @@ def plot_cross_trigger_match(
|
|||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.match.source is not None
|
assert match.sound_event_annotation is not None
|
||||||
assert match.match.target is not None
|
assert match.pred_geometry is not None
|
||||||
sound_event = match.match.source.sound_event
|
sound_event = match.sound_event_annotation.sound_event
|
||||||
geometry = sound_event.geometry
|
geometry = sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
@ -358,7 +350,7 @@ def plot_cross_trigger_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
plot.plot_annotation(
|
||||||
match.match.target,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
freq_offset=2_000,
|
freq_offset=2_000,
|
||||||
@ -369,11 +361,9 @@ def plot_cross_trigger_match(
|
|||||||
linestyle=annotation_linestyle,
|
linestyle=annotation_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_prediction(
|
plot.plot_geometry(
|
||||||
match.match.source,
|
match.pred_geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
|
|||||||
@ -10,9 +10,9 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.decoding import (
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
convert_detections_to_raw_predictions,
|
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
|
to_raw_predictions,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
||||||
from batdetect2.postprocess.nms import (
|
from batdetect2.postprocess.nms import (
|
||||||
@ -24,7 +24,8 @@ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
|||||||
from batdetect2.typing import ModelOutput
|
from batdetect2.typing import ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
Detections,
|
DetectionsArray,
|
||||||
|
DetectionsTensor,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
)
|
)
|
||||||
@ -43,7 +44,7 @@ __all__ = [
|
|||||||
"TOP_K_PER_SEC",
|
"TOP_K_PER_SEC",
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
"convert_detections_to_raw_predictions",
|
"to_raw_predictions",
|
||||||
"load_postprocess_config",
|
"load_postprocess_config",
|
||||||
"non_max_suppression",
|
"non_max_suppression",
|
||||||
]
|
]
|
||||||
@ -168,7 +169,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
self.top_k_per_sec = top_k_per_sec
|
self.top_k_per_sec = top_k_per_sec
|
||||||
self.detection_threshold = detection_threshold
|
self.detection_threshold = detection_threshold
|
||||||
|
|
||||||
def forward(self, output: ModelOutput) -> List[Detections]:
|
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
|
||||||
width = output.detection_probs.shape[-1]
|
width = output.detection_probs.shape[-1]
|
||||||
duration = width / self.samplerate
|
duration = width / self.samplerate
|
||||||
max_detections = int(self.top_k_per_sec * duration)
|
max_detections = int(self.top_k_per_sec * duration)
|
||||||
@ -192,7 +193,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: Optional[List[data.Clip]] = None,
|
clips: Optional[List[data.Clip]] = None,
|
||||||
) -> List[Detections]:
|
) -> List[DetectionsTensor]:
|
||||||
width = output.detection_probs.shape[-1]
|
width = output.detection_probs.shape[-1]
|
||||||
duration = width / self.samplerate
|
duration = width / self.samplerate
|
||||||
max_detections = int(self.top_k_per_sec * duration)
|
max_detections = int(self.top_k_per_sec * duration)
|
||||||
@ -245,11 +246,8 @@ def get_raw_predictions(
|
|||||||
"""
|
"""
|
||||||
detections = postprocessor.get_detections(output, clips)
|
detections = postprocessor.get_detections(output, clips)
|
||||||
return [
|
return [
|
||||||
convert_detections_to_raw_predictions(
|
to_raw_predictions(detection.numpy(), targets=targets)
|
||||||
dataset,
|
for detection in detections
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
for dataset in detections
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,13 +6,13 @@ import numpy as np
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
Detections,
|
DetectionsArray,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_detections_to_raw_predictions",
|
"to_raw_predictions",
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
"convert_raw_prediction_to_sound_event_prediction",
|
"convert_raw_prediction_to_sound_event_prediction",
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
@ -27,19 +27,19 @@ decoding.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def convert_detections_to_raw_predictions(
|
def to_raw_predictions(
|
||||||
detections: Detections,
|
detections: DetectionsArray,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
for score, class_scores, time, freq, dims, feats in zip(
|
for score, class_scores, time, freq, dims, feats in zip(
|
||||||
detections.scores.cpu().numpy(),
|
detections.scores,
|
||||||
detections.class_scores.cpu().numpy(),
|
detections.class_scores,
|
||||||
detections.times.cpu().numpy(),
|
detections.times,
|
||||||
detections.frequencies.cpu().numpy(),
|
detections.frequencies,
|
||||||
detections.sizes.cpu().numpy(),
|
detections.sizes,
|
||||||
detections.features.cpu().numpy(),
|
detections.features,
|
||||||
):
|
):
|
||||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,10 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||||
from batdetect2.typing.postprocess import Detections, ModelOutput
|
from batdetect2.typing.postprocess import (
|
||||||
|
DetectionsTensor,
|
||||||
|
ModelOutput,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_prediction_tensor",
|
"extract_prediction_tensor",
|
||||||
@ -32,7 +35,7 @@ def extract_prediction_tensor(
|
|||||||
max_detections: int = 200,
|
max_detections: int = 200,
|
||||||
threshold: Optional[float] = None,
|
threshold: Optional[float] = None,
|
||||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
) -> List[Detections]:
|
) -> List[DetectionsTensor]:
|
||||||
detection_heatmap = non_max_suppression(
|
detection_heatmap = non_max_suppression(
|
||||||
output.detection_probs.detach(),
|
output.detection_probs.detach(),
|
||||||
kernel_size=nms_kernel_size,
|
kernel_size=nms_kernel_size,
|
||||||
@ -78,7 +81,7 @@ def extract_prediction_tensor(
|
|||||||
class_scores = class_scores[mask]
|
class_scores = class_scores[mask]
|
||||||
|
|
||||||
predictions.append(
|
predictions.append(
|
||||||
Detections(
|
DetectionsTensor(
|
||||||
scores=detection_scores,
|
scores=detection_scores,
|
||||||
sizes=sizes,
|
sizes=sizes,
|
||||||
features=features,
|
features=features,
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import xarray as xr
|
|||||||
from soundevent.arrays import Dimensions
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.typing.postprocess import Detections
|
from batdetect2.typing.postprocess import DetectionsTensor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"features_to_xarray",
|
"features_to_xarray",
|
||||||
@ -31,15 +31,15 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def map_detection_to_clip(
|
def map_detection_to_clip(
|
||||||
detections: Detections,
|
detections: DetectionsTensor,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
min_freq: float,
|
min_freq: float,
|
||||||
max_freq: float,
|
max_freq: float,
|
||||||
) -> Detections:
|
) -> DetectionsTensor:
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
bandwidth = max_freq - min_freq
|
bandwidth = max_freq - min_freq
|
||||||
return Detections(
|
return DetectionsTensor(
|
||||||
scores=detections.scores,
|
scores=detections.scores,
|
||||||
sizes=detections.sizes,
|
sizes=detections.sizes,
|
||||||
features=detections.features,
|
features=detections.features,
|
||||||
|
|||||||
@ -21,7 +21,7 @@ configured processing steps. The main way to create a functional `Targets`
|
|||||||
object is via the `build_targets` or `load_targets` functions.
|
object is via the `build_targets` or `load_targets` functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -675,3 +675,24 @@ def load_targets(
|
|||||||
term_registry=term_registry,
|
term_registry=term_registry,
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_encoded_sound_events(
|
||||||
|
sound_events: Iterable[data.SoundEventAnnotation],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
||||||
|
for sound_event in sound_events:
|
||||||
|
if not targets.filter(sound_event):
|
||||||
|
continue
|
||||||
|
|
||||||
|
geometry = sound_event.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sound_event = targets.transform(sound_event)
|
||||||
|
|
||||||
|
class_name = targets.encode_class(sound_event)
|
||||||
|
position, size = targets.encode_roi(sound_event)
|
||||||
|
|
||||||
|
yield class_name, position, size
|
||||||
|
|||||||
@ -1,32 +1,26 @@
|
|||||||
import io
|
from typing import List, Optional
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from lightning.pytorch.loggers import Logger, TensorBoardLogger
|
|
||||||
from lightning.pytorch.loggers.mlflow import MLFlowLogger
|
|
||||||
from loguru import logger
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.match import (
|
from batdetect2.evaluate.match import (
|
||||||
MatchConfig,
|
MatchConfig,
|
||||||
match_sound_events_and_raw_predictions,
|
match_all_predictions,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model
|
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
from batdetect2.postprocess import get_sound_event_predictions
|
from batdetect2.postprocess import get_raw_predictions
|
||||||
from batdetect2.train.dataset import TrainingDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
|
from batdetect2.train.logging import get_image_plotter
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
MetricsProtocol,
|
MetricsProtocol,
|
||||||
ModelOutput,
|
|
||||||
TargetProtocol,
|
|
||||||
TrainExample,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.models import ModelOutput
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
from batdetect2.typing.train import TrainExample
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
@ -45,15 +39,14 @@ class ValidationMetrics(Callback):
|
|||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
|
|
||||||
self._matches: List[
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
|
self._predictions: List[List[RawPrediction]] = []
|
||||||
] = []
|
|
||||||
|
|
||||||
def get_dataset(self, trainer: Trainer) -> TrainingDataset:
|
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
dataset = dataloaders.dataset
|
dataset = dataloaders.dataset
|
||||||
assert isinstance(dataset, TrainingDataset)
|
assert isinstance(dataset, ValidationDataset)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def plot_examples(
|
def plot_examples(
|
||||||
@ -61,7 +54,7 @@ class ValidationMetrics(Callback):
|
|||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
matches: List[MatchEvaluation],
|
matches: List[MatchEvaluation],
|
||||||
):
|
):
|
||||||
plotter = _get_image_plotter(pl_module.logger) # type: ignore
|
plotter = get_image_plotter(pl_module.logger) # type: ignore
|
||||||
|
|
||||||
if plotter is None:
|
if plotter is None:
|
||||||
return
|
return
|
||||||
@ -93,9 +86,10 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
matches = _match_all_collected_examples(
|
matches = match_all_predictions(
|
||||||
self._matches,
|
self._clip_annotations,
|
||||||
pl_module.model.targets,
|
self._predictions,
|
||||||
|
targets=pl_module.model.targets,
|
||||||
config=self.match_config,
|
config=self.match_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -123,133 +117,23 @@ class ValidationMetrics(Callback):
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._matches.extend(
|
postprocessor = pl_module.model.postprocessor
|
||||||
_get_batch_clips_and_predictions(
|
targets = pl_module.model.targets
|
||||||
batch,
|
dataset = self.get_dataset(trainer)
|
||||||
outputs,
|
|
||||||
dataset=self.get_dataset(trainer),
|
|
||||||
model=pl_module.model,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
clip_annotations = [
|
||||||
|
dataset.clip_annotations[int(example_idx)]
|
||||||
|
for example_idx in batch.idx
|
||||||
|
]
|
||||||
|
|
||||||
def _get_batch_clips_and_predictions(
|
predictions = get_raw_predictions(
|
||||||
batch: TrainExample,
|
outputs,
|
||||||
outputs: ModelOutput,
|
clips=[
|
||||||
dataset: TrainingDataset,
|
clip_annotation.clip for clip_annotation in clip_annotations
|
||||||
model: Model,
|
],
|
||||||
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
|
||||||
clip_annotations = [
|
|
||||||
_get_subclip(
|
|
||||||
dataset.clip_annotations[int(example_id)],
|
|
||||||
start_time=start_time.item(),
|
|
||||||
end_time=end_time.item(),
|
|
||||||
targets=model.targets,
|
|
||||||
)
|
|
||||||
for example_id, start_time, end_time in zip(
|
|
||||||
batch.idx,
|
|
||||||
batch.start_time,
|
|
||||||
batch.end_time,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
|
||||||
|
|
||||||
raw_predictions = get_sound_event_predictions(
|
|
||||||
outputs,
|
|
||||||
clips,
|
|
||||||
targets=model.targets,
|
|
||||||
postprocessor=model.postprocessor
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
(clip_annotation, clip_predictions)
|
|
||||||
for clip_annotation, clip_predictions in zip(
|
|
||||||
clip_annotations, raw_predictions
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _match_all_collected_examples(
|
|
||||||
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
|
|
||||||
targets: TargetProtocol,
|
|
||||||
config: Optional[MatchConfig] = None,
|
|
||||||
) -> List[MatchEvaluation]:
|
|
||||||
logger.info("Matching all annotations and predictions...")
|
|
||||||
return [
|
|
||||||
match
|
|
||||||
for clip_annotation, raw_predictions in pre_matches
|
|
||||||
for match in match_sound_events_and_raw_predictions(
|
|
||||||
clip_annotation,
|
|
||||||
raw_predictions,
|
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config,
|
postprocessor=postprocessor,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
|
||||||
|
self._clip_annotations.extend(clip_annotations)
|
||||||
def _is_in_subclip(
|
self._predictions.extend(predictions)
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
) -> bool:
|
|
||||||
(time, _), _ = targets.encode_roi(sound_event_annotation)
|
|
||||||
return start_time <= time <= end_time
|
|
||||||
|
|
||||||
|
|
||||||
def _get_subclip(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return data.ClipAnnotation(
|
|
||||||
clip=data.Clip(
|
|
||||||
recording=clip_annotation.clip.recording,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
),
|
|
||||||
sound_events=[
|
|
||||||
sound_event_annotation
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
|
||||||
if _is_in_subclip(
|
|
||||||
sound_event_annotation,
|
|
||||||
targets,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_image_plotter(logger: Logger):
|
|
||||||
if isinstance(logger, TensorBoardLogger):
|
|
||||||
|
|
||||||
def plot_figure(name, figure, step):
|
|
||||||
return logger.experiment.add_figure(name, figure, step)
|
|
||||||
|
|
||||||
return plot_figure
|
|
||||||
|
|
||||||
if isinstance(logger, MLFlowLogger):
|
|
||||||
|
|
||||||
def plot_figure(name, figure, step):
|
|
||||||
image = _convert_figure_to_image(figure)
|
|
||||||
return logger.experiment.log_image(
|
|
||||||
run_id=logger.run_id,
|
|
||||||
image=image,
|
|
||||||
key=name,
|
|
||||||
step=step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return plot_figure
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_figure_to_image(figure):
|
|
||||||
with io.BytesIO() as buff:
|
|
||||||
figure.savefig(buff, format="raw")
|
|
||||||
buff.seek(0)
|
|
||||||
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
|
|
||||||
w, h = figure.canvas.get_width_height()
|
|
||||||
im = data.reshape((int(h), int(w), -1))
|
|
||||||
return im
|
|
||||||
|
|||||||
@ -6,7 +6,10 @@ from torch.utils.data import Dataset
|
|||||||
|
|
||||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
from batdetect2.typing.train import Augmentation, ClipLabeller
|
from batdetect2.typing.train import (
|
||||||
|
Augmentation,
|
||||||
|
ClipLabeller,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingDataset",
|
"TrainingDataset",
|
||||||
@ -75,3 +78,47 @@ class TrainingDataset(Dataset):
|
|||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
labeller: ClipLabeller,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
):
|
||||||
|
self.clip_annotations = clip_annotations
|
||||||
|
self.labeller = labeller
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.audio_dir = audio_dir
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
|
clip_annotation = self.clip_annotations[idx]
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
wav = self.audio_loader.load_clip(
|
||||||
|
clip_annotation.clip,
|
||||||
|
audio_dir=self.audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||||
|
|
||||||
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
|
|
||||||
|
heatmaps = self.labeller(clip_annotation, spectrogram)
|
||||||
|
|
||||||
|
return TrainExample(
|
||||||
|
spec=spectrogram,
|
||||||
|
detection_heatmap=heatmaps.detection,
|
||||||
|
class_heatmap=heatmaps.classes,
|
||||||
|
size_heatmap=heatmaps.size,
|
||||||
|
idx=torch.tensor(idx),
|
||||||
|
start_time=torch.tensor(clip.start_time),
|
||||||
|
end_time=torch.tensor(clip.end_time),
|
||||||
|
)
|
||||||
|
|||||||
@ -3,24 +3,6 @@
|
|||||||
This module is responsible for creating the target labels used for training
|
This module is responsible for creating the target labels used for training
|
||||||
BatDetect2 models. It converts sound event annotations for an audio clip into
|
BatDetect2 models. It converts sound event annotations for an audio clip into
|
||||||
the specific multi-channel heatmap formats required by the neural network.
|
the specific multi-channel heatmap formats required by the neural network.
|
||||||
|
|
||||||
It uses a pre-configured object adhering to the `TargetProtocol` (from
|
|
||||||
`batdetect2.targets`) which encapsulates all the logic for filtering
|
|
||||||
annotations, transforming tags, encoding class names, and mapping annotation
|
|
||||||
geometry (ROIs) to target positions and sizes. This module then focuses on
|
|
||||||
rendering this information onto the heatmap grids.
|
|
||||||
|
|
||||||
The pipeline generates three core outputs for a given spectrogram:
|
|
||||||
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
|
|
||||||
2. **Class Heatmap**: Indicates location and class identity for specifically
|
|
||||||
classified events.
|
|
||||||
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
|
|
||||||
|
|
||||||
The primary function generated by this module is a `ClipLabeller` (defined in
|
|
||||||
`.types`), which takes a `ClipAnnotation` object and its corresponding
|
|
||||||
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
|
|
||||||
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
|
||||||
defined in `LabelConfig`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -32,6 +14,7 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets import iterate_encoded_sound_events
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipLabeller,
|
ClipLabeller,
|
||||||
Heatmaps,
|
Heatmaps,
|
||||||
@ -56,9 +39,6 @@ class LabelConfig(BaseConfig):
|
|||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
sigma : float, default=3.0
|
sigma : float, default=3.0
|
||||||
The standard deviation (in pixels/bins) of the Gaussian kernel applied
|
|
||||||
to smooth the detection and class heatmaps. Larger values create more
|
|
||||||
diffuse targets.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sigma: float = 2.0
|
sigma: float = 2.0
|
||||||
@ -70,28 +50,7 @@ def build_clip_labeler(
|
|||||||
max_freq: float,
|
max_freq: float,
|
||||||
config: Optional[LabelConfig] = None,
|
config: Optional[LabelConfig] = None,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
"""Construct the final clip labelling function.
|
"""Construct the final clip labelling function."""
|
||||||
|
|
||||||
This factory function prepares the callable that will perform the
|
|
||||||
end-to-end heatmap generation for a given clip and spectrogram during
|
|
||||||
training data loading. It takes the fully configured `targets` object and
|
|
||||||
the `LabelConfig` and binds them to the `generate_clip_label` function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
targets : TargetProtocol
|
|
||||||
An initialized object conforming to the `TargetProtocol`, providing all
|
|
||||||
necessary methods for filtering, transforming, encoding, and ROI
|
|
||||||
mapping.
|
|
||||||
config : LabelConfig
|
|
||||||
Configuration object containing heatmap generation parameters.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ClipLabeller
|
|
||||||
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
|
|
||||||
(spectrogram) and returns the generated `Heatmaps`.
|
|
||||||
"""
|
|
||||||
config = config or LabelConfig()
|
config = config or LabelConfig()
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building clip labeler with config: \n{}",
|
"Building clip labeler with config: \n{}",
|
||||||
@ -119,37 +78,10 @@ def generate_heatmaps(
|
|||||||
target_sigma: float = 3.0,
|
target_sigma: float = 3.0,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
) -> Heatmaps:
|
) -> Heatmaps:
|
||||||
"""Generate training heatmaps for a single annotated clip.
|
"""Generate training heatmaps for a single annotated clip."""
|
||||||
|
|
||||||
This function orchestrates the target generation process for one clip:
|
|
||||||
1. Filters and transforms sound events using `targets.filter` and
|
|
||||||
`targets.transform`.
|
|
||||||
2. Passes the resulting processed annotations, along with the spectrogram,
|
|
||||||
the `targets` object, and the Gaussian `sigma` from `config`, to the
|
|
||||||
core `generate_heatmaps` function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip_annotation : data.ClipAnnotation
|
|
||||||
The complete annotation data for the audio clip, including the list
|
|
||||||
of `sound_events` to process.
|
|
||||||
spec : xr.DataArray
|
|
||||||
The spectrogram corresponding to the `clip_annotation`. Must have
|
|
||||||
'time' and 'frequency' dimensions/coordinates.
|
|
||||||
targets : TargetProtocol
|
|
||||||
The fully configured target definition object, providing methods for
|
|
||||||
filtering, transformation, encoding, and ROI mapping.
|
|
||||||
config : LabelConfig
|
|
||||||
Configuration object providing heatmap parameters (primarily `sigma`).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Heatmaps
|
|
||||||
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
|
||||||
heatmaps for this clip.
|
|
||||||
"""
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
"Will generate heatmaps for clip annotation "
|
||||||
|
"{uuid} with {num} annotated sound events",
|
||||||
uuid=clip_annotation.uuid,
|
uuid=clip_annotation.uuid,
|
||||||
num=len(clip_annotation.sound_events),
|
num=len(clip_annotation.sound_events),
|
||||||
)
|
)
|
||||||
@ -174,28 +106,10 @@ def generate_heatmaps(
|
|||||||
freqs = freqs.to(spec.device)
|
freqs = freqs.to(spec.device)
|
||||||
times = times.to(spec.device)
|
times = times.to(spec.device)
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for class_name, (time, frequency), size in iterate_encoded_sound_events(
|
||||||
if not targets.filter(sound_event_annotation):
|
clip_annotation.sound_events,
|
||||||
logger.debug(
|
targets,
|
||||||
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
|
):
|
||||||
sound_event=sound_event_annotation,
|
|
||||||
tags=sound_event_annotation.tags,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
sound_event_annotation = targets.transform(sound_event_annotation)
|
|
||||||
|
|
||||||
geom = sound_event_annotation.sound_event.geometry
|
|
||||||
if geom is None:
|
|
||||||
logger.debug(
|
|
||||||
"Skipping annotation %s: missing geometry.",
|
|
||||||
sound_event_annotation.uuid,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the position of the sound event
|
|
||||||
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
|
||||||
|
|
||||||
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
|
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
|
||||||
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
|
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
|
||||||
|
|
||||||
@ -206,9 +120,7 @@ def generate_heatmaps(
|
|||||||
or freq_index >= height
|
or freq_index >= height
|
||||||
):
|
):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Skipping annotation %s: position outside spectrogram. "
|
"Skipping annotation: position outside spectrogram. Pos: %s",
|
||||||
"Pos: %s",
|
|
||||||
sound_event_annotation.uuid,
|
|
||||||
(time, frequency),
|
(time, frequency),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@ -222,20 +134,8 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
|
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
|
||||||
|
|
||||||
# Get the class name of the sound event
|
# If the label is None skip the sound event
|
||||||
try:
|
|
||||||
class_name = targets.encode_class(sound_event_annotation)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warning(
|
|
||||||
"Skipping annotation %s: Unexpected error while encoding "
|
|
||||||
"class name %s",
|
|
||||||
sound_event_annotation.uuid,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None:
|
||||||
# If the label is None skip the sound event
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
class_index = targets.class_names.index(class_name)
|
class_index = targets.class_names.index(class_name)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
|
import io
|
||||||
from typing import Annotated, Any, Literal, Optional, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
from lightning.pytorch.loggers import Logger
|
import numpy as np
|
||||||
|
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -140,3 +142,35 @@ def build_logger(config: LoggerConfig) -> Logger:
|
|||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config)
|
return creation_func(config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_plotter(logger: Logger):
|
||||||
|
if isinstance(logger, TensorBoardLogger):
|
||||||
|
|
||||||
|
def plot_figure(name, figure, step):
|
||||||
|
return logger.experiment.add_figure(name, figure, step)
|
||||||
|
|
||||||
|
return plot_figure
|
||||||
|
|
||||||
|
if isinstance(logger, MLFlowLogger):
|
||||||
|
|
||||||
|
def plot_figure(name, figure, step):
|
||||||
|
image = _convert_figure_to_image(figure)
|
||||||
|
return logger.experiment.log_image(
|
||||||
|
run_id=logger.run_id,
|
||||||
|
image=image,
|
||||||
|
key=name,
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
|
||||||
|
return plot_figure
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_figure_to_image(figure):
|
||||||
|
with io.BytesIO() as buff:
|
||||||
|
figure.savefig(buff, format="raw")
|
||||||
|
buff.seek(0)
|
||||||
|
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
|
||||||
|
w, h = figure.canvas.get_width_height()
|
||||||
|
im = data.reshape((int(h), int(w), -1))
|
||||||
|
return im
|
||||||
|
|||||||
@ -24,9 +24,7 @@ from batdetect2.train.augmentations import (
|
|||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
||||||
TrainingDataset,
|
|
||||||
)
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
@ -304,11 +302,11 @@ def build_val_dataset(
|
|||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
) -> TrainingDataset:
|
) -> ValidationDataset:
|
||||||
logger.info("Building validation dataset...")
|
logger.info("Building validation dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
return TrainingDataset(
|
return ValidationDataset(
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
|
|||||||
@ -11,13 +11,17 @@ __all__ = [
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatchEvaluation:
|
class MatchEvaluation:
|
||||||
match: data.Match
|
clip: data.Clip
|
||||||
|
|
||||||
|
sound_event_annotation: Optional[data.SoundEventAnnotation]
|
||||||
gt_det: bool
|
gt_det: bool
|
||||||
gt_class: Optional[str]
|
gt_class: Optional[str]
|
||||||
|
|
||||||
pred_score: float
|
pred_score: float
|
||||||
pred_class_scores: Dict[str, float]
|
pred_class_scores: Dict[str, float]
|
||||||
|
pred_geometry: Optional[data.Geometry]
|
||||||
|
|
||||||
|
affinity: float
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pred_class(self) -> Optional[str]:
|
def pred_class(self) -> Optional[str]:
|
||||||
|
|||||||
@ -77,7 +77,16 @@ class RawPrediction(NamedTuple):
|
|||||||
features: np.ndarray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
class Detections(NamedTuple):
|
class DetectionsArray(NamedTuple):
|
||||||
|
scores: np.ndarray
|
||||||
|
sizes: np.ndarray
|
||||||
|
class_scores: np.ndarray
|
||||||
|
times: np.ndarray
|
||||||
|
frequencies: np.ndarray
|
||||||
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionsTensor(NamedTuple):
|
||||||
scores: torch.Tensor
|
scores: torch.Tensor
|
||||||
sizes: torch.Tensor
|
sizes: torch.Tensor
|
||||||
class_scores: torch.Tensor
|
class_scores: torch.Tensor
|
||||||
@ -85,6 +94,16 @@ class Detections(NamedTuple):
|
|||||||
frequencies: torch.Tensor
|
frequencies: torch.Tensor
|
||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
|
|
||||||
|
def numpy(self) -> DetectionsArray:
|
||||||
|
return DetectionsArray(
|
||||||
|
scores=self.scores.detach().numpy(),
|
||||||
|
sizes=self.sizes.detach().numpy(),
|
||||||
|
class_scores=self.class_scores.detach().numpy(),
|
||||||
|
times=self.times.detach().numpy(),
|
||||||
|
frequencies=self.frequencies.detach().numpy(),
|
||||||
|
features=self.features.detach().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatDetect2Prediction:
|
class BatDetect2Prediction:
|
||||||
@ -95,10 +114,10 @@ class BatDetect2Prediction:
|
|||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||||
|
|
||||||
def __call__(self, output: ModelOutput) -> List[Detections]: ...
|
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
|
||||||
|
|
||||||
def get_detections(
|
def get_detections(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: Optional[List[data.Clip]] = None,
|
clips: Optional[List[data.Clip]] = None,
|
||||||
) -> List[Detections]: ...
|
) -> List[DetectionsTensor]: ...
|
||||||
|
|||||||
@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
|
|||||||
throughout BatDetect2.
|
throughout BatDetect2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable
|
||||||
from typing import List, Optional, Protocol, Tuple
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user