You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_dataflow.py 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. os.environ['TL_BACKEND'] = 'tensorflow'
  5. # os.environ['TL_BACKEND'] = 'mindspore'
  6. # os.environ['TL_BACKEND'] = 'paddle'
  7. import tensorlayer as tl
  8. from tensorlayer.layers import Module
  9. from tensorlayer.layers import Dense, Flatten
  10. from tensorlayer.vision.transforms import Normalize, Compose
  11. from tensorlayer.dataflow import Dataset, IterableDataset
  12. transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format='HWC')])
  13. print('download training data and load training data')
  14. X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
  15. X_train = X_train * 255
  16. print('load finished')
  17. class mnistdataset(Dataset):
  18. def __init__(self, data=X_train, label=y_train, transform=transform):
  19. self.data = data
  20. self.label = label
  21. self.transform = transform
  22. def __getitem__(self, index):
  23. data = self.data[index].astype('float32')
  24. data = self.transform(data)
  25. label = self.label[index].astype('int64')
  26. return data, label
  27. def __len__(self):
  28. return len(self.data)
  29. class mnistdataset1(IterableDataset):
  30. def __init__(self, data=X_train, label=y_train, transform=transform):
  31. self.data = data
  32. self.label = label
  33. self.transform = transform
  34. def __iter__(self):
  35. for i in range(len(self.data)):
  36. data = self.data[i].astype('float32')
  37. data = self.transform(data)
  38. label = self.label[i].astype('int64')
  39. yield data, label
  40. class MLP(Module):
  41. def __init__(self):
  42. super(MLP, self).__init__()
  43. self.linear1 = Dense(n_units=120, in_channels=784, act=tl.ReLU)
  44. self.linear2 = Dense(n_units=84, in_channels=120, act=tl.ReLU)
  45. self.linear3 = Dense(n_units=10, in_channels=84)
  46. self.flatten = Flatten()
  47. def forward(self, x):
  48. x = self.flatten(x)
  49. x = self.linear1(x)
  50. x = self.linear2(x)
  51. x = self.linear3(x)
  52. return x
  53. train_dataset = mnistdataset1(data=X_train, label=y_train, transform=transform)
  54. train_dataset = tl.dataflow.FromGenerator(
  55. train_dataset, output_types=[tl.float32, tl.int64], column_names=['data', 'label']
  56. )
  57. train_loader = tl.dataflow.Dataloader(train_dataset, batch_size=128, shuffle=False)
  58. for i in train_loader:
  59. print(i[0].shape, i[1])

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.