Cleanup train preprocessing

This commit is contained in:
mbsantiago 2025-08-25 18:37:46 +01:00
parent c80078feee
commit 1f26103f42
5 changed files with 82 additions and 82 deletions

View File

@ -99,7 +99,7 @@ example-preprocess OPTIONS="":
--dataset-field datasets.train \ --dataset-field datasets.train \
--config example_data/config.yaml \ --config example_data/config.yaml \
{{OPTIONS}} \ {{OPTIONS}} \
example_data/datasets.yaml example_data/preprocessed example_data/config.yaml example_data/preprocessed
# Train on example data. # Train on example data.
example-train OPTIONS="": example-train OPTIONS="":

View File

@ -8,7 +8,7 @@ from batdetect2.typing import ClipperProtocol
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width
DEFAULT_TRAIN_CLIP_DURATION = 0.513 DEFAULT_TRAIN_CLIP_DURATION = 0.512
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1

View File

@ -3,13 +3,12 @@ from typing import List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
import xarray as xr
from soundevent import data from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation from batdetect2.train.augmentations import Augmentation
from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.utils.tensors import adjust_width from batdetect2.typing.train import PreprocessedExample
__all__ = [ __all__ = [
"LabeledDataset", "LabeledDataset",
@ -31,17 +30,18 @@ class LabeledDataset(Dataset):
return len(self.filenames) return len(self.filenames)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
dataset = self.get_dataset(idx) example = self.get_example(idx)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
example, start_time, end_time = self.clipper.extract_clip(example)
if self.augmentation: if self.augmentation:
dataset = self.augmentation(dataset) example = self.augmentation(example)
return TrainExample( return TrainExample(
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0), spec=example.spectrogram.unsqueeze(0),
detection_heatmap=self.to_tensor(dataset["detection"]), detection_heatmap=example.detection_heatmap.unsqueeze(0),
class_heatmap=self.to_tensor(dataset["class"]), class_heatmap=example.class_heatmap,
size_heatmap=self.to_tensor(dataset["size"]), size_heatmap=example.size_heatmap,
idx=torch.tensor(idx), idx=torch.tensor(idx),
start_time=torch.tensor(start_time), start_time=torch.tensor(start_time),
end_time=torch.tensor(end_time), end_time=torch.tensor(end_time),
@ -52,7 +52,7 @@ class LabeledDataset(Dataset):
cls, cls,
directory: data.PathLike, directory: data.PathLike,
clipper: ClipperProtocol, clipper: ClipperProtocol,
extension: str = ".nc", extension: str = ".npz",
augmentation: Optional[Augmentation] = None, augmentation: Optional[Augmentation] = None,
): ):
return cls( return cls(
@ -61,55 +61,35 @@ class LabeledDataset(Dataset):
augmentation=augmentation, augmentation=augmentation,
) )
def get_random_example(self) -> Tuple[xr.Dataset, float, float]: def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
idx = np.random.randint(0, len(self)) idx = np.random.randint(0, len(self))
dataset = self.get_dataset(idx) dataset = self.get_example(idx)
dataset, start_time, end_time = self.clipper.extract_clip(dataset) dataset, start_time, end_time = self.clipper.extract_clip(dataset)
return dataset, start_time, end_time return dataset, start_time, end_time
def get_dataset(self, idx) -> xr.Dataset: def get_example(self, idx) -> PreprocessedExample:
return xr.open_dataset(self.filenames[idx]) return load_preprocessed_example(self.filenames[idx])
def get_clip_annotation(self, idx) -> data.ClipAnnotation: def get_clip_annotation(self, idx) -> data.ClipAnnotation:
return data.ClipAnnotation.model_validate_json( item = np.load(self.filenames[idx])
self.get_dataset(idx).attrs["clip_annotation"] return item["clip_annotation"]
)
def to_tensor(
self,
array: xr.DataArray,
dtype=np.float32,
) -> torch.Tensor:
return torch.nan_to_num(
torch.tensor(array.values.astype(dtype)),
nan=0,
)
def collate_fn(batch: List[TrainExample]): def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
width = 512 item = np.load(path)
return PreprocessedExample(
return TrainExample( audio=torch.tensor(item["audio"]),
spec=torch.stack([adjust_width(x.spec, width) for x in batch]), spectrogram=torch.tensor(item["spectrogram"]),
detection_heatmap=torch.stack( size_heatmap=torch.tensor(item["size_heatmap"]),
[adjust_width(x.detection_heatmap, width) for x in batch] detection_heatmap=torch.tensor(item["detection_heatmap"]),
), class_heatmap=torch.tensor(item["class_heatmap"]),
class_heatmap=torch.stack(
[adjust_width(x.class_heatmap, width) for x in batch]
),
size_heatmap=torch.stack(
[adjust_width(x.size_heatmap, width) for x in batch]
),
idx=torch.stack([x.idx for x in batch]),
start_time=torch.stack([x.start_time for x in batch]),
end_time=torch.stack([x.end_time for x in batch]),
) )
def list_preprocessed_files( def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc" directory: data.PathLike, extension: str = ".npz"
) -> List[Path]: ) -> List[Path]:
return list(Path(directory).glob(f"*{extension}")) return list(Path(directory).glob(f"*{extension}"))
@ -123,9 +103,9 @@ class RandomExampleSource:
self.filenames = filenames self.filenames = filenames
self.clipper = clipper self.clipper = clipper
def __call__(self): def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames))) index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index] filename = self.filenames[index]
dataset = xr.open_dataset(filename) example = load_preprocessed_example(filename)
example, _, _ = self.clipper.extract_clip(dataset) example, _, _ = self.clipper.extract_clip(example)
return example return example

