You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_mnist_simple.py 2.4 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # The same set of code can switch the backend with one line
  4. import os
  5. # os.environ['TL_BACKEND'] = 'tensorflow'
  6. # os.environ['TL_BACKEND'] = 'mindspore'
  7. os.environ['TL_BACKEND'] = 'paddle'
  8. import tensorlayer as tl
  9. from tensorlayer.layers import Module
  10. from tensorlayer.layers import Dense, Dropout
  11. from tensorlayer.dataflow import Dataset
  12. X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 784))
  13. class mnistdataset(Dataset):
  14. def __init__(self, data=X_train, label=y_train):
  15. self.data = data
  16. self.label = label
  17. def __getitem__(self, index):
  18. data = self.data[index].astype('float32')
  19. label = self.label[index].astype('int64')
  20. return data, label
  21. def __len__(self):
  22. return len(self.data)
  23. class CustomModel(Module):
  24. def __init__(self):
  25. super(CustomModel, self).__init__()
  26. self.dropout1 = Dropout(keep=0.8)
  27. self.dense1 = Dense(n_units=800, act=tl.ReLU, in_channels=784)
  28. self.dropout2 = Dropout(keep=0.8)
  29. self.dense2 = Dense(n_units=800, act=tl.ReLU, in_channels=800)
  30. self.dropout3 = Dropout(keep=0.8)
  31. self.dense3 = Dense(n_units=10, act=tl.ReLU, in_channels=800)
  32. def forward(self, x, foo=None):
  33. z = self.dropout1(x)
  34. z = self.dense1(z)
  35. z = self.dropout2(z)
  36. z = self.dense2(z)
  37. z = self.dropout3(z)
  38. out = self.dense3(z)
  39. if foo is not None:
  40. out = tl.ops.relu(out)
  41. return out
  42. MLP = CustomModel()
  43. n_epoch = 50
  44. batch_size = 128
  45. print_freq = 2
  46. train_weights = MLP.trainable_weights
  47. optimizer = tl.optimizers.Momentum(0.05, 0.9)
  48. metric = tl.metric.Accuracy()
  49. loss_fn = tl.cost.softmax_cross_entropy_with_logits
  50. train_dataset = mnistdataset(data=X_train, label=y_train)
  51. train_dataset = tl.dataflow.FromGenerator(
  52. train_dataset, output_types=[tl.float32, tl.int64], column_names=['data', 'label']
  53. )
  54. train_loader = tl.dataflow.Dataloader(train_dataset, batch_size=batch_size, shuffle=True)
  55. model = tl.models.Model(network=MLP, loss_fn=loss_fn, optimizer=optimizer, metrics=metric)
  56. model.train(n_epoch=n_epoch, train_dataset=train_loader, print_freq=print_freq, print_train_batch=False)
  57. model.save_weights('./model.npz', format='npz_dict')
  58. model.load_weights('./model.npz', format='npz_dict')

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.