You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TrainSaverTest.cs 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. [TestClass]
  9. public class TrainSaverTest : Python
  10. {
  11. [TestMethod]
  12. public void ExportGraph()
  13. {
  14. var v = tf.Variable(0, name: "my_variable");
  15. var sess = tf.Session();
  16. tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt");
  17. }
  18. [TestMethod]
  19. public void ImportGraph()
  20. {
  21. with<Session>(tf.Session(), sess =>
  22. {
  23. var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");
  24. });
  25. }
  26. [TestMethod]
  27. public void ImportSavedModel()
  28. {
  29. with<Session>(Session.LoadFromSavedModel("mobilenet"), sess =>
  30. {
  31. });
  32. }
  33. [TestMethod]
  34. public void Save1()
  35. {
  36. var w1 = tf.Variable(0, name: "save1");
  37. var init_op = tf.global_variables_initializer();
  38. // Add ops to save and restore all the variables.
  39. var saver = tf.train.Saver();
  40. with<Session>(tf.Session(), sess =>
  41. {
  42. sess.run(init_op);
  43. // Save the variables to disk.
  44. var save_path = saver.save(sess, "/tmp/model1.ckpt");
  45. Console.WriteLine($"Model saved in path: {save_path}");
  46. });
  47. }
  48. [TestMethod]
  49. public void Save2()
  50. {
  51. var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
  52. var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);
  53. var inc_v1 = v1.assign(v1 + 1.0f);
  54. var dec_v2 = v2.assign(v2 - 1.0f);
  55. // Add an op to initialize the variables.
  56. var init_op = tf.global_variables_initializer();
  57. // Add ops to save and restore all the variables.
  58. var saver = tf.train.Saver();
  59. with<Session>(tf.Session(), sess =>
  60. {
  61. sess.run(init_op);
  62. // o some work with the model.
  63. inc_v1.op.run();
  64. dec_v2.op.run();
  65. // Save the variables to disk.
  66. var save_path = saver.save(sess, "/tmp/model2.ckpt");
  67. Console.WriteLine($"Model saved in path: {save_path}");
  68. });
  69. }
  70. }
  71. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。