You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TrainSaverTest.cs 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Text;
  6. using Tensorflow;
  7. namespace TensorFlowNET.UnitTest
  8. {
  9. [TestClass]
  10. public class TrainSaverTest : Python
  11. {
  12. [TestMethod]
  13. public void ExportGraph()
  14. {
  15. var v = tf.Variable(0, name: "my_variable");
  16. var sess = tf.Session();
  17. tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt");
  18. }
  19. [TestMethod]
  20. public void ImportGraph()
  21. {
  22. with<Session>(tf.Session(), sess =>
  23. {
  24. var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");
  25. });
  26. }
  27. [TestMethod]
  28. public void ImportSavedModel()
  29. {
  30. with<Session>(Session.LoadFromSavedModel("mobilenet"), sess =>
  31. {
  32. });
  33. }
  34. [TestMethod]
  35. public void ImportGraphDefFromPbFile()
  36. {
  37. var g = new Graph();
  38. var status = g.Import("mobilenet/saved_model.pb");
  39. }
  40. [TestMethod]
  41. public void Save1()
  42. {
  43. var w1 = tf.Variable(0, name: "save1");
  44. var init_op = tf.global_variables_initializer();
  45. // Add ops to save and restore all the variables.
  46. var saver = tf.train.Saver();
  47. with<Session>(tf.Session(), sess =>
  48. {
  49. sess.run(init_op);
  50. // Save the variables to disk.
  51. var save_path = saver.save(sess, "/tmp/model1.ckpt");
  52. Console.WriteLine($"Model saved in path: {save_path}");
  53. });
  54. }
  55. [TestMethod]
  56. public void Save2()
  57. {
  58. var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
  59. var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);
  60. var inc_v1 = v1.assign(v1 + 1.0f);
  61. var dec_v2 = v2.assign(v2 - 1.0f);
  62. // Add an op to initialize the variables.
  63. var init_op = tf.global_variables_initializer();
  64. // Add ops to save and restore all the variables.
  65. var saver = tf.train.Saver();
  66. with<Session>(tf.Session(), sess =>
  67. {
  68. sess.run(init_op);
  69. // o some work with the model.
  70. inc_v1.op.run();
  71. dec_v2.op.run();
  72. // Save the variables to disk.
  73. var save_path = saver.save(sess, "/tmp/model2.ckpt");
  74. Console.WriteLine($"Model saved in path: {save_path}");
  75. });
  76. }
  77. }
  78. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。