You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_graph.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. 
  2. import tensorflow as tf
  3. import math
  4. # Creates an inference graph.
  5. # Hidden 1
  6. images = tf.constant(1.2, tf.float32, shape=[100, 28])
  7. with tf.name_scope("hidden1"):
  8. weights = tf.Variable(
  9. tf.truncated_normal([28, 128],
  10. stddev=1.0 / math.sqrt(float(28))),
  11. name="weights")
  12. biases = tf.Variable(tf.zeros([128]),
  13. name="biases")
  14. hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  15. # Hidden 2
  16. with tf.name_scope("hidden2"):
  17. weights = tf.Variable(
  18. tf.truncated_normal([128, 32],
  19. stddev=1.0 / math.sqrt(float(128))),
  20. name="weights")
  21. biases = tf.Variable(tf.zeros([32]),
  22. name="biases")
  23. hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  24. # Linear
  25. with tf.name_scope("softmax_linear"):
  26. weights = tf.Variable(
  27. tf.truncated_normal([32, 10],
  28. stddev=1.0 / math.sqrt(float(32))),
  29. name="weights")
  30. biases = tf.Variable(tf.zeros([10]),
  31. name="biases")
  32. logits = tf.matmul(hidden2, weights) + biases
  33. tf.add_to_collection("logits", logits)
  34. init_all_op = tf.global_variables_initializer()
  35. with tf.Session() as sess:
  36. # Initializes all the variables.
  37. sess.run(init_all_op)
  38. # Runs to logit.
  39. sess.run(logits)
  40. # Creates a saver.
  41. saver0 = tf.train.Saver()
  42. saver0.save(sess, 'my-save-dir/my-model-10000')
  43. # Generates MetaGraphDef.
  44. saver0.export_meta_graph('my-save-dir/my-model-10000.meta')
  45. # Then later import it and extend it to a training graph.
  46. with tf.Session() as sess:
  47. new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  48. new_saver.restore(sess, 'my-save-dir/my-model-10000')
  49. # Addes loss and train.
  50. labels = tf.constant(0, tf.int32, shape=[100], name="labels")
  51. batch_size = tf.size(labels)
  52. logits = tf.get_collection("logits")[0]
  53. loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
  54. logits=logits)
  55. tf.summary.scalar('loss', loss)
  56. # Creates the gradient descent optimizer with the given learning rate.
  57. optimizer = tf.train.GradientDescentOptimizer(0.01)
  58. # Runs train_op.
  59. train_op = optimizer.minimize(loss)
  60. sess.run(train_op)