Fix plotting after changes

This commit is contained in:
mbsantiago 2025-08-25 19:07:12 +01:00
parent 1f26103f42
commit cc9e47b022
9 changed files with 53 additions and 70 deletions

View File

@ -17,8 +17,6 @@ def plot_clip_annotation(
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
cmap: str = "gray",
alpha: float = 1,
@ -31,8 +29,6 @@ def plot_clip_annotation(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=cmap,
)

View File

@ -21,8 +21,6 @@ def plot_clip_prediction(
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_legend: bool = False,
spec_cmap: str = "gray",
linewidth: float = 1,
@ -34,8 +32,6 @@ def plot_clip_prediction(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)

View File

@ -1,13 +1,13 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import torch
from matplotlib.axes import Axes
from soundevent import data
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.preprocess import build_audio_loader, get_default_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [
"plot_clip",
@ -16,12 +16,11 @@ __all__ = [
def plot_clip(
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
spec_cmap: str = "gray",
) -> Axes:
if ax is None:
@ -30,15 +29,16 @@ def plot_clip(
if preprocessor is None:
preprocessor = get_default_preprocessor()
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
if audio_loader is None:
audio_loader = build_audio_loader()
spec.plot( # type: ignore
wav = torch.tensor(audio_loader.load_clip(clip, audio_dir=audio_dir))
spec = preprocessor(wav)
plot_spectrogram(
spec,
ax=ax,
add_colorbar=add_colorbar,
cmap=spec_cmap,
add_labels=add_labels,
vmin=spec.min().item(),
vmax=spec.max().item(),
)
return ax

View File

@ -40,8 +40,6 @@ def plot_matches(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
@ -59,8 +57,6 @@ def plot_matches(
ax=ax,
figsize=figsize,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
@ -128,8 +124,6 @@ def plot_false_positive_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
@ -160,8 +154,6 @@ def plot_false_positive_match(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
@ -196,8 +188,6 @@ def plot_false_negative_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
@ -226,8 +216,6 @@ def plot_false_negative_match(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
@ -262,8 +250,6 @@ def plot_true_positive_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
@ -294,8 +280,6 @@ def plot_true_positive_match(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
@ -343,8 +327,6 @@ def plot_cross_trigger_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
@ -375,8 +357,6 @@ def plot_cross_trigger_match(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)

View File

@ -67,7 +67,7 @@ class ValidationMetrics(Callback):
for class_name, fig in plot_example_gallery(
matches,
preprocessor=pl_module.preprocessor,
preprocessor=pl_module.model.preprocessor,
n_examples=4,
):
plotter(
@ -94,7 +94,7 @@ class ValidationMetrics(Callback):
) -> None:
matches = _match_all_collected_examples(
self._matches,
pl_module.targets,
pl_module.model.targets,
config=self.match_config,
)
@ -127,8 +127,8 @@ class ValidationMetrics(Callback):
batch,
outputs,
dataset=self.get_dataset(trainer),
postprocessor=pl_module.postprocessor,
targets=pl_module.targets,
postprocessor=pl_module.model.postprocessor,
targets=pl_module.model.targets,
)
)

View File

@ -73,12 +73,12 @@ class LabeledDataset(Dataset):
return load_preprocessed_example(self.filenames[idx])
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
item = np.load(self.filenames[idx])
return item["clip_annotation"]
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
return item["clip_annotation"].tolist()
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path)
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),

View File

@ -12,6 +12,8 @@ __all__ = [
class TrainingModule(L.LightningModule):
model: Model
def __init__(
self,
model: Model,

View File

@ -1,24 +1,4 @@
"""Preprocesses datasets for BatDetect2 model training.
This module provides functions to take a collection of annotated audio clips
(`soundevent.data.ClipAnnotation`) and process them into the final format
required for training a BatDetect2 model. This typically involves:
1. Loading the relevant audio segment for each annotation using a configured
`PreprocessorProtocol`.
2. Generating the corresponding input spectrogram using the
`PreprocessorProtocol`.
3. Generating the target heatmaps (detection, classification, size) using a
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
4. Packaging the input spectrogram, target heatmaps, and potentially the
processed audio waveform into an `xarray.Dataset`.
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
in an output directory.
This offline preprocessing is often preferred for large datasets as it avoids
computationally intensive steps during the actual training loop. The module
includes utilities for parallel processing using `multiprocessing`.
"""
"""Preprocesses datasets for BatDetect2 model training."""
import os
from pathlib import Path

View File

@ -1,6 +1,7 @@
from collections.abc import Sequence
from typing import List, Optional
import torch
from lightning import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
@ -25,7 +26,12 @@ from batdetect2.train.dataset import (
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
from batdetect2.typing import PreprocessorProtocol, TargetProtocol
from batdetect2.typing import (
PreprocessorProtocol,
TargetProtocol,
TrainExample,
)
from batdetect2.utils.arrays import adjust_width
__all__ = [
"build_train_dataset",
@ -53,7 +59,7 @@ def train(
else:
module = build_training_module(config)
trainer = build_trainer(config, targets=module.targets)
trainer = build_trainer(config, targets=module.model.targets)
train_dataloader = build_train_loader(
train_examples,
@ -160,6 +166,7 @@ def build_train_loader(
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
@ -188,6 +195,28 @@ def build_val_loader(
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)