You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_LeNet.py 1.7 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import numpy as np
  2. import tensorflow as tf
  3. def tf_conv_pool(x, in_channel, out_channel):
  4. weight = tf.Variable(np.random.normal(scale=0.1, size=(
  5. out_channel, in_channel, 5, 5)).transpose([2, 3, 1, 0]).astype(np.float32))
  6. x = tf.nn.conv2d(x, weight, padding='SAME', strides=[1, 1, 1, 1])
  7. x = tf.nn.relu(x)
  8. x = tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
  9. padding='VALID', strides=[1, 2, 2, 1])
  10. return x
  11. def tf_fc(x, shape, with_relu=True):
  12. weight = tf.Variable(np.random.normal(
  13. scale=0.1, size=shape).astype(np.float32))
  14. bias = tf.Variable(np.random.normal(
  15. scale=0.1, size=shape[-1:]).astype(np.float32))
  16. x = tf.matmul(x, weight) + bias
  17. if with_relu:
  18. x = tf.nn.relu(x)
  19. return x
  20. def tf_lenet(x, y_):
  21. '''
  22. LeNet model in TensorFlow, for MNIST dataset.
  23. Parameters:
  24. x: Variable(tensorflow.python.framework.ops.Tensor), shape (N, dims)
  25. y_: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  26. Return:
  27. loss: Variable(tensorflow.python.framework.ops.Tensor), shape (1,)
  28. y: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  29. '''
  30. print('Building LeNet model in tensorflow...')
  31. x = tf.reshape(x, [-1, 28, 28, 1])
  32. x = tf_conv_pool(x, 1, 6)
  33. x = tf_conv_pool(x, 6, 16)
  34. x = tf.transpose(x, [0, 3, 1, 2])
  35. x = tf.reshape(x, (-1, 7*7*16))
  36. x = tf_fc(x, (7*7*16, 120), with_relu=True)
  37. x = tf_fc(x, (120, 84), with_relu=True)
  38. y = tf_fc(x, (84, 10), with_relu=False)
  39. loss = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)
  40. loss = tf.reduce_mean(loss)
  41. return loss, y