You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TrainSaverTest.cs 3.3 kB

6 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Basics
  6. {
  7. [TestClass]
  8. public class TrainSaverTest
  9. {
  10. public void ExportGraph()
  11. {
  12. var v = tf.Variable(0, name: "my_variable");
  13. var sess = tf.Session();
  14. tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt");
  15. }
  16. public void ImportGraph()
  17. {
  18. using (var sess = tf.Session())
  19. {
  20. var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");
  21. }
  22. //tf.train.export_meta_graph(filename: "linear_regression.meta.bin");
  23. // import meta
  24. /*tf.train.import_meta_graph("linear_regression.meta.bin");
  25. var cost = graph.OperationByName("truediv").output;
  26. var pred = graph.OperationByName("Add").output;
  27. var optimizer = graph.OperationByName("GradientDescent");
  28. var X = graph.OperationByName("Placeholder").output;
  29. var Y = graph.OperationByName("Placeholder_1").output;
  30. var W = graph.OperationByName("weight").output;
  31. var b = graph.OperationByName("bias").output;*/
  32. /*var text = JsonConvert.SerializeObject(graph, new JsonSerializerSettings
  33. {
  34. Formatting = Formatting.Indented
  35. });*/
  36. }
  37. public void ImportSavedModel()
  38. {
  39. Session.LoadFromSavedModel("mobilenet");
  40. }
  41. public void ImportGraphDefFromPbFile()
  42. {
  43. var g = new Graph();
  44. var status = g.Import("mobilenet/saved_model.pb");
  45. }
  46. public void Save1()
  47. {
  48. var w1 = tf.Variable(0, name: "save1");
  49. var init_op = tf.global_variables_initializer();
  50. // Add ops to save and restore all the variables.
  51. var saver = tf.train.Saver();
  52. using (var sess = tf.Session())
  53. {
  54. sess.run(init_op);
  55. // Save the variables to disk.
  56. var save_path = saver.save(sess, "/tmp/model1.ckpt");
  57. Console.WriteLine($"Model saved in path: {save_path}");
  58. }
  59. }
  60. public void Save2()
  61. {
  62. var v1 = tf.compat.v1.get_variable("v1", shape: new Shape(3), initializer: tf.zeros_initializer);
  63. var v2 = tf.compat.v1.get_variable("v2", shape: new Shape(5), initializer: tf.zeros_initializer);
  64. var inc_v1 = v1.assign(v1.AsTensor() + 1.0f);
  65. var dec_v2 = v2.assign(v2.AsTensor() - 1.0f);
  66. // Add an op to initialize the variables.
  67. var init_op = tf.global_variables_initializer();
  68. // Add ops to save and restore all the variables.
  69. var saver = tf.train.Saver();
  70. using (var sess = tf.Session())
  71. {
  72. sess.run(init_op);
  73. // o some work with the model.
  74. inc_v1.op.run();
  75. dec_v2.op.run();
  76. // Save the variables to disk.
  77. var save_path = saver.save(sess, "/tmp/model2.ckpt");
  78. Console.WriteLine($"Model saved in path: {save_path}");
  79. }
  80. }
  81. }
  82. }