You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TrainSaverTest.cs 3.4 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Python;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class TrainSaverTest
  12. {
  13. public void ExportGraph()
  14. {
  15. var v = tf.Variable(0, name: "my_variable");
  16. var sess = tf.Session();
  17. tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt");
  18. }
  19. public void ImportGraph()
  20. {
  21. with(tf.Session(), sess =>
  22. {
  23. var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");
  24. });
  25. //tf.train.export_meta_graph(filename: "linear_regression.meta.bin");
  26. // import meta
  27. /*tf.train.import_meta_graph("linear_regression.meta.bin");
  28. var cost = graph.OperationByName("truediv").output;
  29. var pred = graph.OperationByName("Add").output;
  30. var optimizer = graph.OperationByName("GradientDescent");
  31. var X = graph.OperationByName("Placeholder").output;
  32. var Y = graph.OperationByName("Placeholder_1").output;
  33. var W = graph.OperationByName("weight").output;
  34. var b = graph.OperationByName("bias").output;*/
  35. /*var text = JsonConvert.SerializeObject(graph, new JsonSerializerSettings
  36. {
  37. Formatting = Formatting.Indented
  38. });*/
  39. }
  40. public void ImportSavedModel()
  41. {
  42. with(Session.LoadFromSavedModel("mobilenet"), sess =>
  43. {
  44. });
  45. }
  46. public void ImportGraphDefFromPbFile()
  47. {
  48. var g = new Graph();
  49. var status = g.Import("mobilenet/saved_model.pb");
  50. }
  51. public void Save1()
  52. {
  53. var w1 = tf.Variable(0, name: "save1");
  54. var init_op = tf.global_variables_initializer();
  55. // Add ops to save and restore all the variables.
  56. var saver = tf.train.Saver();
  57. with(tf.Session(), sess =>
  58. {
  59. sess.run(init_op);
  60. // Save the variables to disk.
  61. var save_path = saver.save(sess, "/tmp/model1.ckpt");
  62. Console.WriteLine($"Model saved in path: {save_path}");
  63. });
  64. }
  65. public void Save2()
  66. {
  67. var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
  68. var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);
  69. var inc_v1 = v1.assign(v1 + 1.0f);
  70. var dec_v2 = v2.assign(v2 - 1.0f);
  71. // Add an op to initialize the variables.
  72. var init_op = tf.global_variables_initializer();
  73. // Add ops to save and restore all the variables.
  74. var saver = tf.train.Saver();
  75. with<Session>(tf.Session(), sess =>
  76. {
  77. sess.run(init_op);
  78. // o some work with the model.
  79. inc_v1.op.run();
  80. dec_v2.op.run();
  81. // Save the variables to disk.
  82. var save_path = saver.save(sess, "/tmp/model2.ckpt");
  83. Console.WriteLine($"Model saved in path: {save_path}");
  84. });
  85. }
  86. }
  87. }