You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiThreadedUnitTestExecuter.cs 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. using System;
  2. using System.Diagnostics;
  3. using System.Threading;
  4. using Microsoft.VisualStudio.TestTools.UnitTesting;
  5. namespace TensorFlowNET.UnitTest
  6. {
  7. public delegate void MultiThreadedTestDelegate(int threadid);
  8. /// <summary>
  9. /// Creates a synchronized eco-system of running code.
  10. /// </summary>
  11. public class MultiThreadedUnitTestExecuter : IDisposable
  12. {
  13. public int ThreadCount { get; }
  14. public Thread[] Threads { get; }
  15. public Exception[] Exceptions { get; }
  16. private readonly SemaphoreSlim barrier_threadstarted;
  17. private readonly ManualResetEventSlim barrier_corestart;
  18. private readonly SemaphoreSlim done_barrier2;
  19. public Action<MultiThreadedUnitTestExecuter> PostRun { get; set; }
  20. #region Static
  21. [DebuggerHidden]
  22. public static void Run(int threadCount, MultiThreadedTestDelegate workload)
  23. {
  24. if (workload == null) throw new ArgumentNullException(nameof(workload));
  25. if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
  26. new MultiThreadedUnitTestExecuter(threadCount).Run(workload);
  27. }
  28. [DebuggerHidden]
  29. public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads)
  30. {
  31. if (workloads == null) throw new ArgumentNullException(nameof(workloads));
  32. if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads));
  33. if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
  34. new MultiThreadedUnitTestExecuter(threadCount).Run(workloads);
  35. }
  36. [DebuggerHidden]
  37. public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action<MultiThreadedUnitTestExecuter> postRun)
  38. {
  39. if (workload == null) throw new ArgumentNullException(nameof(workload));
  40. if (postRun == null) throw new ArgumentNullException(nameof(postRun));
  41. if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
  42. new MultiThreadedUnitTestExecuter(threadCount) {PostRun = postRun}.Run(workload);
  43. }
  44. #endregion
  45. /// <summary>Initializes a new instance of the <see cref="T:System.Object"></see> class.</summary>
  46. public MultiThreadedUnitTestExecuter(int threadCount)
  47. {
  48. if (threadCount <= 0)
  49. throw new ArgumentOutOfRangeException(nameof(threadCount));
  50. ThreadCount = threadCount;
  51. Threads = new Thread[ThreadCount];
  52. Exceptions = new Exception[ThreadCount];
  53. done_barrier2 = new SemaphoreSlim(0, threadCount);
  54. barrier_corestart = new ManualResetEventSlim();
  55. barrier_threadstarted = new SemaphoreSlim(0, threadCount);
  56. }
  57. [DebuggerHidden]
  58. public void Run(params MultiThreadedTestDelegate[] workloads)
  59. {
  60. if (workloads == null)
  61. throw new ArgumentNullException(nameof(workloads));
  62. if (workloads.Length != 1 && workloads.Length % ThreadCount != 0)
  63. throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads.");
  64. if (ThreadCount == 1)
  65. {
  66. Exception ex = null;
  67. new Thread(() =>
  68. {
  69. try
  70. {
  71. workloads[0](0);
  72. } catch (Exception e)
  73. {
  74. if (Debugger.IsAttached)
  75. throw;
  76. ex = e;
  77. } finally
  78. {
  79. done_barrier2.Release(1);
  80. }
  81. }).Start();
  82. done_barrier2.Wait();
  83. if (ex != null)
  84. throw new Exception($"Thread 0 has failed: ", ex);
  85. PostRun?.Invoke(this);
  86. return;
  87. }
  88. //thread core
  89. Exception ThreadCore(MultiThreadedTestDelegate core, int threadid)
  90. {
  91. barrier_threadstarted.Release(1);
  92. barrier_corestart.Wait();
  93. //workload
  94. try
  95. {
  96. core(threadid);
  97. } catch (Exception e)
  98. {
  99. if (Debugger.IsAttached)
  100. throw;
  101. return e;
  102. } finally
  103. {
  104. done_barrier2.Release(1);
  105. }
  106. return null;
  107. }
  108. //initialize all threads
  109. if (workloads.Length == 1)
  110. {
  111. var workload = workloads[0];
  112. for (int i = 0; i < ThreadCount; i++)
  113. {
  114. var i_local = i;
  115. Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
  116. }
  117. } else
  118. {
  119. for (int i = 0; i < ThreadCount; i++)
  120. {
  121. var i_local = i;
  122. var workload = workloads[i_local % workloads.Length];
  123. Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
  124. }
  125. }
  126. //run all threads
  127. for (int i = 0; i < ThreadCount; i++) Threads[i].Start();
  128. //wait for threads to be started and ready
  129. for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait();
  130. //signal threads to start
  131. barrier_corestart.Set();
  132. //wait for threads to finish
  133. for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait();
  134. //handle fails
  135. for (int i = 0; i < ThreadCount; i++)
  136. if (Exceptions[i] != null)
  137. throw new Exception($"Thread {i} has failed: ", Exceptions[i]);
  138. //checks after ended
  139. PostRun?.Invoke(this);
  140. }
  141. public void Dispose()
  142. {
  143. barrier_threadstarted.Dispose();
  144. barrier_corestart.Dispose();
  145. done_barrier2.Dispose();
  146. }
  147. }
  148. }