Add collate fn

This commit is contained in:
Santiago Martinez Balvanera 2025-06-19 00:46:34 +01:00
parent 84a13c65a7
commit 434fc652a2
3 changed files with 34 additions and 4 deletions

View File

@ -9,6 +9,7 @@ from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation from batdetect2.train.augmentations import Augmentation
from batdetect2.train.types import ClipperProtocol, TrainExample from batdetect2.train.types import ClipperProtocol, TrainExample
from batdetect2.utils.tensors import adjust_width
__all__ = [ __all__ = [
"LabeledDataset", "LabeledDataset",
@ -87,6 +88,26 @@ class LabeledDataset(Dataset):
) )
def collate_fn(batch: List[TrainExample]):
width = 512
return TrainExample(
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
detection_heatmap=torch.stack(
[adjust_width(x.detection_heatmap, width) for x in batch]
),
class_heatmap=torch.stack(
[adjust_width(x.class_heatmap, width) for x in batch]
),
size_heatmap=torch.stack(
[adjust_width(x.size_heatmap, width) for x in batch]
),
idx=torch.stack([x.idx for x in batch]),
start_time=torch.stack([x.start_time for x in batch]),
end_time=torch.stack([x.end_time for x in batch]),
)
def list_preprocessed_files( def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc" directory: data.PathLike, extension: str = ".nc"
) -> List[Path]: ) -> List[Path]:

View File

@ -276,8 +276,11 @@ def preprocess_single_annotation(
labeller=labeller, labeller=labeller,
) )
except Exception as error: except Exception as error:
raise RuntimeError( logger.error(
f"Failed to process annotation: {clip_annotation.uuid}" "Failed to process annotation: {uuid}. Error {error}",
) from error uuid=clip_annotation.uuid,
error=error,
)
return
_save_xr_dataset_to_file(sample, path) _save_xr_dataset_to_file(sample, path)

View File

@ -17,7 +17,11 @@ from batdetect2.train.augmentations import (
) )
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
collate_fn,
)
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss from batdetect2.train.losses import build_loss
@ -88,6 +92,7 @@ def train(
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,
num_workers=train_workers, num_workers=train_workers,
collate_fn=collate_fn,
) )
val_dataloader = None val_dataloader = None
@ -101,6 +106,7 @@ def train(
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=False, shuffle=False,
num_workers=val_workers, num_workers=val_workers,
collate_fn=collate_fn,
) )
trainer.fit( trainer.fit(