diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index eb14a795..02362a55 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -25,7 +25,8 @@ namespace Tensorflow.Keras public LossesApi losses { get; } = new LossesApi(); public Activations activations { get; } = new Activations(); public Preprocessing preprocessing { get; } = new Preprocessing(); - public BackendImpl backend { get; } = new BackendImpl(); + ThreadLocal _backend = new ThreadLocal(() => new BackendImpl()); + public BackendImpl backend => _backend.Value; public OptimizerApi optimizers { get; } = new OptimizerApi(); public MetricsApi metrics { get; } = new MetricsApi(); public ModelsApi models { get; } = new ModelsApi(); diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs index cc8ac451..555154d7 100644 --- a/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs @@ -60,9 +60,9 @@ namespace TensorFlowNET.Keras.UnitTest //act ParallelOptions parallelOptions = new ParallelOptions(); - parallelOptions.MaxDegreeOfParallelism = 1; + parallelOptions.MaxDegreeOfParallelism = 8; var input = np.array(new float[,] { { 0, 0 } }); - Parallel.For(0, 1, parallelOptions, i => + Parallel.For(0, 8, parallelOptions, i => { var clone = BuildModel(); clone.load_weights(savefile);