mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fixed labeled dataset indexing error when loading
This commit is contained in:
parent
335a05d51a
commit
4973cfcc5f
@ -43,7 +43,7 @@ class LabeledDataset(Dataset):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
data = self.load(self.filenames[idx])
|
||||
data = self.load(idx)
|
||||
return TrainExample(
|
||||
spec=data["spectrogram"],
|
||||
detection_heatmap=data["detection"],
|
||||
@ -56,11 +56,12 @@ class LabeledDataset(Dataset):
|
||||
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
||||
return cls(get_files(directory, extension))
|
||||
|
||||
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
|
||||
dataset = self.get_dataset(filename)
|
||||
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
|
||||
def load(self, idx) -> Dict[str, torch.Tensor]:
|
||||
dataset = self.get_dataset(idx)
|
||||
return {
|
||||
"spectrogram": spectrogram,
|
||||
"spectrogram": torch.tensor(
|
||||
dataset["spectrogram"].values
|
||||
).unsqueeze(0),
|
||||
"detection": torch.tensor(dataset["detection"].values),
|
||||
"class": torch.tensor(dataset["class"].values),
|
||||
"size": torch.tensor(dataset["size"].values),
|
||||
|
Loading…
Reference in New Issue
Block a user