View File

@ -30,6 +30,7 @@ import torch.utils.data
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from tqdm import tqdm
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
@ -132,34 +133,47 @@ class PreprocessingDataset(torch.utils.data.Dataset):
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
labeller: ClipLabeller, labeller: ClipLabeller,
filename_fn: FilenameFn,
output_dir: Path,
force: bool = False,
): ):
self.clips = clips self.clips = clips
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.labeller = labeller self.labeller = labeller
self.filename_fn = filename_fn
self.output_dir = output_dir
self.force = force
def __getitem__(self, idx) -> dict: def __getitem__(self, idx) -> int:
clip_annotation = self.clips[idx] clip_annotation = self.clips[idx]
filename = self.filename_fn(clip_annotation)
path = self.output_dir / filename
if path.exists() and not self.force:
return idx
if not path.parent.exists():
path.parent.mkdir()
example = generate_train_example( example = generate_train_example(
clip_annotation, clip_annotation,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
labeller=self.labeller, labeller=self.labeller,
) )
return {
"idx": idx, save_example_to_file(example, clip_annotation, path)
"spectrogram": example.spectrogram,
"audio": example.audio, return idx
"class_heatmap": example.class_heatmap,
"size_heatmap": example.size_heatmap,
"detection_heatmap": example.detection_heatmap,
}
def __len__(self) -> int: def __len__(self) -> int:
return len(self.clips) return len(self.clips)
def _save_example_to_file( def save_example_to_file(
example: PreprocessedExample, example: PreprocessedExample,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
path: data.PathLike, path: data.PathLike,
@ -177,7 +191,7 @@ def _save_example_to_file(
def _get_filename(clip_annotation: data.ClipAnnotation) -> str: def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID.""" """Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}.nc" return f"{clip_annotation.uuid}"
def preprocess_annotations( def preprocess_annotations(
@ -212,24 +226,17 @@ def preprocess_annotations(
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
labeller=labeller, labeller=labeller,
output_dir=Path(output_dir),
filename_fn=filename_fn,
) )
loader = torch.utils.data.DataLoader( loader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=None, batch_size=1,
shuffle=False, shuffle=False,
num_workers=max_workers, num_workers=max_workers,
prefetch_factor=16,
) )
for batch in loader: for _ in tqdm(loader, total=len(dataset)):
clip_annotation = dataset.clips[batch["idx"]] pass
filename = filename_fn(clip_annotation)
path = output_dir / filename
example = PreprocessedExample(
spectrogram=batch["spectrogram"],
audio=batch["audio"],
class_heatmap=batch["class_heatmap"],
size_heatmap=batch["size_heatmap"],
detection_heatmap=batch["detection_heatmap"],
)
_save_example_to_file(example, clip_annotation, path)

View File

@ -21,7 +21,6 @@ from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
collate_fn,
) )
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
@ -58,7 +57,7 @@ def train(
train_dataloader = build_train_loader( train_dataloader = build_train_loader(
train_examples, train_examples,
preprocessor=module.preprocessor, preprocessor=module.model.preprocessor,
config=config.train, config=config.train,
num_workers=train_workers, num_workers=train_workers,
) )
@ -66,6 +65,7 @@ def train(
val_dataloader = ( val_dataloader = (
build_val_loader( build_val_loader(
val_examples, val_examples,
preprocessor=module.model.preprocessor,
config=config.train, config=config.train,
num_workers=val_workers, num_workers=val_workers,
) )
@ -138,9 +138,11 @@ def build_trainer(
def build_train_loader( def build_train_loader(
train_examples: Sequence[data.PathLike], train_examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: TrainingConfig, config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
config = config or TrainingConfig()
logger.info("Building training data loader...") logger.info("Building training data loader...")
train_dataset = build_train_dataset( train_dataset = build_train_dataset(
train_examples, train_examples,
@ -158,18 +160,21 @@ def build_train_loader(
batch_size=loader_conf.batch_size, batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle, shuffle=loader_conf.shuffle,
num_workers=num_workers, num_workers=num_workers,
collate_fn=collate_fn,
) )
def build_val_loader( def build_val_loader(
val_examples: Sequence[data.PathLike], val_examples: Sequence[data.PathLike],
config: TrainingConfig, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
): ):
config = config or TrainingConfig()
logger.info("Building validation data loader...") logger.info("Building validation data loader...")
val_dataset = build_val_dataset( val_dataset = build_val_dataset(
val_examples, val_examples,
preprocessor=preprocessor,
config=config, config=config,
) )
loader_conf = config.dataloaders.val loader_conf = config.dataloaders.val
@ -183,7 +188,6 @@ def build_val_loader(
batch_size=loader_conf.batch_size, batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle, shuffle=loader_conf.shuffle,
num_workers=num_workers, num_workers=num_workers,
collate_fn=collate_fn,
) )
@ -195,7 +199,11 @@ def build_train_dataset(
logger.info("Building training dataset...") logger.info("Building training dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=True) clipper = build_clipper(
samplerate=preprocessor.samplerate,
config=config.cliping,
random=True,
)
random_example_source = RandomExampleSource( random_example_source = RandomExampleSource(
list(examples), list(examples),
@ -221,10 +229,15 @@ def build_train_dataset(
def build_val_dataset( def build_val_dataset(
examples: Sequence[data.PathLike], examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
train: bool = True, train: bool = True,
) -> LabeledDataset: ) -> LabeledDataset:
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train) clipper = build_clipper(
samplerate=preprocessor.samplerate,
config=config.cliping,
random=train,
)
return LabeledDataset(examples, clipper=clipper) return LabeledDataset(examples, clipper=clipper)