diff --git a/batdetect2/postprocess/remapping.py b/batdetect2/postprocess/remapping.py index 51560ea..7112046 100644 --- a/batdetect2/postprocess/remapping.py +++ b/batdetect2/postprocess/remapping.py @@ -84,7 +84,7 @@ def features_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=features.detach().numpy(), + data=features.detach().cpu().numpy(), dims=[ Dimensions.feature.value, Dimensions.frequency.value, @@ -157,7 +157,7 @@ def detection_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=detection.squeeze(dim=0).detach().numpy(), + data=detection.squeeze(dim=0).detach().cpu().numpy(), dims=[ Dimensions.frequency.value, Dimensions.time.value, @@ -233,7 +233,7 @@ def classification_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=classes.detach().numpy(), + data=classes.detach().cpu().numpy(), dims=[ "category", Dimensions.frequency.value, @@ -302,7 +302,7 @@ def sizes_to_xarray( freqs = np.linspace(min_freq, max_freq, height, endpoint=False) return xr.DataArray( - data=sizes.detach().numpy(), + data=sizes.detach().cpu().numpy(), dims=[ "dimension", Dimensions.frequency.value, diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index e586429..59a429a 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -9,6 +9,7 @@ from torch.utils.data import Dataset from batdetect2.train.augmentations import Augmentation from batdetect2.train.types import ClipperProtocol, TrainExample +from batdetect2.utils.tensors import adjust_width __all__ = [ "LabeledDataset", @@ -87,6 +88,26 @@ class LabeledDataset(Dataset): ) +def collate_fn(batch: List[TrainExample]): + width = 512 + + return TrainExample( + spec=torch.stack([adjust_width(x.spec, width) for x in batch]), + detection_heatmap=torch.stack( + [adjust_width(x.detection_heatmap, width) for x in batch] + ), + class_heatmap=torch.stack( + [adjust_width(x.class_heatmap, width) for x in batch] + ), + size_heatmap=torch.stack( + [adjust_width(x.size_heatmap, width) for x in batch] + ), + idx=torch.stack([x.idx for x in batch]), + start_time=torch.stack([x.start_time for x in batch]), + end_time=torch.stack([x.end_time for x in batch]), + ) + + def list_preprocessed_files( directory: data.PathLike, extension: str = ".nc" ) -> List[Path]: diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 0368d10..2abed50 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -276,8 +276,11 @@ def preprocess_single_annotation( labeller=labeller, ) except Exception as error: - raise RuntimeError( - f"Failed to process annotation: {clip_annotation.uuid}" - ) from error + logger.error( + "Failed to process annotation: {uuid}. Error {error}", + uuid=clip_annotation.uuid, + error=error, + ) + return _save_xr_dataset_to_file(sample, path) diff --git a/batdetect2/train/train.py b/batdetect2/train/train.py index 04d0b23..6a71204 100644 --- a/batdetect2/train/train.py +++ b/batdetect2/train/train.py @@ -17,7 +17,11 @@ from batdetect2.train.augmentations import ( ) from batdetect2.train.clips import build_clipper from batdetect2.train.config import TrainingConfig -from batdetect2.train.dataset import LabeledDataset, RandomExampleSource +from batdetect2.train.dataset import ( + LabeledDataset, + RandomExampleSource, + collate_fn, +) from batdetect2.train.lightning import TrainingModule from batdetect2.train.losses import build_loss @@ -88,6 +92,7 @@ def train( batch_size=config.batch_size, shuffle=True, num_workers=train_workers, + collate_fn=collate_fn, ) val_dataloader = None @@ -101,6 +106,7 @@ def train( batch_size=config.batch_size, shuffle=False, num_workers=val_workers, + collate_fn=collate_fn, ) trainer.fit(