|
- from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
-
- if _NEED_IMPORT_ONEFLOW:
- import oneflow
- from oneflow.utils.data import Dataset
- else:
- from fastNLP.core.utils.dummy_class import DummyClass as Dataset
-
-
- class OneflowNormalDataset(Dataset):
- def __init__(self, num_of_data=1000):
- self.num_of_data = num_of_data
- self._data = list(range(num_of_data))
-
- def __len__(self):
- return self.num_of_data
-
- def __getitem__(self, item):
- return self._data[item]
-
- class OneflowNormalXYDataset(Dataset):
- """
- 可以被输入到分类模型中的普通数据集
- """
- def __init__(self, num_of_data=1000):
- self.num_of_data = num_of_data
- self._data = list(range(num_of_data))
-
- def __len__(self):
- return self.num_of_data
-
- def __getitem__(self, item):
- return {
- "x": oneflow.tensor([self._data[item]], dtype=oneflow.float),
- "y": oneflow.tensor([self._data[item]], dtype=oneflow.float)
- }
-
-
- class OneflowArgMaxDataset(Dataset):
- def __init__(self, data_num=1000, feature_dimension=10, seed=0):
- self.num_labels = feature_dimension
- self.feature_dimension = feature_dimension
- self.data_num = data_num
- self.seed = seed
-
- g = oneflow.Generator()
- g.manual_seed(1000)
- self.x = oneflow.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()
- self.y = oneflow.max(self.x, dim=-1)[1]
-
- def __len__(self):
- return self.data_num
-
- def __getitem__(self, item):
- return {"x": self.x[item], "y": self.y[item]}
|