mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Fix plotting after changes
This commit is contained in:
parent
1f26103f42
commit
cc9e47b022
@ -17,8 +17,6 @@ def plot_clip_annotation(
|
|||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
cmap: str = "gray",
|
cmap: str = "gray",
|
||||||
alpha: float = 1,
|
alpha: float = 1,
|
||||||
@ -31,8 +29,6 @@ def plot_clip_annotation(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
add_labels=add_labels,
|
|
||||||
spec_cmap=cmap,
|
spec_cmap=cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -21,8 +21,6 @@ def plot_clip_prediction(
|
|||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
add_legend: bool = False,
|
add_legend: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
linewidth: float = 1,
|
linewidth: float = 1,
|
||||||
@ -34,8 +32,6 @@ def plot_clip_prediction(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
add_labels=add_labels,
|
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.plotting.common import plot_spectrogram
|
||||||
PreprocessorProtocol,
|
from batdetect2.preprocess import build_audio_loader, get_default_preprocessor
|
||||||
get_default_preprocessor,
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip",
|
"plot_clip",
|
||||||
@ -16,12 +16,11 @@ __all__ = [
|
|||||||
|
|
||||||
def plot_clip(
|
def plot_clip(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
if ax is None:
|
if ax is None:
|
||||||
@ -30,15 +29,16 @@ def plot_clip(
|
|||||||
if preprocessor is None:
|
if preprocessor is None:
|
||||||
preprocessor = get_default_preprocessor()
|
preprocessor = get_default_preprocessor()
|
||||||
|
|
||||||
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
spec.plot( # type: ignore
|
wav = torch.tensor(audio_loader.load_clip(clip, audio_dir=audio_dir))
|
||||||
|
spec = preprocessor(wav)
|
||||||
|
|
||||||
|
plot_spectrogram(
|
||||||
|
spec,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
cmap=spec_cmap,
|
cmap=spec_cmap,
|
||||||
add_labels=add_labels,
|
|
||||||
vmin=spec.min().item(),
|
|
||||||
vmax=spec.max().item(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -40,8 +40,6 @@ def plot_matches(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
color_mapper: Optional[TagColorMapper] = None,
|
color_mapper: Optional[TagColorMapper] = None,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
@ -59,8 +57,6 @@ def plot_matches(
|
|||||||
ax=ax,
|
ax=ax,
|
||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
add_labels=add_labels,
|
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,8 +124,6 @@ def plot_false_positive_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
@ -160,8 +154,6 @@ def plot_false_positive_match(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
add_labels=add_labels,
|
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,8 +188,6 @@ def plot_false_negative_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
@ -226,8 +216,6 @@ def plot_false_negative_match(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
add_labels=add_labels,
|
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -262,8 +250,6 @@ def plot_true_positive_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
@ -294,8 +280,6 @@ def plot_true_positive_match(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
add_labels=add_labels,
|
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -343,8 +327,6 @@ def plot_cross_trigger_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_colorbar: bool = False,
|
|
||||||
add_labels: bool = False,
|
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
@ -375,8 +357,6 @@ def plot_cross_trigger_match(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
add_colorbar=add_colorbar,
|
|
||||||
add_labels=add_labels,
|
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
for class_name, fig in plot_example_gallery(
|
for class_name, fig in plot_example_gallery(
|
||||||
matches,
|
matches,
|
||||||
preprocessor=pl_module.preprocessor,
|
preprocessor=pl_module.model.preprocessor,
|
||||||
n_examples=4,
|
n_examples=4,
|
||||||
):
|
):
|
||||||
plotter(
|
plotter(
|
||||||
@ -94,7 +94,7 @@ class ValidationMetrics(Callback):
|
|||||||
) -> None:
|
) -> None:
|
||||||
matches = _match_all_collected_examples(
|
matches = _match_all_collected_examples(
|
||||||
self._matches,
|
self._matches,
|
||||||
pl_module.targets,
|
pl_module.model.targets,
|
||||||
config=self.match_config,
|
config=self.match_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,8 +127,8 @@ class ValidationMetrics(Callback):
|
|||||||
batch,
|
batch,
|
||||||
outputs,
|
outputs,
|
||||||
dataset=self.get_dataset(trainer),
|
dataset=self.get_dataset(trainer),
|
||||||
postprocessor=pl_module.postprocessor,
|
postprocessor=pl_module.model.postprocessor,
|
||||||
targets=pl_module.targets,
|
targets=pl_module.model.targets,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -73,12 +73,12 @@ class LabeledDataset(Dataset):
|
|||||||
return load_preprocessed_example(self.filenames[idx])
|
return load_preprocessed_example(self.filenames[idx])
|
||||||
|
|
||||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||||
item = np.load(self.filenames[idx])
|
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
|
||||||
return item["clip_annotation"]
|
return item["clip_annotation"].tolist()
|
||||||
|
|
||||||
|
|
||||||
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||||
item = np.load(path)
|
item = np.load(path, mmap_mode="r+")
|
||||||
return PreprocessedExample(
|
return PreprocessedExample(
|
||||||
audio=torch.tensor(item["audio"]),
|
audio=torch.tensor(item["audio"]),
|
||||||
spectrogram=torch.tensor(item["spectrogram"]),
|
spectrogram=torch.tensor(item["spectrogram"]),
|
||||||
|
|||||||
@ -12,6 +12,8 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
|
model: Model
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
|
|||||||
@ -1,24 +1,4 @@
|
|||||||
"""Preprocesses datasets for BatDetect2 model training.
|
"""Preprocesses datasets for BatDetect2 model training."""
|
||||||
|
|
||||||
This module provides functions to take a collection of annotated audio clips
|
|
||||||
(`soundevent.data.ClipAnnotation`) and process them into the final format
|
|
||||||
required for training a BatDetect2 model. This typically involves:
|
|
||||||
|
|
||||||
1. Loading the relevant audio segment for each annotation using a configured
|
|
||||||
`PreprocessorProtocol`.
|
|
||||||
2. Generating the corresponding input spectrogram using the
|
|
||||||
`PreprocessorProtocol`.
|
|
||||||
3. Generating the target heatmaps (detection, classification, size) using a
|
|
||||||
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
|
|
||||||
4. Packaging the input spectrogram, target heatmaps, and potentially the
|
|
||||||
processed audio waveform into an `xarray.Dataset`.
|
|
||||||
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
|
|
||||||
in an output directory.
|
|
||||||
|
|
||||||
This offline preprocessing is often preferred for large datasets as it avoids
|
|
||||||
computationally intensive steps during the actual training loop. The module
|
|
||||||
includes utilities for parallel processing using `multiprocessing`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -25,7 +26,12 @@ from batdetect2.train.dataset import (
|
|||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.typing import PreprocessorProtocol, TargetProtocol
|
from batdetect2.typing import (
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
@ -53,7 +59,7 @@ def train(
|
|||||||
else:
|
else:
|
||||||
module = build_training_module(config)
|
module = build_training_module(config)
|
||||||
|
|
||||||
trainer = build_trainer(config, targets=module.targets)
|
trainer = build_trainer(config, targets=module.model.targets)
|
||||||
|
|
||||||
train_dataloader = build_train_loader(
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_examples,
|
||||||
@ -160,6 +166,7 @@ def build_train_loader(
|
|||||||
batch_size=loader_conf.batch_size,
|
batch_size=loader_conf.batch_size,
|
||||||
shuffle=loader_conf.shuffle,
|
shuffle=loader_conf.shuffle,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
collate_fn=_collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -188,6 +195,28 @@ def build_val_loader(
|
|||||||
batch_size=loader_conf.batch_size,
|
batch_size=loader_conf.batch_size,
|
||||||
shuffle=loader_conf.shuffle,
|
shuffle=loader_conf.shuffle,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
collate_fn=_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
|
||||||
|
max_width = max(item.spec.shape[-1] for item in batch)
|
||||||
|
return TrainExample(
|
||||||
|
spec=torch.stack(
|
||||||
|
[adjust_width(item.spec, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
detection_heatmap=torch.stack(
|
||||||
|
[adjust_width(item.detection_heatmap, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
size_heatmap=torch.stack(
|
||||||
|
[adjust_width(item.size_heatmap, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
class_heatmap=torch.stack(
|
||||||
|
[adjust_width(item.class_heatmap, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
idx=torch.stack([item.idx for item in batch]),
|
||||||
|
start_time=torch.stack([item.start_time for item in batch]),
|
||||||
|
end_time=torch.stack([item.end_time for item in batch]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user