Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
71c2301c21 Independent preprocessor for generating validation plots 2025-08-31 23:04:16 +01:00
mbsantiago
d3d2a28130 Move detections array to cpu 2025-08-31 22:59:06 +01:00
mbsantiago
5b9a5a968f Refactor eval code 2025-08-31 22:57:02 +01:00
17 changed files with 285 additions and 357 deletions

View File

@ -2,14 +2,10 @@ from batdetect2.evaluate.config import (
EvaluationConfig,
load_evaluation_config,
)
from batdetect2.evaluate.match import (
match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
)
from batdetect2.evaluate.match import match_predictions_and_annotations
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
"match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
]

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np
from loguru import logger
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match
@ -10,10 +11,10 @@ from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.typing import (
BatDetect2Prediction,
MatchEvaluation,
TargetProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
@ -274,7 +275,7 @@ def greedy_match(
def match_sound_events_and_raw_predictions(
clip_annotation: data.ClipAnnotation,
raw_predictions: List[BatDetect2Prediction],
raw_predictions: List[RawPrediction],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
@ -294,12 +295,11 @@ def match_sound_events_and_raw_predictions(
]
predicted_geometries = [
raw_prediction.raw.geometry for raw_prediction in raw_predictions
raw_prediction.geometry for raw_prediction in raw_predictions
]
scores = [
raw_prediction.raw.detection_score
for raw_prediction in raw_predictions
raw_prediction.detection_score for raw_prediction in raw_predictions
]
matches = []
@ -320,14 +320,20 @@ def match_sound_events_and_raw_predictions(
gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.raw.detection_score) if prediction else 0
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.raw.class_scores,
prediction.class_scores,
)
}
if prediction is not None
@ -336,17 +342,14 @@ def match_sound_events_and_raw_predictions(
matches.append(
MatchEvaluation(
match=data.Match(
source=None
if prediction is None
else prediction.sound_event_prediction,
target=target,
affinity=affinity,
),
clip=clip_annotation.clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
@ -418,6 +421,28 @@ def match_predictions_and_annotations(
return matches
def match_all_predictions(
clip_annotations: List[data.ClipAnnotation],
predictions: List[List[RawPrediction]],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in zip(
clip_annotations,
predictions,
)
for match in match_sound_events_and_raw_predictions(
clip_annotation,
raw_predictions,
targets=targets,
config=config,
)
]
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)

View File

@ -68,7 +68,7 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
@ -122,7 +122,7 @@ class Model(LightningModule):
self.targets = targets
self.save_hyperparameters()
def forward(self, wav: torch.Tensor) -> List[Detections]:
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)

View File

@ -124,25 +124,21 @@ def plot_false_positive_match(
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
time_offset: float = 0,
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is not None
assert match.match.target is None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
assert match.pred_geometry is not None
assert match.sound_event_annotation is None
start_time, _, _, high_freq = compute_bounds(geometry)
start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2,
sound_event.recording.duration,
match.clip.end_time,
),
recording=sound_event.recording,
recording=match.clip.recording,
)
ax = plot_clip(
@ -154,11 +150,9 @@ def plot_false_positive_match(
spec_cmap=spec_cmap,
)
plot_prediction(
match.match.source,
plot.plot_geometry(
match.pred_geometry,
ax=ax,
time_offset=time_offset,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -191,9 +185,9 @@ def plot_false_negative_match(
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is None
assert match.match.target is not None
sound_event = match.match.target.sound_event
assert match.pred_geometry is None
assert match.sound_event_annotation is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert geometry is not None
@ -217,7 +211,7 @@ def plot_false_negative_match(
)
plot.plot_annotation(
match.match.target,
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
@ -255,9 +249,9 @@ def plot_true_positive_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.target.sound_event
assert match.sound_event_annotation is not None
assert match.pred_geometry is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert geometry is not None
@ -281,7 +275,7 @@ def plot_true_positive_match(
)
plot.plot_annotation(
match.match.target,
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
@ -292,11 +286,9 @@ def plot_true_positive_match(
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
plot.plot_geometry(
match.pred_geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -332,9 +324,9 @@ def plot_cross_trigger_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.source.sound_event
assert match.sound_event_annotation is not None
assert match.pred_geometry is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert geometry is not None
@ -358,7 +350,7 @@ def plot_cross_trigger_match(
)
plot.plot_annotation(
match.match.target,
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
@ -369,11 +361,9 @@ def plot_cross_trigger_match(
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
plot.plot_geometry(
match.pred_geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,

View File

@ -10,9 +10,9 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_detections_to_raw_predictions,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.nms import (
@ -24,7 +24,7 @@ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
Detections,
DetectionsTensor,
PostprocessorProtocol,
RawPrediction,
)
@ -43,7 +43,7 @@ __all__ = [
"TOP_K_PER_SEC",
"build_postprocessor",
"convert_raw_predictions_to_clip_prediction",
"convert_detections_to_raw_predictions",
"to_raw_predictions",
"load_postprocess_config",
"non_max_suppression",
]
@ -168,7 +168,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[Detections]:
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
@ -192,7 +192,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
self,
output: ModelOutput,
clips: Optional[List[data.Clip]] = None,
) -> List[Detections]:
) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
@ -245,11 +245,8 @@ def get_raw_predictions(
"""
detections = postprocessor.get_detections(output, clips)
return [
convert_detections_to_raw_predictions(
dataset,
targets=targets,
)
for dataset in detections
to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections
]

View File

@ -6,13 +6,13 @@ import numpy as np
from soundevent import data
from batdetect2.typing.postprocess import (
Detections,
DetectionsArray,
RawPrediction,
)
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"convert_detections_to_raw_predictions",
"to_raw_predictions",
"convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD",
@ -27,19 +27,19 @@ decoding.
"""
def convert_detections_to_raw_predictions(
detections: Detections,
def to_raw_predictions(
detections: DetectionsArray,
targets: TargetProtocol,
) -> List[RawPrediction]:
predictions = []
for score, class_scores, time, freq, dims, feats in zip(
detections.scores.cpu().numpy(),
detections.class_scores.cpu().numpy(),
detections.times.cpu().numpy(),
detections.frequencies.cpu().numpy(),
detections.sizes.cpu().numpy(),
detections.features.cpu().numpy(),
detections.scores,
detections.class_scores,
detections.times,
detections.frequencies,
detections.sizes,
detections.features,
):
highest_scoring_class = targets.class_names[class_scores.argmax()]

View File

@ -20,7 +20,10 @@ from typing import List, Optional, Tuple, Union
import torch
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.typing.postprocess import Detections, ModelOutput
from batdetect2.typing.postprocess import (
DetectionsTensor,
ModelOutput,
)
__all__ = [
"extract_prediction_tensor",
@ -32,7 +35,7 @@ def extract_prediction_tensor(
max_detections: int = 200,
threshold: Optional[float] = None,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> List[Detections]:
) -> List[DetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=nms_kernel_size,
@ -78,7 +81,7 @@ def extract_prediction_tensor(
class_scores = class_scores[mask]
predictions.append(
Detections(
DetectionsTensor(
scores=detection_scores,
sizes=sizes,
features=features,

View File

@ -20,7 +20,7 @@ import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import Detections
from batdetect2.typing.postprocess import DetectionsTensor
__all__ = [
"features_to_xarray",
@ -31,15 +31,15 @@ __all__ = [
def map_detection_to_clip(
detections: Detections,
detections: DetectionsTensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> Detections:
) -> DetectionsTensor:
duration = end_time - start_time
bandwidth = max_freq - min_freq
return Detections(
return DetectionsTensor(
scores=detections.scores,
sizes=detections.sizes,
features=detections.features,

View File

@ -21,7 +21,7 @@ configured processing steps. The main way to create a functional `Targets`
object is via the `build_targets` or `load_targets` functions.
"""
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
from loguru import logger
from pydantic import Field
@ -675,3 +675,24 @@ def load_targets(
term_registry=term_registry,
derivation_registry=derivation_registry,
)
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
sound_event = targets.transform(sound_event)
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)
yield class_name, position, size

View File

@ -1,38 +1,34 @@
import io
from typing import List, Optional, Tuple
from typing import List, Optional
import numpy as np
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger, TensorBoardLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
match_all_predictions,
)
from batdetect2.models import Model
from batdetect2.plotting.clips import PreprocessorProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions
from batdetect2.train.dataset import TrainingDataset
from batdetect2.postprocess import get_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter
from batdetect2.typing import (
BatDetect2Prediction,
MatchEvaluation,
MetricsProtocol,
ModelOutput,
TargetProtocol,
TrainExample,
)
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.train import TrainExample
class ValidationMetrics(Callback):
def __init__(
self,
metrics: List[MetricsProtocol],
preprocessor: PreprocessorProtocol,
plot: bool = True,
match_config: Optional[MatchConfig] = None,
):
@ -43,17 +39,17 @@ class ValidationMetrics(Callback):
self.match_config = match_config
self.metrics = metrics
self.preprocessor = preprocessor
self.plot = plot
self._matches: List[
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
] = []
self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[List[RawPrediction]] = []
def get_dataset(self, trainer: Trainer) -> TrainingDataset:
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, TrainingDataset)
assert isinstance(dataset, ValidationDataset)
return dataset
def plot_examples(
@ -61,14 +57,14 @@ class ValidationMetrics(Callback):
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
plotter = _get_image_plotter(pl_module.logger) # type: ignore
plotter = get_image_plotter(pl_module.logger) # type: ignore
if plotter is None:
return
for class_name, fig in plot_example_gallery(
matches,
preprocessor=pl_module.model.preprocessor,
preprocessor=self.preprocessor,
n_examples=4,
):
plotter(
@ -93,9 +89,10 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
matches = _match_all_collected_examples(
self._matches,
pl_module.model.targets,
matches = match_all_predictions(
self._clip_annotations,
self._predictions,
targets=pl_module.model.targets,
config=self.match_config,
)
@ -123,133 +120,23 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self._matches.extend(
_get_batch_clips_and_predictions(
batch,
outputs,
dataset=self.get_dataset(trainer),
model=pl_module.model,
)
)
postprocessor = pl_module.model.postprocessor
targets = pl_module.model.targets
dataset = self.get_dataset(trainer)
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: TrainingDataset,
model: Model,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.clip_annotations[int(example_id)],
start_time=start_time.item(),
end_time=end_time.item(),
targets=model.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = get_sound_event_predictions(
outputs,
clips,
targets=model.targets,
postprocessor=model.postprocessor
)
return [
(clip_annotation, clip_predictions)
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
)
]
def _match_all_collected_examples(
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in pre_matches
for match in match_sound_events_and_raw_predictions(
clip_annotation,
raw_predictions,
predictions = get_raw_predictions(
outputs,
clips=[
clip_annotation.clip for clip_annotation in clip_annotations
],
targets=targets,
config=config,
postprocessor=postprocessor,
)
]
def _is_in_subclip(
sound_event_annotation: data.SoundEventAnnotation,
targets: TargetProtocol,
start_time: float,
end_time: float,
) -> bool:
(time, _), _ = targets.encode_roi(sound_event_annotation)
return start_time <= time <= end_time
def _get_subclip(
clip_annotation: data.ClipAnnotation,
start_time: float,
end_time: float,
targets: TargetProtocol,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=data.Clip(
recording=clip_annotation.clip.recording,
start_time=start_time,
end_time=end_time,
),
sound_events=[
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if _is_in_subclip(
sound_event_annotation,
targets,
start_time=start_time,
end_time=end_time,
)
],
)
def _get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(
run_id=logger.run_id,
image=image,
key=name,
step=step,
)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im
self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions)

View File

@ -6,7 +6,10 @@ from torch.utils.data import Dataset
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import Augmentation, ClipLabeller
from batdetect2.typing.train import (
Augmentation,
ClipLabeller,
)
__all__ = [
"TrainingDataset",
@ -75,3 +78,47 @@ class TrainingDataset(Dataset):
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class ValidationDataset(Dataset):
def __init__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = clip_annotations
self.labeller = labeller
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx]
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(
clip_annotation.clip,
audio_dir=self.audio_dir,
)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
heatmaps = self.labeller(clip_annotation, spectrogram)
return TrainExample(
spec=spectrogram,
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)

View File

@ -3,24 +3,6 @@
This module is responsible for creating the target labels used for training
BatDetect2 models. It converts sound event annotations for an audio clip into
the specific multi-channel heatmap formats required by the neural network.
It uses a pre-configured object adhering to the `TargetProtocol` (from
`batdetect2.targets`) which encapsulates all the logic for filtering
annotations, transforming tags, encoding class names, and mapping annotation
geometry (ROIs) to target positions and sizes. This module then focuses on
rendering this information onto the heatmap grids.
The pipeline generates three core outputs for a given spectrogram:
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
2. **Class Heatmap**: Indicates location and class identity for specifically
classified events.
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
The primary function generated by this module is a `ClipLabeller` (defined in
`.types`), which takes a `ClipAnnotation` object and its corresponding
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`.
"""
from functools import partial
@ -32,6 +14,7 @@ from loguru import logger
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets import iterate_encoded_sound_events
from batdetect2.typing import (
ClipLabeller,
Heatmaps,
@ -56,9 +39,6 @@ class LabelConfig(BaseConfig):
Attributes
----------
sigma : float, default=3.0
The standard deviation (in pixels/bins) of the Gaussian kernel applied
to smooth the detection and class heatmaps. Larger values create more
diffuse targets.
"""
sigma: float = 2.0
@ -70,28 +50,7 @@ def build_clip_labeler(
max_freq: float,
config: Optional[LabelConfig] = None,
) -> ClipLabeller:
"""Construct the final clip labelling function.
This factory function prepares the callable that will perform the
end-to-end heatmap generation for a given clip and spectrogram during
training data loading. It takes the fully configured `targets` object and
the `LabelConfig` and binds them to the `generate_clip_label` function.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing all
necessary methods for filtering, transforming, encoding, and ROI
mapping.
config : LabelConfig
Configuration object containing heatmap generation parameters.
Returns
-------
ClipLabeller
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
"""Construct the final clip labelling function."""
config = config or LabelConfig()
logger.opt(lazy=True).debug(
"Building clip labeler with config: \n{}",
@ -119,37 +78,10 @@ def generate_heatmaps(
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip.
This function orchestrates the target generation process for one clip:
1. Filters and transforms sound events using `targets.filter` and
`targets.transform`.
2. Passes the resulting processed annotations, along with the spectrogram,
the `targets` object, and the Gaussian `sigma` from `config`, to the
core `generate_heatmaps` function.
Parameters
----------
clip_annotation : data.ClipAnnotation
The complete annotation data for the audio clip, including the list
of `sound_events` to process.
spec : xr.DataArray
The spectrogram corresponding to the `clip_annotation`. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object, providing methods for
filtering, transformation, encoding, and ROI mapping.
config : LabelConfig
Configuration object providing heatmap parameters (primarily `sigma`).
Returns
-------
Heatmaps
A NamedTuple containing the generated 'detection', 'classes', and 'size'
heatmaps for this clip.
"""
"""Generate training heatmaps for a single annotated clip."""
logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
"Will generate heatmaps for clip annotation "
"{uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events),
)
@ -174,28 +106,10 @@ def generate_heatmaps(
freqs = freqs.to(spec.device)
times = times.to(spec.device)
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_event_annotation = targets.transform(sound_event_annotation)
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
"Skipping annotation %s: missing geometry.",
sound_event_annotation.uuid,
)
continue
# Get the position of the sound event
(time, frequency), size = targets.encode_roi(sound_event_annotation)
for class_name, (time, frequency), size in iterate_encoded_sound_events(
clip_annotation.sound_events,
targets,
):
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
@ -206,9 +120,7 @@ def generate_heatmaps(
or freq_index >= height
):
logger.debug(
"Skipping annotation %s: position outside spectrogram. "
"Pos: %s",
sound_event_annotation.uuid,
"Skipping annotation: position outside spectrogram. Pos: %s",
(time, frequency),
)
continue
@ -222,20 +134,8 @@ def generate_heatmaps(
)
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
# Get the class name of the sound event
try:
class_name = targets.encode_class(sound_event_annotation)
except ValueError as e:
logger.warning(
"Skipping annotation %s: Unexpected error while encoding "
"class name %s",
sound_event_annotation.uuid,
e,
)
continue
# If the label is None skip the sound event
if class_name is None:
# If the label is None skip the sound event
continue
class_index = targets.class_names.index(class_name)

View File

@ -1,6 +1,8 @@
import io
from typing import Annotated, Any, Literal, Optional, Union
from lightning.pytorch.loggers import Logger
import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
from loguru import logger
from pydantic import Field
@ -140,3 +142,35 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config)
def get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(
run_id=logger.run_id,
image=image,
key=name,
step=step,
)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -24,9 +24,7 @@ from batdetect2.train.augmentations import (
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import (
TrainingDataset,
)
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
@ -128,7 +126,9 @@ def build_training_module(
def build_trainer_callbacks(
targets: TargetProtocol, config: EvaluationConfig
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: EvaluationConfig,
) -> List[Callback]:
return [
ModelCheckpoint(
@ -144,6 +144,7 @@ def build_trainer_callbacks(
),
ClassificationAccuracy(class_names=targets.class_names),
],
preprocessor=preprocessor,
match_config=config.match,
),
]
@ -165,7 +166,11 @@ def build_trainer(
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
logger=train_logger,
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
callbacks=build_trainer_callbacks(
targets,
config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess),
),
)
@ -304,11 +309,11 @@ def build_val_dataset(
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
) -> TrainingDataset:
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig()
return TrainingDataset(
return ValidationDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,

View File

@ -11,13 +11,17 @@ __all__ = [
@dataclass
class MatchEvaluation:
match: data.Match
clip: data.Clip
sound_event_annotation: Optional[data.SoundEventAnnotation]
gt_det: bool
gt_class: Optional[str]
pred_score: float
pred_class_scores: Dict[str, float]
pred_geometry: Optional[data.Geometry]
affinity: float
@property
def pred_class(self) -> Optional[str]:

View File

@ -77,7 +77,16 @@ class RawPrediction(NamedTuple):
features: np.ndarray
class Detections(NamedTuple):
class DetectionsArray(NamedTuple):
scores: np.ndarray
sizes: np.ndarray
class_scores: np.ndarray
times: np.ndarray
frequencies: np.ndarray
features: np.ndarray
class DetectionsTensor(NamedTuple):
scores: torch.Tensor
sizes: torch.Tensor
class_scores: torch.Tensor
@ -85,6 +94,16 @@ class Detections(NamedTuple):
frequencies: torch.Tensor
features: torch.Tensor
def numpy(self) -> DetectionsArray:
return DetectionsArray(
scores=self.scores.detach().cpu().numpy(),
sizes=self.sizes.detach().cpu().numpy(),
class_scores=self.class_scores.detach().cpu().numpy(),
times=self.times.detach().cpu().numpy(),
frequencies=self.frequencies.detach().cpu().numpy(),
features=self.features.detach().cpu().numpy(),
)
@dataclass
class BatDetect2Prediction:
@ -95,10 +114,10 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[Detections]: ...
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
def get_detections(
self,
output: ModelOutput,
clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ...
) -> List[DetectionsTensor]: ...

View File

@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2.
"""
from collections.abc import Callable, Iterable
from typing import List, Optional, Protocol, Tuple
from collections.abc import Callable
from typing import List, Optional, Protocol
import numpy as np
from soundevent import data