mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Fix plotting after changes
This commit is contained in:
parent
1f26103f42
commit
cc9e47b022
@ -17,8 +17,6 @@ def plot_clip_annotation(
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
cmap: str = "gray",
|
||||
alpha: float = 1,
|
||||
@ -31,8 +29,6 @@ def plot_clip_annotation(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=cmap,
|
||||
)
|
||||
|
||||
|
||||
@ -21,8 +21,6 @@ def plot_clip_prediction(
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_legend: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
linewidth: float = 1,
|
||||
@ -34,8 +32,6 @@ def plot_clip_prediction(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
get_default_preprocessor,
|
||||
)
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.preprocess import build_audio_loader, get_default_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip",
|
||||
@ -16,12 +16,11 @@ __all__ = [
|
||||
|
||||
def plot_clip(
|
||||
clip: data.Clip,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
) -> Axes:
|
||||
if ax is None:
|
||||
@ -30,15 +29,16 @@ def plot_clip(
|
||||
if preprocessor is None:
|
||||
preprocessor = get_default_preprocessor()
|
||||
|
||||
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
spec.plot( # type: ignore
|
||||
wav = torch.tensor(audio_loader.load_clip(clip, audio_dir=audio_dir))
|
||||
spec = preprocessor(wav)
|
||||
|
||||
plot_spectrogram(
|
||||
spec,
|
||||
ax=ax,
|
||||
add_colorbar=add_colorbar,
|
||||
cmap=spec_cmap,
|
||||
add_labels=add_labels,
|
||||
vmin=spec.min().item(),
|
||||
vmax=spec.max().item(),
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -40,8 +40,6 @@ def plot_matches(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
@ -59,8 +57,6 @@ def plot_matches(
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
@ -128,8 +124,6 @@ def plot_false_positive_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
@ -160,8 +154,6 @@ def plot_false_positive_match(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
@ -196,8 +188,6 @@ def plot_false_negative_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
@ -226,8 +216,6 @@ def plot_false_negative_match(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
@ -262,8 +250,6 @@ def plot_true_positive_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
@ -294,8 +280,6 @@ def plot_true_positive_match(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
@ -343,8 +327,6 @@ def plot_cross_trigger_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
@ -375,8 +357,6 @@ def plot_cross_trigger_match(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ class ValidationMetrics(Callback):
|
||||
|
||||
for class_name, fig in plot_example_gallery(
|
||||
matches,
|
||||
preprocessor=pl_module.preprocessor,
|
||||
preprocessor=pl_module.model.preprocessor,
|
||||
n_examples=4,
|
||||
):
|
||||
plotter(
|
||||
@ -94,7 +94,7 @@ class ValidationMetrics(Callback):
|
||||
) -> None:
|
||||
matches = _match_all_collected_examples(
|
||||
self._matches,
|
||||
pl_module.targets,
|
||||
pl_module.model.targets,
|
||||
config=self.match_config,
|
||||
)
|
||||
|
||||
@ -127,8 +127,8 @@ class ValidationMetrics(Callback):
|
||||
batch,
|
||||
outputs,
|
||||
dataset=self.get_dataset(trainer),
|
||||
postprocessor=pl_module.postprocessor,
|
||||
targets=pl_module.targets,
|
||||
postprocessor=pl_module.model.postprocessor,
|
||||
targets=pl_module.model.targets,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -73,12 +73,12 @@ class LabeledDataset(Dataset):
|
||||
return load_preprocessed_example(self.filenames[idx])
|
||||
|
||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||
item = np.load(self.filenames[idx])
|
||||
return item["clip_annotation"]
|
||||
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
|
||||
return item["clip_annotation"].tolist()
|
||||
|
||||
|
||||
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||
item = np.load(path)
|
||||
item = np.load(path, mmap_mode="r+")
|
||||
return PreprocessedExample(
|
||||
audio=torch.tensor(item["audio"]),
|
||||
spectrogram=torch.tensor(item["spectrogram"]),
|
||||
|
||||
@ -12,6 +12,8 @@ __all__ = [
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
model: Model
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
|
||||
@ -1,24 +1,4 @@
|
||||
"""Preprocesses datasets for BatDetect2 model training.
|
||||
|
||||
This module provides functions to take a collection of annotated audio clips
|
||||
(`soundevent.data.ClipAnnotation`) and process them into the final format
|
||||
required for training a BatDetect2 model. This typically involves:
|
||||
|
||||
1. Loading the relevant audio segment for each annotation using a configured
|
||||
`PreprocessorProtocol`.
|
||||
2. Generating the corresponding input spectrogram using the
|
||||
`PreprocessorProtocol`.
|
||||
3. Generating the target heatmaps (detection, classification, size) using a
|
||||
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
|
||||
4. Packaging the input spectrogram, target heatmaps, and potentially the
|
||||
processed audio waveform into an `xarray.Dataset`.
|
||||
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
|
||||
in an output directory.
|
||||
|
||||
This offline preprocessing is often preferred for large datasets as it avoids
|
||||
computationally intensive steps during the actual training loop. The module
|
||||
includes utilities for parallel processing using `multiprocessing`.
|
||||
"""
|
||||
"""Preprocesses datasets for BatDetect2 model training."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
@ -25,7 +26,12 @@ from batdetect2.train.dataset import (
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import PreprocessorProtocol, TargetProtocol
|
||||
from batdetect2.typing import (
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
TrainExample,
|
||||
)
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
|
||||
__all__ = [
|
||||
"build_train_dataset",
|
||||
@ -53,7 +59,7 @@ def train(
|
||||
else:
|
||||
module = build_training_module(config)
|
||||
|
||||
trainer = build_trainer(config, targets=module.targets)
|
||||
trainer = build_trainer(config, targets=module.model.targets)
|
||||
|
||||
train_dataloader = build_train_loader(
|
||||
train_examples,
|
||||
@ -160,6 +166,7 @@ def build_train_loader(
|
||||
batch_size=loader_conf.batch_size,
|
||||
shuffle=loader_conf.shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
@ -188,6 +195,28 @@ def build_val_loader(
|
||||
batch_size=loader_conf.batch_size,
|
||||
shuffle=loader_conf.shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return TrainExample(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
detection_heatmap=torch.stack(
|
||||
[adjust_width(item.detection_heatmap, max_width) for item in batch]
|
||||
),
|
||||
size_heatmap=torch.stack(
|
||||
[adjust_width(item.size_heatmap, max_width) for item in batch]
|
||||
),
|
||||
class_heatmap=torch.stack(
|
||||
[adjust_width(item.class_heatmap, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user