mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Cleanup train preprocessing
This commit is contained in:
parent
c80078feee
commit
1f26103f42
2
justfile
2
justfile
@ -99,7 +99,7 @@ example-preprocess OPTIONS="":
|
|||||||
--dataset-field datasets.train \
|
--dataset-field datasets.train \
|
||||||
--config example_data/config.yaml \
|
--config example_data/config.yaml \
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
example_data/datasets.yaml example_data/preprocessed
|
example_data/config.yaml example_data/preprocessed
|
||||||
|
|
||||||
# Train on example data.
|
# Train on example data.
|
||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from batdetect2.typing import ClipperProtocol
|
|||||||
from batdetect2.typing.train import PreprocessedExample
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
DEFAULT_TRAIN_CLIP_DURATION = 0.512
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,13 +3,12 @@ from typing import List, Optional, Sequence, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from batdetect2.train.augmentations import Augmentation
|
from batdetect2.train.augmentations import Augmentation
|
||||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||||
from batdetect2.utils.tensors import adjust_width
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
@ -31,17 +30,18 @@ class LabeledDataset(Dataset):
|
|||||||
return len(self.filenames)
|
return len(self.filenames)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
dataset = self.get_dataset(idx)
|
example = self.get_example(idx)
|
||||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
|
||||||
|
example, start_time, end_time = self.clipper.extract_clip(example)
|
||||||
|
|
||||||
if self.augmentation:
|
if self.augmentation:
|
||||||
dataset = self.augmentation(dataset)
|
example = self.augmentation(example)
|
||||||
|
|
||||||
return TrainExample(
|
return TrainExample(
|
||||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
spec=example.spectrogram.unsqueeze(0),
|
||||||
detection_heatmap=self.to_tensor(dataset["detection"]),
|
detection_heatmap=example.detection_heatmap.unsqueeze(0),
|
||||||
class_heatmap=self.to_tensor(dataset["class"]),
|
class_heatmap=example.class_heatmap,
|
||||||
size_heatmap=self.to_tensor(dataset["size"]),
|
size_heatmap=example.size_heatmap,
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
start_time=torch.tensor(start_time),
|
start_time=torch.tensor(start_time),
|
||||||
end_time=torch.tensor(end_time),
|
end_time=torch.tensor(end_time),
|
||||||
@ -52,7 +52,7 @@ class LabeledDataset(Dataset):
|
|||||||
cls,
|
cls,
|
||||||
directory: data.PathLike,
|
directory: data.PathLike,
|
||||||
clipper: ClipperProtocol,
|
clipper: ClipperProtocol,
|
||||||
extension: str = ".nc",
|
extension: str = ".npz",
|
||||||
augmentation: Optional[Augmentation] = None,
|
augmentation: Optional[Augmentation] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
@ -61,55 +61,35 @@ class LabeledDataset(Dataset):
|
|||||||
augmentation=augmentation,
|
augmentation=augmentation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
|
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
|
||||||
idx = np.random.randint(0, len(self))
|
idx = np.random.randint(0, len(self))
|
||||||
dataset = self.get_dataset(idx)
|
dataset = self.get_example(idx)
|
||||||
|
|
||||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||||
|
|
||||||
return dataset, start_time, end_time
|
return dataset, start_time, end_time
|
||||||
|
|
||||||
def get_dataset(self, idx) -> xr.Dataset:
|
def get_example(self, idx) -> PreprocessedExample:
|
||||||
return xr.open_dataset(self.filenames[idx])
|
return load_preprocessed_example(self.filenames[idx])
|
||||||
|
|
||||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||||
return data.ClipAnnotation.model_validate_json(
|
item = np.load(self.filenames[idx])
|
||||||
self.get_dataset(idx).attrs["clip_annotation"]
|
return item["clip_annotation"]
|
||||||
)
|
|
||||||
|
|
||||||
def to_tensor(
|
|
||||||
self,
|
|
||||||
array: xr.DataArray,
|
|
||||||
dtype=np.float32,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.nan_to_num(
|
|
||||||
torch.tensor(array.values.astype(dtype)),
|
|
||||||
nan=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch: List[TrainExample]):
|
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||||
width = 512
|
item = np.load(path)
|
||||||
|
return PreprocessedExample(
|
||||||
return TrainExample(
|
audio=torch.tensor(item["audio"]),
|
||||||
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
|
spectrogram=torch.tensor(item["spectrogram"]),
|
||||||
detection_heatmap=torch.stack(
|
size_heatmap=torch.tensor(item["size_heatmap"]),
|
||||||
[adjust_width(x.detection_heatmap, width) for x in batch]
|
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
||||||
),
|
class_heatmap=torch.tensor(item["class_heatmap"]),
|
||||||
class_heatmap=torch.stack(
|
|
||||||
[adjust_width(x.class_heatmap, width) for x in batch]
|
|
||||||
),
|
|
||||||
size_heatmap=torch.stack(
|
|
||||||
[adjust_width(x.size_heatmap, width) for x in batch]
|
|
||||||
),
|
|
||||||
idx=torch.stack([x.idx for x in batch]),
|
|
||||||
start_time=torch.stack([x.start_time for x in batch]),
|
|
||||||
end_time=torch.stack([x.end_time for x in batch]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def list_preprocessed_files(
|
def list_preprocessed_files(
|
||||||
directory: data.PathLike, extension: str = ".nc"
|
directory: data.PathLike, extension: str = ".npz"
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
|
||||||
@ -123,9 +103,9 @@ class RandomExampleSource:
|
|||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.clipper = clipper
|
self.clipper = clipper
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self) -> PreprocessedExample:
|
||||||
index = int(np.random.randint(len(self.filenames)))
|
index = int(np.random.randint(len(self.filenames)))
|
||||||
filename = self.filenames[index]
|
filename = self.filenames[index]
|
||||||
dataset = xr.open_dataset(filename)
|
example = load_preprocessed_example(filename)
|
||||||
example, _, _ = self.clipper.extract_clip(dataset)
|
example, _, _ = self.clipper.extract_clip(example)
|
||||||
return example
|
return example
|
||||||
|
|||||||
@ -30,6 +30,7 @@ import torch.utils.data
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
@ -132,34 +133,47 @@ class PreprocessingDataset(torch.utils.data.Dataset):
|
|||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
|
filename_fn: FilenameFn,
|
||||||
|
output_dir: Path,
|
||||||
|
force: bool = False,
|
||||||
):
|
):
|
||||||
self.clips = clips
|
self.clips = clips
|
||||||
self.audio_loader = audio_loader
|
self.audio_loader = audio_loader
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.labeller = labeller
|
self.labeller = labeller
|
||||||
|
self.filename_fn = filename_fn
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.force = force
|
||||||
|
|
||||||
def __getitem__(self, idx) -> dict:
|
def __getitem__(self, idx) -> int:
|
||||||
clip_annotation = self.clips[idx]
|
clip_annotation = self.clips[idx]
|
||||||
|
|
||||||
|
filename = self.filename_fn(clip_annotation)
|
||||||
|
|
||||||
|
path = self.output_dir / filename
|
||||||
|
|
||||||
|
if path.exists() and not self.force:
|
||||||
|
return idx
|
||||||
|
|
||||||
|
if not path.parent.exists():
|
||||||
|
path.parent.mkdir()
|
||||||
|
|
||||||
example = generate_train_example(
|
example = generate_train_example(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
labeller=self.labeller,
|
labeller=self.labeller,
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"idx": idx,
|
save_example_to_file(example, clip_annotation, path)
|
||||||
"spectrogram": example.spectrogram,
|
|
||||||
"audio": example.audio,
|
return idx
|
||||||
"class_heatmap": example.class_heatmap,
|
|
||||||
"size_heatmap": example.size_heatmap,
|
|
||||||
"detection_heatmap": example.detection_heatmap,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.clips)
|
return len(self.clips)
|
||||||
|
|
||||||
|
|
||||||
def _save_example_to_file(
|
def save_example_to_file(
|
||||||
example: PreprocessedExample,
|
example: PreprocessedExample,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
@ -177,7 +191,7 @@ def _save_example_to_file(
|
|||||||
|
|
||||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||||
"""Generate a default output filename based on the annotation UUID."""
|
"""Generate a default output filename based on the annotation UUID."""
|
||||||
return f"{clip_annotation.uuid}.nc"
|
return f"{clip_annotation.uuid}"
|
||||||
|
|
||||||
|
|
||||||
def preprocess_annotations(
|
def preprocess_annotations(
|
||||||
@ -212,24 +226,17 @@ def preprocess_annotations(
|
|||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
|
output_dir=Path(output_dir),
|
||||||
|
filename_fn=filename_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
loader = torch.utils.data.DataLoader(
|
loader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=None,
|
batch_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=max_workers,
|
num_workers=max_workers,
|
||||||
|
prefetch_factor=16,
|
||||||
)
|
)
|
||||||
|
|
||||||
for batch in loader:
|
for _ in tqdm(loader, total=len(dataset)):
|
||||||
clip_annotation = dataset.clips[batch["idx"]]
|
pass
|
||||||
filename = filename_fn(clip_annotation)
|
|
||||||
path = output_dir / filename
|
|
||||||
example = PreprocessedExample(
|
|
||||||
spectrogram=batch["spectrogram"],
|
|
||||||
audio=batch["audio"],
|
|
||||||
class_heatmap=batch["class_heatmap"],
|
|
||||||
size_heatmap=batch["size_heatmap"],
|
|
||||||
detection_heatmap=batch["detection_heatmap"],
|
|
||||||
)
|
|
||||||
_save_example_to_file(example, clip_annotation, path)
|
|
||||||
|
|||||||
@ -21,7 +21,6 @@ from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
|||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
collate_fn,
|
|
||||||
)
|
)
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
@ -58,7 +57,7 @@ def train(
|
|||||||
|
|
||||||
train_dataloader = build_train_loader(
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor=module.preprocessor,
|
preprocessor=module.model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
@ -66,6 +65,7 @@ def train(
|
|||||||
val_dataloader = (
|
val_dataloader = (
|
||||||
build_val_loader(
|
build_val_loader(
|
||||||
val_examples,
|
val_examples,
|
||||||
|
preprocessor=module.model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
@ -138,9 +138,11 @@ def build_trainer(
|
|||||||
def build_train_loader(
|
def build_train_loader(
|
||||||
train_examples: Sequence[data.PathLike],
|
train_examples: Sequence[data.PathLike],
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: TrainingConfig,
|
config: Optional[TrainingConfig] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
logger.info("Building training data loader...")
|
logger.info("Building training data loader...")
|
||||||
train_dataset = build_train_dataset(
|
train_dataset = build_train_dataset(
|
||||||
train_examples,
|
train_examples,
|
||||||
@ -158,18 +160,21 @@ def build_train_loader(
|
|||||||
batch_size=loader_conf.batch_size,
|
batch_size=loader_conf.batch_size,
|
||||||
shuffle=loader_conf.shuffle,
|
shuffle=loader_conf.shuffle,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_val_loader(
|
def build_val_loader(
|
||||||
val_examples: Sequence[data.PathLike],
|
val_examples: Sequence[data.PathLike],
|
||||||
config: TrainingConfig,
|
preprocessor: PreprocessorProtocol,
|
||||||
|
config: Optional[TrainingConfig] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
logger.info("Building validation data loader...")
|
logger.info("Building validation data loader...")
|
||||||
val_dataset = build_val_dataset(
|
val_dataset = build_val_dataset(
|
||||||
val_examples,
|
val_examples,
|
||||||
|
preprocessor=preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
loader_conf = config.dataloaders.val
|
loader_conf = config.dataloaders.val
|
||||||
@ -183,7 +188,6 @@ def build_val_loader(
|
|||||||
batch_size=loader_conf.batch_size,
|
batch_size=loader_conf.batch_size,
|
||||||
shuffle=loader_conf.shuffle,
|
shuffle=loader_conf.shuffle,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -195,7 +199,11 @@ def build_train_dataset(
|
|||||||
logger.info("Building training dataset...")
|
logger.info("Building training dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
clipper = build_clipper(config.cliping, random=True)
|
clipper = build_clipper(
|
||||||
|
samplerate=preprocessor.samplerate,
|
||||||
|
config=config.cliping,
|
||||||
|
random=True,
|
||||||
|
)
|
||||||
|
|
||||||
random_example_source = RandomExampleSource(
|
random_example_source = RandomExampleSource(
|
||||||
list(examples),
|
list(examples),
|
||||||
@ -221,10 +229,15 @@ def build_train_dataset(
|
|||||||
|
|
||||||
def build_val_dataset(
|
def build_val_dataset(
|
||||||
examples: Sequence[data.PathLike],
|
examples: Sequence[data.PathLike],
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
train: bool = True,
|
train: bool = True,
|
||||||
) -> LabeledDataset:
|
) -> LabeledDataset:
|
||||||
logger.info("Building validation dataset...")
|
logger.info("Building validation dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
clipper = build_clipper(config.cliping, random=train)
|
clipper = build_clipper(
|
||||||
|
samplerate=preprocessor.samplerate,
|
||||||
|
config=config.cliping,
|
||||||
|
random=train,
|
||||||
|
)
|
||||||
return LabeledDataset(examples, clipper=clipper)
|
return LabeledDataset(examples, clipper=clipper)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user