You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_MLP.py 1.1 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334
  1. import numpy as np
  2. import tensorflow as tf
  3. def tf_fc(x, shape, with_relu=True):
  4. weight = tf.Variable(np.random.normal(
  5. scale=0.1, size=shape).astype(np.float32))
  6. bias = tf.Variable(np.random.normal(
  7. scale=0.1, size=shape[-1:]).astype(np.float32))
  8. x = tf.matmul(x, weight) + bias
  9. if with_relu:
  10. x = tf.nn.relu(x)
  11. return x
  12. def tf_mlp(x, y_, num_class=10):
  13. '''
  14. MLP model in TensorFlow, for CIFAR dataset.
  15. Parameters:
  16. x: Variable(tensorflow.python.framework.ops.Tensor), shape (N, dims)
  17. y_: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  18. Return:
  19. loss: Variable(tensorflow.python.framework.ops.Tensor), shape (1,)
  20. y: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  21. '''
  22. print("Building MLP model in tensorflow...")
  23. x = tf_fc(x, (3072, 256), with_relu=True)
  24. x = tf_fc(x, (256, 256), with_relu=True)
  25. y = tf_fc(x, (256, num_class), with_relu=False)
  26. loss = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)
  27. loss = tf.reduce_mean(loss)
  28. return loss, y

分布式深度学习系统

Contributors (1)