You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_mnist_simple.py 2.1 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import numpy as np
  4. import os
  5. os.environ['TL_BACKEND'] = 'tensorflow'
  6. # os.environ['TL_BACKEND'] = 'mindspore'
  7. import tensorlayer as tl
  8. from tensorlayer.layers import Module
  9. from tensorlayer.layers import Dense, Dropout
  10. X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 784))
  11. class CustomModel(Module):
  12. def __init__(self):
  13. super(CustomModel, self).__init__()
  14. self.dropout1 = Dropout(keep=0.8)
  15. self.dense1 = Dense(n_units=800, act=tl.ReLU, in_channels=784)
  16. self.dropout2 = Dropout(keep=0.8)
  17. self.dense2 = Dense(n_units=800, act=tl.ReLU, in_channels=800)
  18. self.dropout3 = Dropout(keep=0.8)
  19. self.dense3 = Dense(n_units=10, act=tl.ReLU, in_channels=800)
  20. def forward(self, x, foo=None):
  21. z = self.dropout1(x)
  22. z = self.dense1(z)
  23. # z = self.bn(z)
  24. z = self.dropout2(z)
  25. z = self.dense2(z)
  26. z = self.dropout3(z)
  27. out = self.dense3(z)
  28. if foo is not None:
  29. out = tl.ops.relu(out)
  30. return out
  31. def generator_train():
  32. inputs = X_train
  33. targets = y_train
  34. if len(inputs) != len(targets):
  35. raise AssertionError("The length of inputs and targets should be equal")
  36. for _input, _target in zip(inputs, targets):
  37. yield (_input, np.array(_target))
  38. MLP = CustomModel()
  39. n_epoch = 50
  40. batch_size = 128
  41. print_freq = 2
  42. shuffle_buffer_size = 128
  43. train_weights = MLP.trainable_weights
  44. optimizer = tl.optimizers.Momentum(0.05, 0.9)
  45. train_ds = tl.dataflow.FromGenerator(
  46. generator_train, output_types=(tl.float32, tl.int32) , column_names=['data', 'label']
  47. )
  48. train_ds = tl.dataflow.Shuffle(train_ds,shuffle_buffer_size)
  49. train_ds = tl.dataflow.Batch(train_ds,batch_size)
  50. model = tl.models.Model(network=MLP, loss_fn=tl.cost.cross_entropy, optimizer=optimizer)
  51. model.train(n_epoch=n_epoch, train_dataset=train_ds, print_freq=print_freq, print_train_batch=False)
  52. model.save_weights('./model.npz', format='npz_dict')
  53. model.load_weights('./model.npz', format='npz_dict')

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.