You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiThreadedUnitTestExecuter.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. using System;
  2. using System.Diagnostics;
  3. using System.Threading;
  4. namespace TensorFlowNET.UnitTest
  5. {
  6. public delegate void MultiThreadedTestDelegate(int threadid);
  7. /// <summary>
  8. /// Creates a synchronized eco-system of running code.
  9. /// </summary>
  10. public class MultiThreadedUnitTestExecuter : IDisposable
  11. {
  12. public int ThreadCount { get; }
  13. public Thread[] Threads { get; }
  14. public Exception[] Exceptions { get; }
  15. private readonly SemaphoreSlim barrier_threadstarted;
  16. private readonly ManualResetEventSlim barrier_corestart;
  17. private readonly SemaphoreSlim done_barrier2;
  18. public Action<MultiThreadedUnitTestExecuter> PostRun { get; set; }
  19. #region Static
  20. [DebuggerHidden]
  21. public static void Run(int threadCount, MultiThreadedTestDelegate workload)
  22. {
  23. if (workload == null) throw new ArgumentNullException(nameof(workload));
  24. if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
  25. new MultiThreadedUnitTestExecuter(threadCount).Run(workload);
  26. }
  27. [DebuggerHidden]
  28. public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads)
  29. {
  30. if (workloads == null) throw new ArgumentNullException(nameof(workloads));
  31. if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads));
  32. if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
  33. new MultiThreadedUnitTestExecuter(threadCount).Run(workloads);
  34. }
  35. [DebuggerHidden]
  36. public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action<MultiThreadedUnitTestExecuter> postRun)
  37. {
  38. if (workload == null) throw new ArgumentNullException(nameof(workload));
  39. if (postRun == null) throw new ArgumentNullException(nameof(postRun));
  40. if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
  41. new MultiThreadedUnitTestExecuter(threadCount) { PostRun = postRun }.Run(workload);
  42. }
  43. #endregion
  44. /// <summary>Initializes a new instance of the <see cref="T:System.Object"></see> class.</summary>
  45. public MultiThreadedUnitTestExecuter(int threadCount)
  46. {
  47. if (threadCount <= 0)
  48. throw new ArgumentOutOfRangeException(nameof(threadCount));
  49. ThreadCount = threadCount;
  50. Threads = new Thread[ThreadCount];
  51. Exceptions = new Exception[ThreadCount];
  52. done_barrier2 = new SemaphoreSlim(0, threadCount);
  53. barrier_corestart = new ManualResetEventSlim();
  54. barrier_threadstarted = new SemaphoreSlim(0, threadCount);
  55. }
  56. [DebuggerHidden]
  57. public void Run(params MultiThreadedTestDelegate[] workloads)
  58. {
  59. if (workloads == null)
  60. throw new ArgumentNullException(nameof(workloads));
  61. if (workloads.Length != 1 && workloads.Length % ThreadCount != 0)
  62. throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads.");
  63. if (ThreadCount == 1)
  64. {
  65. Exception ex = null;
  66. new Thread(() =>
  67. {
  68. try
  69. {
  70. workloads[0](0);
  71. }
  72. catch (Exception e)
  73. {
  74. if (Debugger.IsAttached)
  75. throw;
  76. ex = e;
  77. }
  78. finally
  79. {
  80. done_barrier2.Release(1);
  81. }
  82. }).Start();
  83. done_barrier2.Wait();
  84. if (ex != null)
  85. throw new Exception($"Thread 0 has failed: ", ex);
  86. PostRun?.Invoke(this);
  87. return;
  88. }
  89. //thread core
  90. Exception ThreadCore(MultiThreadedTestDelegate core, int threadid)
  91. {
  92. barrier_threadstarted.Release(1);
  93. barrier_corestart.Wait();
  94. //workload
  95. try
  96. {
  97. core(threadid);
  98. }
  99. catch (Exception e)
  100. {
  101. if (Debugger.IsAttached)
  102. throw;
  103. return e;
  104. }
  105. finally
  106. {
  107. done_barrier2.Release(1);
  108. }
  109. return null;
  110. }
  111. //initialize all threads
  112. if (workloads.Length == 1)
  113. {
  114. var workload = workloads[0];
  115. for (int i = 0; i < ThreadCount; i++)
  116. {
  117. var i_local = i;
  118. Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
  119. }
  120. }
  121. else
  122. {
  123. for (int i = 0; i < ThreadCount; i++)
  124. {
  125. var i_local = i;
  126. var workload = workloads[i_local % workloads.Length];
  127. Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
  128. }
  129. }
  130. //run all threads
  131. for (int i = 0; i < ThreadCount; i++) Threads[i].Start();
  132. //wait for threads to be started and ready
  133. for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait();
  134. //signal threads to start
  135. barrier_corestart.Set();
  136. //wait for threads to finish
  137. for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait();
  138. //handle fails
  139. for (int i = 0; i < ThreadCount; i++)
  140. if (Exceptions[i] != null)
  141. throw new Exception($"Thread {i} has failed: ", Exceptions[i]);
  142. //checks after ended
  143. PostRun?.Invoke(this);
  144. }
  145. public void Dispose()
  146. {
  147. barrier_threadstarted.Dispose();
  148. barrier_corestart.Dispose();
  149. done_barrier2.Dispose();
  150. }
  151. }
  152. }