You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BasicOperations.cs 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. using NumSharp;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.Examples
  6. {
  7. /// <summary>
  8. /// Basic Operations example using TensorFlow library.
  9. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/basic_operations.py
  10. /// </summary>
  11. public class BasicOperations : IExample
  12. {
  13. public bool Enabled { get; set; } = true;
  14. public string Name => "Basic Operations";
  15. public bool IsImportingGraph { get; set; } = false;
  16. private Session sess;
  17. public bool Run()
  18. {
  19. // Basic constant operations
  20. // The value returned by the constructor represents the output
  21. // of the Constant op.
  22. var a = tf.constant(2);
  23. var b = tf.constant(3);
  24. // Launch the default graph.
  25. using (sess = tf.Session())
  26. {
  27. Console.WriteLine("a=2, b=3");
  28. Console.WriteLine($"Addition with constants: {sess.run(a + b)}");
  29. Console.WriteLine($"Multiplication with constants: {sess.run(a * b)}");
  30. }
  31. // Basic Operations with variable as graph input
  32. // The value returned by the constructor represents the output
  33. // of the Variable op. (define as input when running session)
  34. // tf Graph input
  35. a = tf.placeholder(tf.int16);
  36. b = tf.placeholder(tf.int16);
  37. // Define some operations
  38. var add = tf.add(a, b);
  39. var mul = tf.multiply(a, b);
  40. // Launch the default graph.
  41. using(sess = tf.Session())
  42. {
  43. var feed_dict = new FeedItem[]
  44. {
  45. new FeedItem(a, (short)2),
  46. new FeedItem(b, (short)3)
  47. };
  48. // Run every operation with variable input
  49. Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}");
  50. Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}");
  51. }
  52. // ----------------
  53. // More in details:
  54. // Matrix Multiplication from TensorFlow official tutorial
  55. // Create a Constant op that produces a 1x2 matrix. The op is
  56. // added as a node to the default graph.
  57. //
  58. // The value returned by the constructor represents the output
  59. // of the Constant op.
  60. var nd1 = np.array(3, 3).reshape(1, 2);
  61. var matrix1 = tf.constant(nd1);
  62. // Create another Constant that produces a 2x1 matrix.
  63. var nd2 = np.array(2, 2).reshape(2, 1);
  64. var matrix2 = tf.constant(nd2);
  65. // Create a Matmul op that takes 'matrix1' and 'matrix2' as inputs.
  66. // The returned value, 'product', represents the result of the matrix
  67. // multiplication.
  68. var product = tf.matmul(matrix1, matrix2);
  69. // To run the matmul op we call the session 'run()' method, passing 'product'
  70. // which represents the output of the matmul op. This indicates to the call
  71. // that we want to get the output of the matmul op back.
  72. //
  73. // All inputs needed by the op are run automatically by the session. They
  74. // typically are run in parallel.
  75. //
  76. // The call 'run(product)' thus causes the execution of threes ops in the
  77. // graph: the two constants and matmul.
  78. //
  79. // The output of the op is returned in 'result' as a numpy `ndarray` object.
  80. using (sess = tf.Session())
  81. {
  82. var result = sess.run(product);
  83. Console.WriteLine(result.ToString()); // ==> [[ 12.]]
  84. };
  85. // `BatchMatMul` is actually embedded into the `MatMul` operation on the tf.dll side. Every time we ask
  86. // for a multiplication between matrices with rank > 2, the first rank - 2 dimensions are checked to be consistent
  87. // across the two matrices and a common matrix multiplication is done on the residual 2 dimensions.
  88. //
  89. // np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3, 3)
  90. // array([[[1, 2, 3],
  91. // [4, 5, 6],
  92. // [7, 8, 9]],
  93. //
  94. // [[1, 2, 3],
  95. // [4, 5, 6],
  96. // [7, 8, 9]],
  97. //
  98. // [[1, 2, 3],
  99. // [4, 5, 6],
  100. // [7, 8, 9]]])
  101. var firstTensor = tf.convert_to_tensor(
  102. np.reshape(
  103. np.array<float>(1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9),
  104. 3, 3, 3));
  105. //
  106. // np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]).reshape(3,3,2)
  107. // array([[[0, 1],
  108. // [0, 1],
  109. // [0, 1]],
  110. //
  111. // [[0, 1],
  112. // [0, 0],
  113. // [1, 0]],
  114. //
  115. // [[1, 0],
  116. // [1, 0],
  117. // [1, 0]]])
  118. var secondTensor = tf.convert_to_tensor(
  119. np.reshape(
  120. np.array<float>(0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0),
  121. 3, 3, 2));
  122. var batchMul = tf.batch_matmul(firstTensor, secondTensor);
  123. var checkTensor = np.array<float>(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0);
  124. using (var sess = tf.Session())
  125. {
  126. var result = sess.run(batchMul);
  127. Console.WriteLine(result.ToString());
  128. //
  129. // ==> array([[[0, 6],
  130. // [0, 15],
  131. // [0, 24]],
  132. //
  133. // [[ 3, 1],
  134. // [ 6, 4],
  135. // [ 9, 7]],
  136. //
  137. // [[ 6, 0],
  138. // [15, 0],
  139. // [24, 0]]])
  140. return np.reshape(result, 18)
  141. .array_equal(checkTensor);
  142. }
  143. }
  144. public void PrepareData()
  145. {
  146. }
  147. public Graph ImportGraph()
  148. {
  149. throw new NotImplementedException();
  150. }
  151. public Graph BuildGraph()
  152. {
  153. throw new NotImplementedException();
  154. }
  155. public void Train(Session sess)
  156. {
  157. throw new NotImplementedException();
  158. }
  159. public void Predict(Session sess)
  160. {
  161. throw new NotImplementedException();
  162. }
  163. public void Test(Session sess)
  164. {
  165. throw new NotImplementedException();
  166. }
  167. }
  168. }