Fixed labeled dataset indexing error when loading

This commit is contained in:
mbsantiago 2024-07-16 01:30:52 +01:00
parent 335a05d51a
commit 4973cfcc5f

View File

@ -43,7 +43,7 @@ class LabeledDataset(Dataset):
return len(self.filenames) return len(self.filenames)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
data = self.load(self.filenames[idx]) data = self.load(idx)
return TrainExample( return TrainExample(
spec=data["spectrogram"], spec=data["spectrogram"],
detection_heatmap=data["detection"], detection_heatmap=data["detection"],
@ -56,11 +56,12 @@ class LabeledDataset(Dataset):
def from_directory(cls, directory: PathLike, extension: str = ".nc"): def from_directory(cls, directory: PathLike, extension: str = ".nc"):
return cls(get_files(directory, extension)) return cls(get_files(directory, extension))
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]: def load(self, idx) -> Dict[str, torch.Tensor]:
dataset = self.get_dataset(filename) dataset = self.get_dataset(idx)
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
return { return {
"spectrogram": spectrogram, "spectrogram": torch.tensor(
dataset["spectrogram"].values
).unsqueeze(0),
"detection": torch.tensor(dataset["detection"].values), "detection": torch.tensor(dataset["detection"].values),
"class": torch.tensor(dataset["class"].values), "class": torch.tensor(dataset["class"].values),
"size": torch.tensor(dataset["size"].values), "size": torch.tensor(dataset["size"].values),