diff --git a/justfile b/justfile index b43a4dc..17abb8c 100644 --- a/justfile +++ b/justfile @@ -99,7 +99,7 @@ example-preprocess OPTIONS="": --dataset-field datasets.train \ --config example_data/config.yaml \ {{OPTIONS}} \ - example_data/datasets.yaml example_data/preprocessed + example_data/config.yaml example_data/preprocessed # Train on example data. example-train OPTIONS="": diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index acf6a95..befbf06 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -8,7 +8,7 @@ from batdetect2.typing import ClipperProtocol from batdetect2.typing.train import PreprocessedExample from batdetect2.utils.arrays import adjust_width -DEFAULT_TRAIN_CLIP_DURATION = 0.513 +DEFAULT_TRAIN_CLIP_DURATION = 0.512 DEFAULT_MAX_EMPTY_CLIP = 0.1 diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 1e67f23..928e687 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -3,13 +3,12 @@ from typing import List, Optional, Sequence, Tuple import numpy as np import torch -import xarray as xr from soundevent import data from torch.utils.data import Dataset from batdetect2.train.augmentations import Augmentation from batdetect2.typing import ClipperProtocol, TrainExample -from batdetect2.utils.tensors import adjust_width +from batdetect2.typing.train import PreprocessedExample __all__ = [ "LabeledDataset", @@ -31,17 +30,18 @@ class LabeledDataset(Dataset): return len(self.filenames) def __getitem__(self, idx) -> TrainExample: - dataset = self.get_dataset(idx) - dataset, start_time, end_time = self.clipper.extract_clip(dataset) + example = self.get_example(idx) + + example, start_time, end_time = self.clipper.extract_clip(example) if self.augmentation: - dataset = self.augmentation(dataset) + example = self.augmentation(example) return TrainExample( - spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0), - detection_heatmap=self.to_tensor(dataset["detection"]), - class_heatmap=self.to_tensor(dataset["class"]), - size_heatmap=self.to_tensor(dataset["size"]), + spec=example.spectrogram.unsqueeze(0), + detection_heatmap=example.detection_heatmap.unsqueeze(0), + class_heatmap=example.class_heatmap, + size_heatmap=example.size_heatmap, idx=torch.tensor(idx), start_time=torch.tensor(start_time), end_time=torch.tensor(end_time), @@ -52,7 +52,7 @@ class LabeledDataset(Dataset): cls, directory: data.PathLike, clipper: ClipperProtocol, - extension: str = ".nc", + extension: str = ".npz", augmentation: Optional[Augmentation] = None, ): return cls( @@ -61,55 +61,35 @@ class LabeledDataset(Dataset): augmentation=augmentation, ) - def get_random_example(self) -> Tuple[xr.Dataset, float, float]: + def get_random_example(self) -> Tuple[PreprocessedExample, float, float]: idx = np.random.randint(0, len(self)) - dataset = self.get_dataset(idx) + dataset = self.get_example(idx) dataset, start_time, end_time = self.clipper.extract_clip(dataset) return dataset, start_time, end_time - def get_dataset(self, idx) -> xr.Dataset: - return xr.open_dataset(self.filenames[idx]) + def get_example(self, idx) -> PreprocessedExample: + return load_preprocessed_example(self.filenames[idx]) def get_clip_annotation(self, idx) -> data.ClipAnnotation: - return data.ClipAnnotation.model_validate_json( - self.get_dataset(idx).attrs["clip_annotation"] - ) - - def to_tensor( - self, - array: xr.DataArray, - dtype=np.float32, - ) -> torch.Tensor: - return torch.nan_to_num( - torch.tensor(array.values.astype(dtype)), - nan=0, - ) + item = np.load(self.filenames[idx]) + return item["clip_annotation"] -def collate_fn(batch: List[TrainExample]): - width = 512 - - return TrainExample( - spec=torch.stack([adjust_width(x.spec, width) for x in batch]), - detection_heatmap=torch.stack( - [adjust_width(x.detection_heatmap, width) for x in batch] - ), - class_heatmap=torch.stack( - [adjust_width(x.class_heatmap, width) for x in batch] - ), - size_heatmap=torch.stack( - [adjust_width(x.size_heatmap, width) for x in batch] - ), - idx=torch.stack([x.idx for x in batch]), - start_time=torch.stack([x.start_time for x in batch]), - end_time=torch.stack([x.end_time for x in batch]), +def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample: + item = np.load(path) + return PreprocessedExample( + audio=torch.tensor(item["audio"]), + spectrogram=torch.tensor(item["spectrogram"]), + size_heatmap=torch.tensor(item["size_heatmap"]), + detection_heatmap=torch.tensor(item["detection_heatmap"]), + class_heatmap=torch.tensor(item["class_heatmap"]), ) def list_preprocessed_files( - directory: data.PathLike, extension: str = ".nc" + directory: data.PathLike, extension: str = ".npz" ) -> List[Path]: return list(Path(directory).glob(f"*{extension}")) @@ -123,9 +103,9 @@ class RandomExampleSource: self.filenames = filenames self.clipper = clipper - def __call__(self): + def __call__(self) -> PreprocessedExample: index = int(np.random.randint(len(self.filenames))) filename = self.filenames[index] - dataset = xr.open_dataset(filename) - example, _, _ = self.clipper.extract_clip(dataset) + example = load_preprocessed_example(filename) + example, _, _ = self.clipper.extract_clip(example) return example diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index ebdd489..fdc0c70 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -30,6 +30,7 @@ import torch.utils.data from loguru import logger from pydantic import Field from soundevent import data +from tqdm import tqdm from batdetect2.configs import BaseConfig, load_config from batdetect2.data.datasets import Dataset @@ -132,34 +133,47 @@ class PreprocessingDataset(torch.utils.data.Dataset): audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, + filename_fn: FilenameFn, + output_dir: Path, + force: bool = False, ): self.clips = clips self.audio_loader = audio_loader self.preprocessor = preprocessor self.labeller = labeller + self.filename_fn = filename_fn + self.output_dir = output_dir + self.force = force - def __getitem__(self, idx) -> dict: + def __getitem__(self, idx) -> int: clip_annotation = self.clips[idx] + + filename = self.filename_fn(clip_annotation) + + path = self.output_dir / filename + + if path.exists() and not self.force: + return idx + + if not path.parent.exists(): + path.parent.mkdir() + example = generate_train_example( clip_annotation, audio_loader=self.audio_loader, preprocessor=self.preprocessor, labeller=self.labeller, ) - return { - "idx": idx, - "spectrogram": example.spectrogram, - "audio": example.audio, - "class_heatmap": example.class_heatmap, - "size_heatmap": example.size_heatmap, - "detection_heatmap": example.detection_heatmap, - } + + save_example_to_file(example, clip_annotation, path) + + return idx def __len__(self) -> int: return len(self.clips) -def _save_example_to_file( +def save_example_to_file( example: PreprocessedExample, clip_annotation: data.ClipAnnotation, path: data.PathLike, @@ -177,7 +191,7 @@ def _save_example_to_file( def _get_filename(clip_annotation: data.ClipAnnotation) -> str: """Generate a default output filename based on the annotation UUID.""" - return f"{clip_annotation.uuid}.nc" + return f"{clip_annotation.uuid}" def preprocess_annotations( @@ -212,24 +226,17 @@ def preprocess_annotations( audio_loader=audio_loader, preprocessor=preprocessor, labeller=labeller, + output_dir=Path(output_dir), + filename_fn=filename_fn, ) loader = torch.utils.data.DataLoader( dataset, - batch_size=None, + batch_size=1, shuffle=False, num_workers=max_workers, + prefetch_factor=16, ) - for batch in loader: - clip_annotation = dataset.clips[batch["idx"]] - filename = filename_fn(clip_annotation) - path = output_dir / filename - example = PreprocessedExample( - spectrogram=batch["spectrogram"], - audio=batch["audio"], - class_heatmap=batch["class_heatmap"], - size_heatmap=batch["size_heatmap"], - detection_heatmap=batch["detection_heatmap"], - ) - _save_example_to_file(example, clip_annotation, path) + for _ in tqdm(loader, total=len(dataset)): + pass diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 11d403f..a6fda60 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -21,7 +21,6 @@ from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.dataset import ( LabeledDataset, RandomExampleSource, - collate_fn, ) from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger @@ -58,7 +57,7 @@ def train( train_dataloader = build_train_loader( train_examples, - preprocessor=module.preprocessor, + preprocessor=module.model.preprocessor, config=config.train, num_workers=train_workers, ) @@ -66,6 +65,7 @@ def train( val_dataloader = ( build_val_loader( val_examples, + preprocessor=module.model.preprocessor, config=config.train, num_workers=val_workers, ) @@ -138,9 +138,11 @@ def build_trainer( def build_train_loader( train_examples: Sequence[data.PathLike], preprocessor: PreprocessorProtocol, - config: TrainingConfig, + config: Optional[TrainingConfig] = None, num_workers: Optional[int] = None, ) -> DataLoader: + config = config or TrainingConfig() + logger.info("Building training data loader...") train_dataset = build_train_dataset( train_examples, @@ -158,18 +160,21 @@ def build_train_loader( batch_size=loader_conf.batch_size, shuffle=loader_conf.shuffle, num_workers=num_workers, - collate_fn=collate_fn, ) def build_val_loader( val_examples: Sequence[data.PathLike], - config: TrainingConfig, + preprocessor: PreprocessorProtocol, + config: Optional[TrainingConfig] = None, num_workers: Optional[int] = None, ): + config = config or TrainingConfig() + logger.info("Building validation data loader...") val_dataset = build_val_dataset( val_examples, + preprocessor=preprocessor, config=config, ) loader_conf = config.dataloaders.val @@ -183,7 +188,6 @@ def build_val_loader( batch_size=loader_conf.batch_size, shuffle=loader_conf.shuffle, num_workers=num_workers, - collate_fn=collate_fn, ) @@ -195,7 +199,11 @@ def build_train_dataset( logger.info("Building training dataset...") config = config or TrainingConfig() - clipper = build_clipper(config.cliping, random=True) + clipper = build_clipper( + samplerate=preprocessor.samplerate, + config=config.cliping, + random=True, + ) random_example_source = RandomExampleSource( list(examples), @@ -221,10 +229,15 @@ def build_train_dataset( def build_val_dataset( examples: Sequence[data.PathLike], + preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, train: bool = True, ) -> LabeledDataset: logger.info("Building validation dataset...") config = config or TrainingConfig() - clipper = build_clipper(config.cliping, random=train) + clipper = build_clipper( + samplerate=preprocessor.samplerate, + config=config.cliping, + random=train, + ) return LabeledDataset(examples, clipper=clipper)