Fixed labeled dataset indexing error when loading

This commit is contained in:
mbsantiago 2024-07-16 01:30:52 +01:00
parent 335a05d51a
commit 4973cfcc5f

View File

@ -43,7 +43,7 @@ class LabeledDataset(Dataset):
return len(self.filenames)
def __getitem__(self, idx) -> TrainExample:
data = self.load(self.filenames[idx])
data = self.load(idx)
return TrainExample(
spec=data["spectrogram"],
detection_heatmap=data["detection"],
@ -56,11 +56,12 @@ class LabeledDataset(Dataset):
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
return cls(get_files(directory, extension))
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
dataset = self.get_dataset(filename)
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
def load(self, idx) -> Dict[str, torch.Tensor]:
dataset = self.get_dataset(idx)
return {
"spectrogram": spectrogram,
"spectrogram": torch.tensor(
dataset["spectrogram"].values
).unsqueeze(0),
"detection": torch.tensor(dataset["detection"].values),
"class": torch.tensor(dataset["class"].values),
"size": torch.tensor(dataset["size"].values),