mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Cleanup train preprocessing
This commit is contained in:
parent
c80078feee
commit
1f26103f42
2
justfile
2
justfile
@ -99,7 +99,7 @@ example-preprocess OPTIONS="":
|
||||
--dataset-field datasets.train \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/datasets.yaml example_data/preprocessed
|
||||
example_data/config.yaml example_data/preprocessed
|
||||
|
||||
# Train on example data.
|
||||
example-train OPTIONS="":
|
||||
|
||||
@ -8,7 +8,7 @@ from batdetect2.typing import ClipperProtocol
|
||||
from batdetect2.typing.train import PreprocessedExample
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.512
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
|
||||
@ -3,13 +3,12 @@ from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.train.augmentations import Augmentation
|
||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||
from batdetect2.utils.tensors import adjust_width
|
||||
from batdetect2.typing.train import PreprocessedExample
|
||||
|
||||
__all__ = [
|
||||
"LabeledDataset",
|
||||
@ -31,17 +30,18 @@ class LabeledDataset(Dataset):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
dataset = self.get_dataset(idx)
|
||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||
example = self.get_example(idx)
|
||||
|
||||
example, start_time, end_time = self.clipper.extract_clip(example)
|
||||
|
||||
if self.augmentation:
|
||||
dataset = self.augmentation(dataset)
|
||||
example = self.augmentation(example)
|
||||
|
||||
return TrainExample(
|
||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||
detection_heatmap=self.to_tensor(dataset["detection"]),
|
||||
class_heatmap=self.to_tensor(dataset["class"]),
|
||||
size_heatmap=self.to_tensor(dataset["size"]),
|
||||
spec=example.spectrogram.unsqueeze(0),
|
||||
detection_heatmap=example.detection_heatmap.unsqueeze(0),
|
||||
class_heatmap=example.class_heatmap,
|
||||
size_heatmap=example.size_heatmap,
|
||||
idx=torch.tensor(idx),
|
||||
start_time=torch.tensor(start_time),
|
||||
end_time=torch.tensor(end_time),
|
||||
@ -52,7 +52,7 @@ class LabeledDataset(Dataset):
|
||||
cls,
|
||||
directory: data.PathLike,
|
||||
clipper: ClipperProtocol,
|
||||
extension: str = ".nc",
|
||||
extension: str = ".npz",
|
||||
augmentation: Optional[Augmentation] = None,
|
||||
):
|
||||
return cls(
|
||||
@ -61,55 +61,35 @@ class LabeledDataset(Dataset):
|
||||
augmentation=augmentation,
|
||||
)
|
||||
|
||||
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
|
||||
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
|
||||
idx = np.random.randint(0, len(self))
|
||||
dataset = self.get_dataset(idx)
|
||||
dataset = self.get_example(idx)
|
||||
|
||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||
|
||||
return dataset, start_time, end_time
|
||||
|
||||
def get_dataset(self, idx) -> xr.Dataset:
|
||||
return xr.open_dataset(self.filenames[idx])
|
||||
def get_example(self, idx) -> PreprocessedExample:
|
||||
return load_preprocessed_example(self.filenames[idx])
|
||||
|
||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||
return data.ClipAnnotation.model_validate_json(
|
||||
self.get_dataset(idx).attrs["clip_annotation"]
|
||||
)
|
||||
|
||||
def to_tensor(
|
||||
self,
|
||||
array: xr.DataArray,
|
||||
dtype=np.float32,
|
||||
) -> torch.Tensor:
|
||||
return torch.nan_to_num(
|
||||
torch.tensor(array.values.astype(dtype)),
|
||||
nan=0,
|
||||
)
|
||||
item = np.load(self.filenames[idx])
|
||||
return item["clip_annotation"]
|
||||
|
||||
|
||||
def collate_fn(batch: List[TrainExample]):
|
||||
width = 512
|
||||
|
||||
return TrainExample(
|
||||
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
|
||||
detection_heatmap=torch.stack(
|
||||
[adjust_width(x.detection_heatmap, width) for x in batch]
|
||||
),
|
||||
class_heatmap=torch.stack(
|
||||
[adjust_width(x.class_heatmap, width) for x in batch]
|
||||
),
|
||||
size_heatmap=torch.stack(
|
||||
[adjust_width(x.size_heatmap, width) for x in batch]
|
||||
),
|
||||
idx=torch.stack([x.idx for x in batch]),
|
||||
start_time=torch.stack([x.start_time for x in batch]),
|
||||
end_time=torch.stack([x.end_time for x in batch]),
|
||||
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||
item = np.load(path)
|
||||
return PreprocessedExample(
|
||||
audio=torch.tensor(item["audio"]),
|
||||
spectrogram=torch.tensor(item["spectrogram"]),
|
||||
size_heatmap=torch.tensor(item["size_heatmap"]),
|
||||
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
||||
class_heatmap=torch.tensor(item["class_heatmap"]),
|
||||
)
|
||||
|
||||
|
||||
def list_preprocessed_files(
|
||||
directory: data.PathLike, extension: str = ".nc"
|
||||
directory: data.PathLike, extension: str = ".npz"
|
||||
) -> List[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
||||
@ -123,9 +103,9 @@ class RandomExampleSource:
|
||||
self.filenames = filenames
|
||||
self.clipper = clipper
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> PreprocessedExample:
|
||||
index = int(np.random.randint(len(self.filenames)))
|
||||
filename = self.filenames[index]
|
||||
dataset = xr.open_dataset(filename)
|
||||
example, _, _ = self.clipper.extract_clip(dataset)
|
||||
example = load_preprocessed_example(filename)
|
||||
example, _, _ = self.clipper.extract_clip(example)
|
||||
return example
|
||||
|
||||
@ -30,6 +30,7 @@ import torch.utils.data
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
@ -132,34 +133,47 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn,
|
||||
output_dir: Path,
|
||||
force: bool = False,
|
||||
):
|
||||
self.clips = clips
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.labeller = labeller
|
||||
self.filename_fn = filename_fn
|
||||
self.output_dir = output_dir
|
||||
self.force = force
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
def __getitem__(self, idx) -> int:
|
||||
clip_annotation = self.clips[idx]
|
||||
|
||||
filename = self.filename_fn(clip_annotation)
|
||||
|
||||
path = self.output_dir / filename
|
||||
|
||||
if path.exists() and not self.force:
|
||||
return idx
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir()
|
||||
|
||||
example = generate_train_example(
|
||||
clip_annotation,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
labeller=self.labeller,
|
||||
)
|
||||
return {
|
||||
"idx": idx,
|
||||
"spectrogram": example.spectrogram,
|
||||
"audio": example.audio,
|
||||
"class_heatmap": example.class_heatmap,
|
||||
"size_heatmap": example.size_heatmap,
|
||||
"detection_heatmap": example.detection_heatmap,
|
||||
}
|
||||
|
||||
save_example_to_file(example, clip_annotation, path)
|
||||
|
||||
return idx
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
|
||||
|
||||
def _save_example_to_file(
|
||||
def save_example_to_file(
|
||||
example: PreprocessedExample,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
path: data.PathLike,
|
||||
@ -177,7 +191,7 @@ def _save_example_to_file(
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
"""Generate a default output filename based on the annotation UUID."""
|
||||
return f"{clip_annotation.uuid}.nc"
|
||||
return f"{clip_annotation.uuid}"
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
@ -212,24 +226,17 @@ def preprocess_annotations(
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
output_dir=Path(output_dir),
|
||||
filename_fn=filename_fn,
|
||||
)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=None,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=max_workers,
|
||||
prefetch_factor=16,
|
||||
)
|
||||
|
||||
for batch in loader:
|
||||
clip_annotation = dataset.clips[batch["idx"]]
|
||||
filename = filename_fn(clip_annotation)
|
||||
path = output_dir / filename
|
||||
example = PreprocessedExample(
|
||||
spectrogram=batch["spectrogram"],
|
||||
audio=batch["audio"],
|
||||
class_heatmap=batch["class_heatmap"],
|
||||
size_heatmap=batch["size_heatmap"],
|
||||
detection_heatmap=batch["detection_heatmap"],
|
||||
)
|
||||
_save_example_to_file(example, clip_annotation, path)
|
||||
for _ in tqdm(loader, total=len(dataset)):
|
||||
pass
|
||||
|
||||
@ -21,7 +21,6 @@ from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
collate_fn,
|
||||
)
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import build_logger
|
||||
@ -58,7 +57,7 @@ def train(
|
||||
|
||||
train_dataloader = build_train_loader(
|
||||
train_examples,
|
||||
preprocessor=module.preprocessor,
|
||||
preprocessor=module.model.preprocessor,
|
||||
config=config.train,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
@ -66,6 +65,7 @@ def train(
|
||||
val_dataloader = (
|
||||
build_val_loader(
|
||||
val_examples,
|
||||
preprocessor=module.model.preprocessor,
|
||||
config=config.train,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
@ -138,9 +138,11 @@ def build_trainer(
|
||||
def build_train_loader(
|
||||
train_examples: Sequence[data.PathLike],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: TrainingConfig,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
config = config or TrainingConfig()
|
||||
|
||||
logger.info("Building training data loader...")
|
||||
train_dataset = build_train_dataset(
|
||||
train_examples,
|
||||
@ -158,18 +160,21 @@ def build_train_loader(
|
||||
batch_size=loader_conf.batch_size,
|
||||
shuffle=loader_conf.shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_val_loader(
|
||||
val_examples: Sequence[data.PathLike],
|
||||
config: TrainingConfig,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
config = config or TrainingConfig()
|
||||
|
||||
logger.info("Building validation data loader...")
|
||||
val_dataset = build_val_dataset(
|
||||
val_examples,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
loader_conf = config.dataloaders.val
|
||||
@ -183,7 +188,6 @@ def build_val_loader(
|
||||
batch_size=loader_conf.batch_size,
|
||||
shuffle=loader_conf.shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
@ -195,7 +199,11 @@ def build_train_dataset(
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TrainingConfig()
|
||||
|
||||
clipper = build_clipper(config.cliping, random=True)
|
||||
clipper = build_clipper(
|
||||
samplerate=preprocessor.samplerate,
|
||||
config=config.cliping,
|
||||
random=True,
|
||||
)
|
||||
|
||||
random_example_source = RandomExampleSource(
|
||||
list(examples),
|
||||
@ -221,10 +229,15 @@ def build_train_dataset(
|
||||
|
||||
def build_val_dataset(
|
||||
examples: Sequence[data.PathLike],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
train: bool = True,
|
||||
) -> LabeledDataset:
|
||||
logger.info("Building validation dataset...")
|
||||
config = config or TrainingConfig()
|
||||
clipper = build_clipper(config.cliping, random=train)
|
||||
clipper = build_clipper(
|
||||
samplerate=preprocessor.samplerate,
|
||||
config=config.cliping,
|
||||
random=train,
|
||||
)
|
||||
return LabeledDataset(examples, clipper=clipper)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user