|
- from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
-
- if _NEED_IMPORT_JITTOR:
- import jittor as jt
- from jittor.dataset import Dataset
- else:
- from fastNLP.core.utils.dummy_class import DummyClass as Dataset
-
- class JittorNormalDataset(Dataset):
- def __init__(self, num_of_data=100, **kwargs):
- super(JittorNormalDataset, self).__init__(**kwargs)
- self._data = list(range(num_of_data))
- self.set_attrs(total_len=num_of_data)
-
- def __getitem__(self, item):
- return self._data[item]
-
- class JittorNormalXYDataset(Dataset):
- """
- 可以被输入到分类模型中的普通数据集
- """
- def __init__(self, num_of_data=1000, **kwargs):
- super(JittorNormalXYDataset, self).__init__(**kwargs)
- self.num_of_data = num_of_data
- self._data = list(range(num_of_data))
- self.set_attrs(total_len=num_of_data)
-
- def __getitem__(self, item):
- return {
- "x": jt.Var([self._data[item]]),
- "y": jt.Var([self._data[item]])
- }
-
- class JittorArgMaxDataset(Dataset):
- def __init__(self, num_samples, num_features, **kwargs):
- super(JittorArgMaxDataset, self).__init__(**kwargs)
- self.x = jt.randn(num_samples, num_features)
- self.y = self.x.argmax(dim=-1)
- self.set_attrs(total_len=num_samples)
-
- def __getitem__(self, item):
- return {"x": self.x[item], "y": self.y[item]}
-
- if __name__ == "__main__":
- dataset = JittorNormalDataset()
- print(len(dataset))
|