Cleanup train preprocessing

This commit is contained in:
mbsantiago 2025-08-25 18:37:46 +01:00
parent c80078feee
commit 1f26103f42
5 changed files with 82 additions and 82 deletions

View File

@ -99,7 +99,7 @@ example-preprocess OPTIONS="":
--dataset-field datasets.train \
--config example_data/config.yaml \
{{OPTIONS}} \
example_data/datasets.yaml example_data/preprocessed
example_data/config.yaml example_data/preprocessed
# Train on example data.
example-train OPTIONS="":

View File

@ -8,7 +8,7 @@ from batdetect2.typing import ClipperProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width
DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_TRAIN_CLIP_DURATION = 0.512
DEFAULT_MAX_EMPTY_CLIP = 0.1

View File

@ -3,13 +3,12 @@ from typing import List, Optional, Sequence, Tuple
import numpy as np
import torch
import xarray as xr
from soundevent import data
from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.utils.tensors import adjust_width
from batdetect2.typing.train import PreprocessedExample
__all__ = [
"LabeledDataset",
@ -31,17 +30,18 @@ class LabeledDataset(Dataset):
return len(self.filenames)
def __getitem__(self, idx) -> TrainExample:
dataset = self.get_dataset(idx)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
example = self.get_example(idx)
example, start_time, end_time = self.clipper.extract_clip(example)
if self.augmentation:
dataset = self.augmentation(dataset)
example = self.augmentation(example)
return TrainExample(
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
detection_heatmap=self.to_tensor(dataset["detection"]),
class_heatmap=self.to_tensor(dataset["class"]),
size_heatmap=self.to_tensor(dataset["size"]),
spec=example.spectrogram.unsqueeze(0),
detection_heatmap=example.detection_heatmap.unsqueeze(0),
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
idx=torch.tensor(idx),
start_time=torch.tensor(start_time),
end_time=torch.tensor(end_time),
@ -52,7 +52,7 @@ class LabeledDataset(Dataset):
cls,
directory: data.PathLike,
clipper: ClipperProtocol,
extension: str = ".nc",
extension: str = ".npz",
augmentation: Optional[Augmentation] = None,
):
return cls(
@ -61,55 +61,35 @@ class LabeledDataset(Dataset):
augmentation=augmentation,
)
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
idx = np.random.randint(0, len(self))
dataset = self.get_dataset(idx)
dataset = self.get_example(idx)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
return dataset, start_time, end_time
def get_dataset(self, idx) -> xr.Dataset:
return xr.open_dataset(self.filenames[idx])
def get_example(self, idx) -> PreprocessedExample:
return load_preprocessed_example(self.filenames[idx])
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
return data.ClipAnnotation.model_validate_json(
self.get_dataset(idx).attrs["clip_annotation"]
)
def to_tensor(
self,
array: xr.DataArray,
dtype=np.float32,
) -> torch.Tensor:
return torch.nan_to_num(
torch.tensor(array.values.astype(dtype)),
nan=0,
)
item = np.load(self.filenames[idx])
return item["clip_annotation"]
def collate_fn(batch: List[TrainExample]):
width = 512
return TrainExample(
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
detection_heatmap=torch.stack(
[adjust_width(x.detection_heatmap, width) for x in batch]
),
class_heatmap=torch.stack(
[adjust_width(x.class_heatmap, width) for x in batch]
),
size_heatmap=torch.stack(
[adjust_width(x.size_heatmap, width) for x in batch]
),
idx=torch.stack([x.idx for x in batch]),
start_time=torch.stack([x.start_time for x in batch]),
end_time=torch.stack([x.end_time for x in batch]),
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path)
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc"
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
@ -123,9 +103,9 @@ class RandomExampleSource:
self.filenames = filenames
self.clipper = clipper
def __call__(self):
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
dataset = xr.open_dataset(filename)
example, _, _ = self.clipper.extract_clip(dataset)
example = load_preprocessed_example(filename)
example, _, _ = self.clipper.extract_clip(example)
return example

View File

@ -30,6 +30,7 @@ import torch.utils.data
from loguru import logger
from pydantic import Field
from soundevent import data
from tqdm import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
@ -132,34 +133,47 @@ class PreprocessingDataset(torch.utils.data.Dataset):
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn,
output_dir: Path,
force: bool = False,
):
self.clips = clips
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.labeller = labeller
self.filename_fn = filename_fn
self.output_dir = output_dir
self.force = force
def __getitem__(self, idx) -> dict:
def __getitem__(self, idx) -> int:
clip_annotation = self.clips[idx]
filename = self.filename_fn(clip_annotation)
path = self.output_dir / filename
if path.exists() and not self.force:
return idx
if not path.parent.exists():
path.parent.mkdir()
example = generate_train_example(
clip_annotation,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
labeller=self.labeller,
)
return {
"idx": idx,
"spectrogram": example.spectrogram,
"audio": example.audio,
"class_heatmap": example.class_heatmap,
"size_heatmap": example.size_heatmap,
"detection_heatmap": example.detection_heatmap,
}
save_example_to_file(example, clip_annotation, path)
return idx
def __len__(self) -> int:
return len(self.clips)
def _save_example_to_file(
def save_example_to_file(
example: PreprocessedExample,
clip_annotation: data.ClipAnnotation,
path: data.PathLike,
@ -177,7 +191,7 @@ def _save_example_to_file(
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}.nc"
return f"{clip_annotation.uuid}"
def preprocess_annotations(
@ -212,24 +226,17 @@ def preprocess_annotations(
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
output_dir=Path(output_dir),
filename_fn=filename_fn,
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
batch_size=1,
shuffle=False,
num_workers=max_workers,
prefetch_factor=16,
)
for batch in loader:
clip_annotation = dataset.clips[batch["idx"]]
filename = filename_fn(clip_annotation)
path = output_dir / filename
example = PreprocessedExample(
spectrogram=batch["spectrogram"],
audio=batch["audio"],
class_heatmap=batch["class_heatmap"],
size_heatmap=batch["size_heatmap"],
detection_heatmap=batch["detection_heatmap"],
)
_save_example_to_file(example, clip_annotation, path)
for _ in tqdm(loader, total=len(dataset)):
pass

View File

@ -21,7 +21,6 @@ from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
collate_fn,
)
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
@ -58,7 +57,7 @@ def train(
train_dataloader = build_train_loader(
train_examples,
preprocessor=module.preprocessor,
preprocessor=module.model.preprocessor,
config=config.train,
num_workers=train_workers,
)
@ -66,6 +65,7 @@ def train(
val_dataloader = (
build_val_loader(
val_examples,
preprocessor=module.model.preprocessor,
config=config.train,
num_workers=val_workers,
)
@ -138,9 +138,11 @@ def build_trainer(
def build_train_loader(
train_examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol,
config: TrainingConfig,
config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainingConfig()
logger.info("Building training data loader...")
train_dataset = build_train_dataset(
train_examples,
@ -158,18 +160,21 @@ def build_train_loader(
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
)
def build_val_loader(
val_examples: Sequence[data.PathLike],
config: TrainingConfig,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None,
):
config = config or TrainingConfig()
logger.info("Building validation data loader...")
val_dataset = build_val_dataset(
val_examples,
preprocessor=preprocessor,
config=config,
)
loader_conf = config.dataloaders.val
@ -183,7 +188,6 @@ def build_val_loader(
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
)
@ -195,7 +199,11 @@ def build_train_dataset(
logger.info("Building training dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=True)
clipper = build_clipper(
samplerate=preprocessor.samplerate,
config=config.cliping,
random=True,
)
random_example_source = RandomExampleSource(
list(examples),
@ -221,10 +229,15 @@ def build_train_dataset(
def build_val_dataset(
examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
clipper = build_clipper(
samplerate=preprocessor.samplerate,
config=config.cliping,
random=train,
)
return LabeledDataset(examples, clipper=clipper)