diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index f51c507..19d1dc2 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -43,7 +43,7 @@ class LabeledDataset(Dataset): return len(self.filenames) def __getitem__(self, idx) -> TrainExample: - data = self.load(self.filenames[idx]) + data = self.load(idx) return TrainExample( spec=data["spectrogram"], detection_heatmap=data["detection"], @@ -56,11 +56,12 @@ class LabeledDataset(Dataset): def from_directory(cls, directory: PathLike, extension: str = ".nc"): return cls(get_files(directory, extension)) - def load(self, filename: PathLike) -> Dict[str, torch.Tensor]: - dataset = self.get_dataset(filename) - spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0) + def load(self, idx) -> Dict[str, torch.Tensor]: + dataset = self.get_dataset(idx) return { - "spectrogram": spectrogram, + "spectrogram": torch.tensor( + dataset["spectrogram"].values + ).unsqueeze(0), "detection": torch.tensor(dataset["detection"].values), "class": torch.tensor(dataset["class"].values), "size": torch.tensor(dataset["size"].values),