mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add collate fn
This commit is contained in:
parent
84a13c65a7
commit
434fc652a2
@ -9,6 +9,7 @@ from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.train.augmentations import Augmentation
|
||||
from batdetect2.train.types import ClipperProtocol, TrainExample
|
||||
from batdetect2.utils.tensors import adjust_width
|
||||
|
||||
__all__ = [
|
||||
"LabeledDataset",
|
||||
@ -87,6 +88,26 @@ class LabeledDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(batch: List[TrainExample]):
|
||||
width = 512
|
||||
|
||||
return TrainExample(
|
||||
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
|
||||
detection_heatmap=torch.stack(
|
||||
[adjust_width(x.detection_heatmap, width) for x in batch]
|
||||
),
|
||||
class_heatmap=torch.stack(
|
||||
[adjust_width(x.class_heatmap, width) for x in batch]
|
||||
),
|
||||
size_heatmap=torch.stack(
|
||||
[adjust_width(x.size_heatmap, width) for x in batch]
|
||||
),
|
||||
idx=torch.stack([x.idx for x in batch]),
|
||||
start_time=torch.stack([x.start_time for x in batch]),
|
||||
end_time=torch.stack([x.end_time for x in batch]),
|
||||
)
|
||||
|
||||
|
||||
def list_preprocessed_files(
|
||||
directory: data.PathLike, extension: str = ".nc"
|
||||
) -> List[Path]:
|
||||
|
@ -276,8 +276,11 @@ def preprocess_single_annotation(
|
||||
labeller=labeller,
|
||||
)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||
) from error
|
||||
logger.error(
|
||||
"Failed to process annotation: {uuid}. Error {error}",
|
||||
uuid=clip_annotation.uuid,
|
||||
error=error,
|
||||
)
|
||||
return
|
||||
|
||||
_save_xr_dataset_to_file(sample, path)
|
||||
|
@ -17,7 +17,11 @@ from batdetect2.train.augmentations import (
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
collate_fn,
|
||||
)
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
@ -88,6 +92,7 @@ def train(
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=train_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
val_dataloader = None
|
||||
@ -101,6 +106,7 @@ def train(
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=val_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
|
Loading…
Reference in New Issue
Block a user