Source code for neuralop.data.datasets.tensor_dataset

from torch.utils.data.dataset import Dataset


[docs] class TensorDataset(Dataset): def __init__(self, x, y, transform_x=None, transform_y=None): assert (x.size(0) == y.size(0)), "Size mismatch between tensors" self.x = x self.y = y self.transform_x = transform_x self.transform_y = transform_y def __getitem__(self, index): x = self.x[index] y = self.y[index] if self.transform_x is not None: x = self.transform_x(x) if self.transform_y is not None: y = self.transform_y(y) return {'x': x, 'y':y} def __len__(self): return self.x.size(0)
class GeneralTensorDataset(Dataset): def __init__(self, sets, transforms): assert len(sets) == len(transforms), "Size mismatch between number of tensors and transforms" self.n = len(sets) if self.n > 1: for j in range(1,self.n): assert sets[j].size(0) == sets[0].size(0), "Size mismatch between tensors" self.sets = sets self.transforms = transforms def __getitem__(self, index): if self.n > 1: items = [] for j in range(self.n): items.append(self.sets[j][index]) if self.transforms[j] is not None: items[j] = self.transforms[j](items[j]) else: items = self.sets[0][index] if self.transforms[0] is not None: items = self.transforms[0](items) return items def __len__(self): return self.sets[0].size(0)