You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultithreadingTests.cs 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using Tensorflow;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class MultithreadingTests : GraphModeTestBase
  13. {
  14. [TestMethod]
  15. public void SessionCreation()
  16. {
  17. ops.uid(); //increment id by one
  18. MultiThreadedUnitTestExecuter.Run(8, Core);
  19. //the core method
  20. void Core(int tid)
  21. {
  22. Assert.IsNull(tf.peak_default_graph());
  23. using (var sess = tf.Session())
  24. {
  25. var default_graph = tf.get_default_graph();
  26. var sess_graph = sess.graph;
  27. Assert.IsNotNull(default_graph);
  28. Assert.IsNotNull(sess_graph);
  29. Assert.AreEqual(default_graph, sess_graph);
  30. }
  31. }
  32. }
  33. [TestMethod]
  34. public void SessionCreation_x2()
  35. {
  36. ops.uid(); //increment id by one
  37. MultiThreadedUnitTestExecuter.Run(16, Core);
  38. //the core method
  39. void Core(int tid)
  40. {
  41. Assert.IsNull(tf.peak_default_graph());
  42. //tf.Session created an other graph
  43. using (var sess = tf.Session())
  44. {
  45. var default_graph = tf.get_default_graph();
  46. var sess_graph = sess.graph;
  47. Assert.IsNotNull(default_graph);
  48. Assert.IsNotNull(sess_graph);
  49. Assert.AreEqual(default_graph, sess_graph);
  50. }
  51. }
  52. }
  53. [TestMethod]
  54. public void GraphCreation()
  55. {
  56. ops.uid(); //increment id by one
  57. MultiThreadedUnitTestExecuter.Run(8, Core);
  58. //the core method
  59. void Core(int tid)
  60. {
  61. Assert.IsNull(tf.peak_default_graph());
  62. var beforehand = tf.get_default_graph(); //this should create default automatically.
  63. beforehand.as_default();
  64. Assert.IsNotNull(tf.peak_default_graph());
  65. using (var sess = tf.Session())
  66. {
  67. var default_graph = tf.peak_default_graph();
  68. var sess_graph = sess.graph;
  69. Assert.IsNotNull(default_graph);
  70. Assert.IsNotNull(sess_graph);
  71. Assert.AreEqual(default_graph, sess_graph);
  72. Console.WriteLine($"{tid}-{default_graph.graph_key}");
  73. //var result = sess.run(new object[] {g, a});
  74. //var actualDeriv = result[0].GetData<float>()[0];
  75. //var actual = result[1].GetData<float>()[0];
  76. }
  77. }
  78. }
  79. [TestMethod]
  80. public void Marshal_AllocHGlobal()
  81. {
  82. MultiThreadedUnitTestExecuter.Run(8, Core);
  83. //the core method
  84. void Core(int tid)
  85. {
  86. for (int i = 0; i < 100; i++)
  87. {
  88. Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int)));
  89. }
  90. }
  91. }
  92. [TestMethod]
  93. public void TensorCreation()
  94. {
  95. MultiThreadedUnitTestExecuter.Run(8, Core);
  96. //the core method
  97. void Core(int tid)
  98. {
  99. using (var sess = tf.Session())
  100. {
  101. Tensor t = null;
  102. for (int i = 0; i < 100; i++)
  103. {
  104. t = new Tensor(1);
  105. }
  106. }
  107. }
  108. }
  109. [TestMethod]
  110. public void TensorCreation_Array()
  111. {
  112. MultiThreadedUnitTestExecuter.Run(8, Core);
  113. //the core method
  114. void Core(int tid)
  115. {
  116. //tf.Session created an other graph
  117. using (var sess = tf.Session())
  118. {
  119. for (int i = 0; i < 100; i++)
  120. {
  121. var t = new Tensor(new int[] { 1, 2, 3 });
  122. }
  123. }
  124. }
  125. }
  126. [TestMethod]
  127. public void SessionRun()
  128. {
  129. MultiThreadedUnitTestExecuter.Run(8, Core);
  130. //the core method
  131. void Core(int tid)
  132. {
  133. //graph is created automatically to perform create these operations
  134. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  135. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  136. var math = a1 + a2;
  137. for (int i = 0; i < 100; i++)
  138. {
  139. var graph = tf.get_default_graph();
  140. using (var sess = tf.Session(graph))
  141. {
  142. var result = sess.run(math);
  143. Assert.AreEqual(result[0], 5f);
  144. }
  145. }
  146. }
  147. }
  148. [TestMethod]
  149. public void SessionRun_InsideSession()
  150. {
  151. MultiThreadedUnitTestExecuter.Run(8, Core);
  152. //the core method
  153. void Core(int tid)
  154. {
  155. using (var sess = tf.Session())
  156. {
  157. Assert.IsNotNull(tf.get_default_graph());
  158. //graph is created automatically to perform create these operations
  159. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  160. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  161. var math = a1 + a2;
  162. var result = sess.run(math);
  163. Assert.AreEqual(result[0], 5f);
  164. }
  165. }
  166. }
  167. [TestMethod]
  168. public void SessionRun_Initialization()
  169. {
  170. MultiThreadedUnitTestExecuter.Run(8, Core);
  171. //the core method
  172. void Core(int tid)
  173. {
  174. using (var sess = tf.Session())
  175. {
  176. Assert.IsNotNull(tf.get_default_graph());
  177. //graph is created automatically to perform create these operations
  178. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  179. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  180. var math = a1 + a2;
  181. }
  182. }
  183. }
  184. [TestMethod]
  185. public void SessionRun_Initialization_OutsideSession()
  186. {
  187. MultiThreadedUnitTestExecuter.Run(8, Core);
  188. //the core method
  189. void Core(int tid)
  190. {
  191. Assert.IsNull(tf.peak_default_graph());
  192. //graph is created automatically to perform create these operations
  193. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  194. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  195. var math = a1 + a2;
  196. }
  197. }
  198. [TestMethod]
  199. public void TF_GraphOperationByName()
  200. {
  201. MultiThreadedUnitTestExecuter.Run(8, Core);
  202. //the core method
  203. void Core(int tid)
  204. {
  205. Assert.IsNull(tf.peak_default_graph());
  206. //graph is created automatically to perform create these operations
  207. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  208. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK");
  209. var math = a1 + a2;
  210. for (int i = 0; i < 100; i++)
  211. {
  212. var op = tf.get_default_graph().OperationByName("ConstantK");
  213. }
  214. }
  215. }
  216. private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/");
  217. [Ignore]
  218. [TestMethod]
  219. public void TF_GraphOperationByName_FromModel()
  220. {
  221. MultiThreadedUnitTestExecuter.Run(8, Core);
  222. //the core method
  223. void Core(int tid)
  224. {
  225. Console.WriteLine();
  226. for (int j = 0; j < 100; j++)
  227. {
  228. var sess = Session.LoadFromSavedModel(modelPath).as_default();
  229. var inputs = new[] { "sp", "fuel" };
  230. var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray();
  231. var outp = sess.graph.OperationByName("softmax_tensor").output;
  232. for (var i = 0; i < 8; i++)
  233. {
  234. var data = new float[96];
  235. FeedItem[] feeds = new FeedItem[2];
  236. for (int f = 0; f < 2; f++)
  237. feeds[f] = new FeedItem(inp[f], new NDArray(data));
  238. sess.run(outp, feeds);
  239. }
  240. }
  241. }
  242. }
  243. }
  244. }