diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index eec7f7f8..c1b033ae 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -20,8 +20,8 @@ namespace Tensorflow { public partial class tensorflow { - public graph_util_impl graph_util => new graph_util_impl(); - public GraphTransformer graph_transforms => new GraphTransformer(); + public graph_util_impl graph_util { get; } = new graph_util_impl(); + public GraphTransformer graph_transforms { get; } = new GraphTransformer(); public GraphKeys GraphKeys { get; } = new GraphKeys(); public void reset_default_graph() diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 833da09b..5d02c027 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -171,7 +171,7 @@ namespace Tensorflow.Contexts public void reset_context() { - ops.reset_uid(); + // ops.reset_uid(); // tf.defaultSession = null; ops.reset_default_graph(); context_switches.Clear(); diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs index 0eee3cdb..622b0071 100644 --- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -14,9 +14,8 @@ limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; -using System.Linq; -using static Tensorflow.Binding; namespace Tensorflow { @@ -25,19 +24,14 @@ namespace Tensorflow /// public class DefaultGraphStack { - private readonly Stack _stack = new Stack(); - Graph _global_default_graph; + Stack _stack = new Stack(); public Graph get_default() { - if (_stack.Count > 0) - return _stack.Peek(); - else if (_global_default_graph != null) - return _global_default_graph; - else - _global_default_graph = new Graph(); + if (_stack.Count == 0) + _stack.Push(new Graph()); - return _global_default_graph; + return _stack.Peek(); } public Graph get_controller(Graph g) @@ -61,7 +55,6 @@ namespace Tensorflow public void reset() { _stack.Clear(); - _global_default_graph = null; } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index d186b400..0a2c8c44 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -107,7 +107,7 @@ namespace Tensorflow.NumPy if (tensor.Handle == null) { if (tf.executing_eagerly()) - tensor = tf.defaultSession.eval(tensor); + tensor = tf.get_default_session().eval(tensor); } return new NDArray(tensor, tf.executing_eagerly()); diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index e70d6a0e..7417476e 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -38,7 +38,7 @@ namespace Tensorflow.NumPy { if (_handle is null) { - tensor = tf.defaultSession.eval(tensor); + tensor = tf.get_default_session().eval(tensor); _handle = tensor.Handle; } diff --git a/src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs b/src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs index 85f7dc59..8f3685cc 100644 --- a/src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs +++ b/src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs @@ -23,7 +23,7 @@ namespace Tensorflow.Variables { // gen_resource_variable_ops.destroy_resource_op(_tensor, ignore_lookup_error: true); - tf.device(_handle_device); + // tf.device(_handle_device); tf.Runner.TFE_Execute(tf.Context, _handle_device, "DestroyResourceOp", new[] { _tensor }, new object[] { "ignore_lookup_error", true }, 0); diff --git a/src/TensorFlowNET.Core/ops.threading.cs b/src/TensorFlowNET.Core/ops.threading.cs index f52dbcae..6c6476a5 100644 --- a/src/TensorFlowNET.Core/ops.threading.cs +++ b/src/TensorFlowNET.Core/ops.threading.cs @@ -1,70 +1,15 @@ -using System.Threading; +using System; +using System.Threading; using static Tensorflow.Binding; namespace Tensorflow { public partial class ops { - private static readonly ThreadLocal _defaultGraphFactory = new ThreadLocal(() => new DefaultGraphStack()); - private static volatile Session _singleSesson; - private static volatile DefaultGraphStack _singleGraphStack; - private static readonly object _threadingLock = new object(); - - public static DefaultGraphStack default_graph_stack - { - get - { - if (!isSingleThreaded) - return _defaultGraphFactory.Value; - - if (_singleGraphStack == null) - { - lock (_threadingLock) - { - if (_singleGraphStack == null) - _singleGraphStack = new DefaultGraphStack(); - } - } - - return _singleGraphStack; - } - } - - private static bool isSingleThreaded = false; - - /// - /// Does this library ignore different thread accessing. - /// - /// https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading - public static bool IsSingleThreaded - { - get => isSingleThreaded; - set - { - if (value) - enforce_singlethreading(); - else - enforce_multithreading(); - } - } - - /// - /// Forces the library to ignore different thread accessing. - /// - /// https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading

Note that this discards any sessions and graphs used in a multithreaded manner
- public static void enforce_singlethreading() - { - isSingleThreaded = true; - } - - /// - /// Forces the library to provide a separate and to every different thread accessing. - /// - /// https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading

Note that this discards any sessions and graphs used in a singlethreaded manner
- public static void enforce_multithreading() - { - isSingleThreaded = false; - } + [ThreadStatic] + static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); + [ThreadStatic] + static Session defaultSession; /// /// Returns the default session for the current thread. @@ -72,19 +17,10 @@ namespace Tensorflow /// The default `Session` being used in the current thread. public static Session get_default_session() { - if (!isSingleThreaded) - return tf.defaultSession; + if (defaultSession == null) + defaultSession = new Session(tf.get_default_graph()); - if (_singleSesson == null) - { - lock (_threadingLock) - { - if (_singleSesson == null) - _singleSesson = new Session(); - } - } - - return _singleSesson; + return defaultSession; } /// @@ -93,15 +29,8 @@ namespace Tensorflow /// The default `Session` being used in the current thread. public static Session set_default_session(Session sess) { - if (!isSingleThreaded) - return tf.defaultSession = sess; - - lock (_threadingLock) - { - _singleSesson = sess; - } - - return _singleSesson; + defaultSession = sess; + return sess; } /// @@ -118,10 +47,18 @@ namespace Tensorflow /// /// public static Graph get_default_graph() - => default_graph_stack.get_default(); + { + if (default_graph_stack == null) + default_graph_stack = new DefaultGraphStack(); + return default_graph_stack.get_default(); + } public static Graph set_default_graph(Graph g) - => default_graph_stack.get_controller(g); + { + if (default_graph_stack == null) + default_graph_stack = new DefaultGraphStack(); + return default_graph_stack.get_controller(g); + } /// /// Clears the default graph stack and resets the global default graph. @@ -135,6 +72,8 @@ namespace Tensorflow /// public static void reset_default_graph() { + if (default_graph_stack == null) + return; //if (!_default_graph_stack.is_cleared()) // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + // "nested graphs. If you need a cleared graph, " + @@ -143,7 +82,11 @@ namespace Tensorflow } public static Graph peak_default_graph() - => default_graph_stack.peak_controller(); + { + if (default_graph_stack == null) + default_graph_stack = new DefaultGraphStack(); + return default_graph_stack.peak_controller(); + } public static void pop_graph() => default_graph_stack.pop(); diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index a80b2007..8a2c78a7 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -16,6 +16,7 @@ using Serilog; using Serilog.Core; +using System.Threading; using Tensorflow.Contexts; using Tensorflow.Eager; using Tensorflow.Gradients; @@ -38,12 +39,18 @@ namespace Tensorflow public TF_DataType chars = TF_DataType.TF_STRING; public TF_DataType @string = TF_DataType.TF_STRING; - public Status Status; public OpDefLibrary OpDefLib; - public Context Context; - public IEagerRunner Runner; public Logger Logger; + ThreadLocal _status = new ThreadLocal(() => new Status()); + public Status Status => _status.Value; + + ThreadLocal _context = new ThreadLocal(() => new Context()); + public Context Context => _context.Value; + + ThreadLocal _runner = new ThreadLocal(() => new EagerRunner()); + public IEagerRunner Runner => _runner.Value; + public tensorflow() { Logger = new LoggerConfiguration() @@ -51,12 +58,8 @@ namespace Tensorflow .WriteTo.Console() .CreateLogger(); - Status = new Status(); - Context = new Context(); OpDefLib = new OpDefLibrary(); - ConstructThreadingObjects(); InitGradientEnvironment(); - Runner = new EagerRunner(); } public string VERSION => c_api.StringPiece(c_api.TF_Version()); diff --git a/src/TensorFlowNET.Core/tensorflow.threading.cs b/src/TensorFlowNET.Core/tensorflow.threading.cs deleted file mode 100644 index c1be5d90..00000000 --- a/src/TensorFlowNET.Core/tensorflow.threading.cs +++ /dev/null @@ -1,53 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System.Runtime.CompilerServices; -using System.Threading; - -namespace Tensorflow -{ - public partial class tensorflow - { - protected ThreadLocal defaultSessionFactory; - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void ConstructThreadingObjects() - { - defaultSessionFactory = new ThreadLocal(() => new Session()); - } - - public Session defaultSession - { - get - { - if (!ops.IsSingleThreaded) - return defaultSessionFactory.Value; - - return ops.get_default_session(); - } - internal set - { - if (!ops.IsSingleThreaded) - { - defaultSessionFactory.Value = value; - return; - } - - ops.set_default_session(value); - } - } - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index b2fe5747..eb14a795 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -12,6 +12,7 @@ using Tensorflow.Keras.Models; using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; +using System.Threading; namespace Tensorflow.Keras { diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index 024a8fc5..c7b9157b 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -9,6 +9,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; using Tensorflow.Functions; +using System.Threading; namespace Tensorflow.Keras.Layers { @@ -40,24 +41,24 @@ namespace Tensorflow.Keras.Layers return MakOp(inputs); } - ConcreteFunction function; + ThreadLocal function = new ThreadLocal(); Tensors DeFunCall(Tensors inputs) { - if(function == null) + if (function.Value == null) { - function = new ConcreteFunction(name); - function.Enter(); + function.Value = new ConcreteFunction(name); + function.Value.Enter(); int i = 0; var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray(); var graph_outputs = MakOp(graph_inputs); graph_outputs = mark_as_return(graph_outputs); - function.ToGraph(graph_inputs, graph_outputs); - function.Exit(); + function.Value.ToGraph(graph_inputs, graph_outputs); + function.Value.Exit(); } - var outputs = function.FilteredCall(inputs); + var outputs = function.Value.FilteredCall(inputs); return outputs; } diff --git a/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs index 91dc84b2..f657acc7 100644 --- a/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs @@ -24,14 +24,12 @@ namespace TensorFlowNET.UnitTest { Assert.IsNull(tf.peak_default_graph()); - using (var sess = tf.Session()) - { - var default_graph = tf.get_default_graph(); - var sess_graph = sess.graph; - Assert.IsNotNull(default_graph); - Assert.IsNotNull(sess_graph); - Assert.AreEqual(default_graph, sess_graph); - } + using var sess = tf.Session(); + var default_graph = tf.get_default_graph(); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); } } @@ -47,14 +45,12 @@ namespace TensorFlowNET.UnitTest { Assert.IsNull(tf.peak_default_graph()); //tf.Session created an other graph - using (var sess = tf.Session()) - { - var default_graph = tf.get_default_graph(); - var sess_graph = sess.graph; - Assert.IsNotNull(default_graph); - Assert.IsNotNull(sess_graph); - Assert.AreEqual(default_graph, sess_graph); - } + using var sess = tf.Session(); + var default_graph = tf.get_default_graph(); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); } } @@ -73,20 +69,12 @@ namespace TensorFlowNET.UnitTest beforehand.as_default(); Assert.IsNotNull(tf.peak_default_graph()); - using (var sess = tf.Session()) - { - var default_graph = tf.peak_default_graph(); - var sess_graph = sess.graph; - Assert.IsNotNull(default_graph); - Assert.IsNotNull(sess_graph); - Assert.AreEqual(default_graph, sess_graph); - - Console.WriteLine($"{tid}-{default_graph.graph_key}"); - - //var result = sess.run(new object[] {g, a}); - //var actualDeriv = result[0].GetData()[0]; - //var actual = result[1].GetData()[0]; - } + using var sess = tf.Session(); + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); } } @@ -114,13 +102,10 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - using (var sess = tf.Session()) + using var sess = tf.Session(); + for (int i = 0; i < 100; i++) { - Tensor t = null; - for (int i = 0; i < 100; i++) - { - t = new Tensor(1); - } + var t = new Tensor(1); } } } @@ -134,12 +119,10 @@ namespace TensorFlowNET.UnitTest void Core(int tid) { //tf.Session created an other graph - using (var sess = tf.Session()) + using var sess = tf.Session(); + for (int i = 0; i < 100; i++) { - for (int i = 0; i < 100; i++) - { - var t = new Tensor(new int[] { 1, 2, 3 }); - } + var t = new Tensor(new int[] { 1, 2, 3 }); } } } @@ -147,23 +130,23 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void SessionRun() { - MultiThreadedUnitTestExecuter.Run(8, Core); + MultiThreadedUnitTestExecuter.Run(2, Core); //the core method void Core(int tid) { + tf.compat.v1.disable_eager_execution(); + var graph = tf.Graph().as_default(); + //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); var math = a1 + a2; + using var sess = tf.Session(graph); for (int i = 0; i < 100; i++) { - var graph = tf.get_default_graph(); - using (var sess = tf.Session(graph)) - { - var result = sess.run(math); - Assert.AreEqual(result[0], 5f); - } + var result = sess.run(math); + Assert.AreEqual(result[0], 5f); } } } @@ -176,17 +159,18 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - using (var sess = tf.Session()) - { - Assert.IsNotNull(tf.get_default_graph()); - //graph is created automatically to perform create these operations - var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); - var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); - var math = a1 + a2; + tf.compat.v1.disable_eager_execution(); + var graph = tf.Graph().as_default(); - var result = sess.run(math); - Assert.AreEqual(result[0], 5f); - } + using var sess = tf.Session(graph); + Assert.IsNotNull(tf.get_default_graph()); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); + var math = a1 + a2; + + var result = sess.run(math); + Assert.AreEqual(result[0], 5f); } } @@ -198,14 +182,12 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - using (var sess = tf.Session()) - { - Assert.IsNotNull(tf.get_default_graph()); - //graph is created automatically to perform create these operations - var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); - var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); - var math = a1 + a2; - } + using var sess = tf.Session(); + Assert.IsNotNull(tf.get_default_graph()); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); + var math = a1 + a2; } } @@ -234,6 +216,10 @@ namespace TensorFlowNET.UnitTest void Core(int tid) { Assert.IsNull(tf.peak_default_graph()); + + tf.compat.v1.disable_eager_execution(); + var graph = tf.Graph().as_default(); + //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK"); @@ -248,7 +234,6 @@ namespace TensorFlowNET.UnitTest private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); [Ignore] - [TestMethod] public void TF_GraphOperationByName_FromModel() { MultiThreadedUnitTestExecuter.Run(8, Core); diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs new file mode 100644 index 00000000..cc8ac451 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs @@ -0,0 +1,95 @@ +using System; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System.Threading.Tasks; +using Tensorflow.NumPy; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace TensorFlowNET.Keras.UnitTest +{ + [TestClass] + public class MultiThreads + { + [TestMethod] + public void Test1() + { + //Arrange + string savefile = "mymodel.h5"; + var model1 = BuildModel(); + model1.save_weights(savefile); + var model2 = BuildModel(); + + //act + model1.load_weights(savefile); + model2.load_weights(savefile); + + } + + [TestMethod] + public void Test2() + { + //Arrange + string savefile = "mymodel2.h5"; + var model1 = BuildModel(); + model1.save_weights(savefile); + model1 = BuildModel(); //recreate model + + //act + model1.load_weights(savefile); + + } + + [TestMethod] + public void Test3Multithreading() + { + //Arrange + string savefile = "mymodel3.h5"; + var model = BuildModel(); + model.save_weights(savefile); + + //Sanity check without multithreading + for (int i = 0; i < 2; i++) + { + Functional clone = BuildModel(); + clone.load_weights(savefile); + + //Predict something + clone.predict(np.array(new float[,] { { 0, 0 } })); + } //works + + //act + ParallelOptions parallelOptions = new ParallelOptions(); + parallelOptions.MaxDegreeOfParallelism = 1; + var input = np.array(new float[,] { { 0, 0 } }); + Parallel.For(0, 1, parallelOptions, i => + { + var clone = BuildModel(); + clone.load_weights(savefile); + //Predict something + clone.predict(input); + }); + } + + Functional BuildModel() + { + tf.Context.reset_context(); + var inputs = keras.Input(shape: 2); + + // 1st dense layer + var DenseLayer = keras.layers.Dense(1, activation: keras.activations.Sigmoid); + var outputs = DenseLayer.Apply(inputs); + + // build keras model + Functional model = keras.Model(inputs, outputs, name: Guid.NewGuid().ToString()); + // show model summary + model.summary(); + + // compile keras model into tensorflow's static graph + model.compile(loss: keras.losses.MeanSquaredError(name: Guid.NewGuid().ToString()), + optimizer: keras.optimizers.Adam(name: Guid.NewGuid().ToString()), + metrics: new[] { "accuracy" }); + return model; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs index 796ace6c..e15100a0 100644 --- a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs @@ -16,7 +16,6 @@ namespace TensorFlowNET.UnitTest /// Initializes a new instance of the class. public EnforcedSinglethreadingTests() { - ops.IsSingleThreaded = true; } [TestMethod, Ignore("Has to be tested manually.")] @@ -24,8 +23,6 @@ namespace TensorFlowNET.UnitTest { lock (_singlethreadLocker) { - ops.IsSingleThreaded.Should().BeTrue(); - ops.uid(); //increment id by one //the core method