You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultithreadingTests.cs 9.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using Tensorflow;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class MultithreadingTests : GraphModeTestBase
  13. {
  14. [TestMethod]
  15. public void SessionCreation()
  16. {
  17. ops.uid(); //increment id by one
  18. MultiThreadedUnitTestExecuter.Run(8, Core);
  19. //the core method
  20. void Core(int tid)
  21. {
  22. Assert.IsNull(tf.peak_default_graph());
  23. using (var sess = tf.Session())
  24. {
  25. var default_graph = tf.get_default_graph();
  26. var sess_graph = sess.graph;
  27. Assert.IsNotNull(default_graph);
  28. Assert.IsNotNull(sess_graph);
  29. Assert.AreEqual(default_graph, sess_graph);
  30. }
  31. }
  32. }
  33. [TestMethod]
  34. public void SessionCreation_x2()
  35. {
  36. ops.uid(); //increment id by one
  37. MultiThreadedUnitTestExecuter.Run(16, Core);
  38. //the core method
  39. void Core(int tid)
  40. {
  41. Assert.IsNull(tf.peak_default_graph());
  42. //tf.Session created an other graph
  43. using (var sess = tf.Session())
  44. {
  45. var default_graph = tf.get_default_graph();
  46. var sess_graph = sess.graph;
  47. Assert.IsNotNull(default_graph);
  48. Assert.IsNotNull(sess_graph);
  49. Assert.AreEqual(default_graph, sess_graph);
  50. }
  51. }
  52. }
  53. [TestMethod]
  54. public void GraphCreation()
  55. {
  56. ops.uid(); //increment id by one
  57. MultiThreadedUnitTestExecuter.Run(8, Core);
  58. //the core method
  59. void Core(int tid)
  60. {
  61. Assert.IsNull(tf.peak_default_graph());
  62. var beforehand = tf.get_default_graph(); //this should create default automatically.
  63. beforehand.as_default();
  64. Assert.IsNotNull(tf.peak_default_graph());
  65. using (var sess = tf.Session())
  66. {
  67. var default_graph = tf.peak_default_graph();
  68. var sess_graph = sess.graph;
  69. Assert.IsNotNull(default_graph);
  70. Assert.IsNotNull(sess_graph);
  71. Assert.AreEqual(default_graph, sess_graph);
  72. Console.WriteLine($"{tid}-{default_graph.graph_key}");
  73. //var result = sess.run(new object[] {g, a});
  74. //var actualDeriv = result[0].GetData<float>()[0];
  75. //var actual = result[1].GetData<float>()[0];
  76. }
  77. }
  78. }
  79. [TestMethod]
  80. public void Marshal_AllocHGlobal()
  81. {
  82. MultiThreadedUnitTestExecuter.Run(8, Core);
  83. //the core method
  84. void Core(int tid)
  85. {
  86. for (int i = 0; i < 100; i++)
  87. {
  88. Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int)));
  89. }
  90. }
  91. }
  92. [TestMethod]
  93. public void TensorCreation()
  94. {
  95. MultiThreadedUnitTestExecuter.Run(8, Core);
  96. //the core method
  97. void Core(int tid)
  98. {
  99. using (var sess = tf.Session())
  100. {
  101. Tensor t = null;
  102. for (int i = 0; i < 100; i++)
  103. {
  104. t = new Tensor(1);
  105. }
  106. }
  107. }
  108. }
  109. [TestMethod]
  110. public void TensorCreation_Array()
  111. {
  112. MultiThreadedUnitTestExecuter.Run(8, Core);
  113. //the core method
  114. void Core(int tid)
  115. {
  116. //tf.Session created an other graph
  117. using (var sess = tf.Session())
  118. {
  119. for (int i = 0; i < 100; i++)
  120. {
  121. var t = new Tensor(new int[] { 1, 2, 3 });
  122. }
  123. }
  124. }
  125. }
  126. [TestMethod]
  127. public void SessionRun()
  128. {
  129. MultiThreadedUnitTestExecuter.Run(8, Core);
  130. //the core method
  131. void Core(int tid)
  132. {
  133. Assert.IsNull(tf.peak_default_graph());
  134. //graph is created automatically to perform create these operations
  135. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  136. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  137. var math = a1 + a2;
  138. for (int i = 0; i < 100; i++)
  139. {
  140. var graph = tf.get_default_graph();
  141. using (var sess = tf.Session(graph))
  142. {
  143. var result = sess.run(math);
  144. Assert.AreEqual(result[0], 5f);
  145. }
  146. }
  147. }
  148. }
  149. [TestMethod]
  150. public void SessionRun_InsideSession()
  151. {
  152. MultiThreadedUnitTestExecuter.Run(8, Core);
  153. //the core method
  154. void Core(int tid)
  155. {
  156. using (var sess = tf.Session())
  157. {
  158. Assert.IsNotNull(tf.get_default_graph());
  159. //graph is created automatically to perform create these operations
  160. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  161. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  162. var math = a1 + a2;
  163. var result = sess.run(math);
  164. Assert.AreEqual(result[0], 5f);
  165. }
  166. }
  167. }
  168. [TestMethod]
  169. public void SessionRun_Initialization()
  170. {
  171. MultiThreadedUnitTestExecuter.Run(8, Core);
  172. //the core method
  173. void Core(int tid)
  174. {
  175. using (var sess = tf.Session())
  176. {
  177. Assert.IsNotNull(tf.get_default_graph());
  178. //graph is created automatically to perform create these operations
  179. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  180. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  181. var math = a1 + a2;
  182. }
  183. }
  184. }
  185. [TestMethod]
  186. public void SessionRun_Initialization_OutsideSession()
  187. {
  188. MultiThreadedUnitTestExecuter.Run(8, Core);
  189. //the core method
  190. void Core(int tid)
  191. {
  192. Assert.IsNull(tf.peak_default_graph());
  193. //graph is created automatically to perform create these operations
  194. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  195. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  196. var math = a1 + a2;
  197. }
  198. }
  199. [TestMethod]
  200. public void TF_GraphOperationByName()
  201. {
  202. MultiThreadedUnitTestExecuter.Run(8, Core);
  203. //the core method
  204. void Core(int tid)
  205. {
  206. Assert.IsNull(tf.peak_default_graph());
  207. //graph is created automatically to perform create these operations
  208. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  209. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK");
  210. var math = a1 + a2;
  211. for (int i = 0; i < 100; i++)
  212. {
  213. var op = tf.get_default_graph().OperationByName("ConstantK");
  214. }
  215. }
  216. }
  217. private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/");
  218. [Ignore]
  219. [TestMethod]
  220. public void TF_GraphOperationByName_FromModel()
  221. {
  222. MultiThreadedUnitTestExecuter.Run(8, Core);
  223. //the core method
  224. void Core(int tid)
  225. {
  226. Console.WriteLine();
  227. for (int j = 0; j < 100; j++)
  228. {
  229. var sess = Session.LoadFromSavedModel(modelPath).as_default();
  230. var inputs = new[] { "sp", "fuel" };
  231. var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray();
  232. var outp = sess.graph.OperationByName("softmax_tensor").output;
  233. for (var i = 0; i < 8; i++)
  234. {
  235. var data = new float[96];
  236. FeedItem[] feeds = new FeedItem[2];
  237. for (int f = 0; f < 2; f++)
  238. feeds[f] = new FeedItem(inp[f], new NDArray(data));
  239. sess.run(outp, feeds);
  240. }
  241. }
  242. }
  243. }
  244. }
  245. }