using System;
using System.Diagnostics;
using System.Threading;
namespace TensorFlowNET.UnitTest
{
public delegate void MultiThreadedTestDelegate(int threadid);
///
/// Creates a synchronized eco-system of running code.
///
public class MultiThreadedUnitTestExecuter : IDisposable
{
public int ThreadCount { get; }
public Thread[] Threads { get; }
public Exception[] Exceptions { get; }
private readonly SemaphoreSlim barrier_threadstarted;
private readonly ManualResetEventSlim barrier_corestart;
private readonly SemaphoreSlim done_barrier2;
public Action PostRun { get; set; }
#region Static
[DebuggerHidden]
public static void Run(int threadCount, MultiThreadedTestDelegate workload)
{
if (workload == null) throw new ArgumentNullException(nameof(workload));
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
new MultiThreadedUnitTestExecuter(threadCount).Run(workload);
}
[DebuggerHidden]
public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads)
{
if (workloads == null) throw new ArgumentNullException(nameof(workloads));
if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads));
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
new MultiThreadedUnitTestExecuter(threadCount).Run(workloads);
}
[DebuggerHidden]
public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action postRun)
{
if (workload == null) throw new ArgumentNullException(nameof(workload));
if (postRun == null) throw new ArgumentNullException(nameof(postRun));
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
new MultiThreadedUnitTestExecuter(threadCount) { PostRun = postRun }.Run(workload);
}
#endregion
/// Initializes a new instance of the class.
public MultiThreadedUnitTestExecuter(int threadCount)
{
if (threadCount <= 0)
throw new ArgumentOutOfRangeException(nameof(threadCount));
ThreadCount = threadCount;
Threads = new Thread[ThreadCount];
Exceptions = new Exception[ThreadCount];
done_barrier2 = new SemaphoreSlim(0, threadCount);
barrier_corestart = new ManualResetEventSlim();
barrier_threadstarted = new SemaphoreSlim(0, threadCount);
}
[DebuggerHidden]
public void Run(params MultiThreadedTestDelegate[] workloads)
{
if (workloads == null)
throw new ArgumentNullException(nameof(workloads));
if (workloads.Length != 1 && workloads.Length % ThreadCount != 0)
throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads.");
if (ThreadCount == 1)
{
Exception ex = null;
new Thread(() =>
{
try
{
workloads[0](0);
}
catch (Exception e)
{
if (Debugger.IsAttached)
throw;
ex = e;
}
finally
{
done_barrier2.Release(1);
}
}).Start();
done_barrier2.Wait();
if (ex != null)
throw new Exception($"Thread 0 has failed: ", ex);
PostRun?.Invoke(this);
return;
}
//thread core
Exception ThreadCore(MultiThreadedTestDelegate core, int threadid)
{
barrier_threadstarted.Release(1);
barrier_corestart.Wait();
//workload
try
{
core(threadid);
}
catch (Exception e)
{
if (Debugger.IsAttached)
throw;
return e;
}
finally
{
done_barrier2.Release(1);
}
return null;
}
//initialize all threads
if (workloads.Length == 1)
{
var workload = workloads[0];
for (int i = 0; i < ThreadCount; i++)
{
var i_local = i;
Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
}
}
else
{
for (int i = 0; i < ThreadCount; i++)
{
var i_local = i;
var workload = workloads[i_local % workloads.Length];
Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
}
}
//run all threads
for (int i = 0; i < ThreadCount; i++) Threads[i].Start();
//wait for threads to be started and ready
for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait();
//signal threads to start
barrier_corestart.Set();
//wait for threads to finish
for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait();
//handle fails
for (int i = 0; i < ThreadCount; i++)
if (Exceptions[i] != null)
throw new Exception($"Thread {i} has failed: ", Exceptions[i]);
//checks after ended
PostRun?.Invoke(this);
}
public void Dispose()
{
barrier_threadstarted.Dispose();
barrier_corestart.Dispose();
done_barrier2.Dispose();
}
}
}