diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index caa669d3..c60a49c1 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -37,8 +37,7 @@ namespace Tensorflow public Session as_default() { - tf._defaultSessionFactory.Value = this; - return this; + return ops.set_default_session(this); } [MethodImpl(MethodImplOptions.NoOptimization)] diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 02417594..633a9bf7 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -28,10 +28,6 @@ namespace Tensorflow { public partial class ops { - private static readonly ThreadLocal _defaultGraphFactory = new ThreadLocal(() => new DefaultGraphStack()); - - public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; - public static int tensor_id(Tensor tensor) { return tensor.Id; @@ -78,53 +74,6 @@ namespace Tensorflow return get_default_graph().get_collection_ref(key); } - /// - /// Returns the default graph for the current thread. - /// - /// The returned graph will be the innermost graph on which a - /// `Graph.as_default()` context has been entered, or a global default - /// graph if none has been explicitly created. - /// - /// NOTE: The default graph is a property of the current thread.If you - /// create a new thread, and wish to use the default graph in that - /// thread, you must explicitly add a `with g.as_default():` in that - /// thread's function. - /// - /// - public static Graph get_default_graph() - { - //TODO: original source indicates there should be a _default_graph_stack! - //return _default_graph_stack.get_default() - return default_graph_stack.get_controller(); - } - - public static Graph set_default_graph(Graph graph) - { - //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! - default_graph_stack.set_controller(graph); - return default_graph_stack.get_controller(); - } - - /// - /// Clears the default graph stack and resets the global default graph. - /// - /// NOTE: The default graph is a property of the current thread.This - /// function applies only to the current thread.Calling this function while - /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined - /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects - /// after calling this function will result in undefined behavior. - /// - /// - public static void reset_default_graph() - { - //TODO: original source indicates there should be a _default_graph_stack! - //if (!_default_graph_stack.is_cleared()) - // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + - // "nested graphs. If you need a cleared graph, " + - // "exit the nesting and create a new graph."); - default_graph_stack.reset(); - } - public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); @@ -399,15 +348,6 @@ namespace Tensorflow return session.run(tensor, feed_dict); } - /// - /// Returns the default session for the current thread. - /// - /// The default `Session` being used in the current thread. - public static Session get_default_session() - { - return tf.defaultSession; - } - /// /// Prepends name scope to a name. /// diff --git a/src/TensorFlowNET.Core/ops.threading.cs b/src/TensorFlowNET.Core/ops.threading.cs new file mode 100644 index 00000000..f8796596 --- /dev/null +++ b/src/TensorFlowNET.Core/ops.threading.cs @@ -0,0 +1,152 @@ +using System.Threading; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class ops + { + private static readonly ThreadLocal _defaultGraphFactory = new ThreadLocal(() => new DefaultGraphStack()); + private static volatile Session _singleSesson; + private static volatile DefaultGraphStack _singleGraphStack; + private static readonly object _threadingLock = new object(); + + public static DefaultGraphStack default_graph_stack + { + get + { + if (!isSingleThreaded) + return _defaultGraphFactory.Value; + + if (_singleGraphStack == null) + { + lock (_threadingLock) + { + if (_singleGraphStack == null) + _singleGraphStack = new DefaultGraphStack(); + } + } + + return _singleGraphStack; + } + } + + private static bool isSingleThreaded = false; + + /// + /// Does this library ignore different thread accessing. + /// + /// https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading + public static bool IsSingleThreaded + { + get => isSingleThreaded; + set + { + if (value) + enforce_singlethreading(); + else + enforce_multithreading(); + } + } + + /// + /// Forces the library to ignore different thread accessing. + /// + /// https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading

Note that this discards any sessions and graphs used in a multithreaded manner
+ public static void enforce_singlethreading() + { + isSingleThreaded = true; + } + + /// + /// Forces the library to provide a separate and to every different thread accessing. + /// + /// https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading

Note that this discards any sessions and graphs used in a singlethreaded manner
+ public static void enforce_multithreading() + { + isSingleThreaded = false; + } + + /// + /// Returns the default session for the current thread. + /// + /// The default `Session` being used in the current thread. + public static Session get_default_session() + { + if (!isSingleThreaded) + return tf.defaultSession; + + if (_singleSesson == null) + { + lock (_threadingLock) + { + if (_singleSesson == null) + _singleSesson = new Session(); + } + } + + return _singleSesson; + } + + /// + /// Returns the default session for the current thread. + /// + /// The default `Session` being used in the current thread. + public static Session set_default_session(Session sess) + { + if (!isSingleThreaded) + return tf.defaultSession = sess; + + lock (_threadingLock) + { + _singleSesson = sess; + } + + return _singleSesson; + } + + /// + /// Returns the default graph for the current thread. + /// + /// The returned graph will be the innermost graph on which a + /// `Graph.as_default()` context has been entered, or a global default + /// graph if none has been explicitly created. + /// + /// NOTE: The default graph is a property of the current thread.If you + /// create a new thread, and wish to use the default graph in that + /// thread, you must explicitly add a `with g.as_default():` in that + /// thread's function. + /// + /// + public static Graph get_default_graph() + { + //return _default_graph_stack.get_default() + return default_graph_stack.get_controller(); + } + + public static Graph set_default_graph(Graph graph) + { + default_graph_stack.set_controller(graph); + return default_graph_stack.get_controller(); + } + + /// + /// Clears the default graph stack and resets the global default graph. + /// + /// NOTE: The default graph is a property of the current thread.This + /// function applies only to the current thread.Calling this function while + /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined + /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects + /// after calling this function will result in undefined behavior. + /// + /// + public static void reset_default_graph() + { + //if (!_default_graph_stack.is_cleared()) + // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + + // "nested graphs. If you need a cleared graph, " + + // "exit the nesting and create a new graph."); + default_graph_stack.reset(); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 6ccf55f5..a42297b2 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -21,8 +21,6 @@ namespace Tensorflow { public partial class tensorflow : IObjectLife { - protected internal readonly ThreadLocal _defaultSessionFactory; - public TF_DataType @byte = TF_DataType.TF_UINT8; public TF_DataType @sbyte = TF_DataType.TF_INT8; public TF_DataType int16 = TF_DataType.TF_INT16; @@ -40,10 +38,10 @@ namespace Tensorflow public tensorflow() { - _defaultSessionFactory = new ThreadLocal(() => new Session()); + _constructThreadingObjects(); } - public Session defaultSession => _defaultSessionFactory.Value; + public RefVariable Variable(T data, bool trainable = true, diff --git a/src/TensorFlowNET.Core/tensorflow.threading.cs b/src/TensorFlowNET.Core/tensorflow.threading.cs new file mode 100644 index 00000000..33e925fd --- /dev/null +++ b/src/TensorFlowNET.Core/tensorflow.threading.cs @@ -0,0 +1,53 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Runtime.CompilerServices; +using System.Threading; + +namespace Tensorflow +{ + public partial class tensorflow : IObjectLife + { + protected ThreadLocal _defaultSessionFactory; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void _constructThreadingObjects() + { + _defaultSessionFactory = new ThreadLocal(() => new Session()); + } + + public Session defaultSession + { + get + { + if (!ops.IsSingleThreaded) + return _defaultSessionFactory.Value; + + return ops.get_default_session(); + } + internal set + { + if (!ops.IsSingleThreaded) + { + _defaultSessionFactory.Value = value; + return; + } + + ops.set_default_session(value); + } + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs new file mode 100644 index 00000000..b7efc116 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Threading; +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class EnforcedSinglethreadingTests : CApiTest + { + private static readonly object _singlethreadLocker = new object(); + + /// Initializes a new instance of the class. + public EnforcedSinglethreadingTests() + { + ops.IsSingleThreaded = true; + } + + [TestMethod, Ignore("Has to be tested manually.")] + public void SessionCreation() + { + lock (_singlethreadLocker) + { + ops.IsSingleThreaded.Should().BeTrue(); + + ops.uid(); //increment id by one + + //the core method + tf.peak_default_graph().Should().BeNull(); + + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph); + + var (graph, session) = Parallely(() => (tf.get_default_graph(), tf.get_default_session())); + + graph.Should().BeEquivalentTo(default_graph); + session.Should().BeEquivalentTo(sess); + } + } + } + + T Parallely(Func fnc) + { + var mrh = new ManualResetEventSlim(); + T ret = default; + Exception e = default; + new Thread(() => + { + try + { + ret = fnc(); + } catch (Exception ee) + { + e = ee; + throw; + } finally + { + mrh.Set(); + } + }).Start(); + + if (!Debugger.IsAttached) + mrh.Wait(10000).Should().BeTrue(); + else + mrh.Wait(-1); + e.Should().BeNull(e?.ToString()); + return ret; + } + + void Parallely(Action fnc) + { + var mrh = new ManualResetEventSlim(); + Exception e = default; + new Thread(() => + { + try + { + fnc(); + } catch (Exception ee) + { + e = ee; + throw; + } finally + { + mrh.Set(); + } + }).Start(); + + mrh.Wait(10000).Should().BeTrue(); + e.Should().BeNull(e.ToString()); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs index f4f3f141..adae4fad 100644 --- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -283,14 +283,11 @@ namespace TensorFlowNET.UnitTest } } - private static string modelPath = "./model/"; + private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); [TestMethod] public void TF_GraphOperationByName_FromModel() { - if (!Directory.Exists(modelPath)) - return; - MultiThreadedUnitTestExecuter.Run(8, Core); //the core method diff --git a/test/TensorFlowNET.UnitTest/UnitTest.csproj b/test/TensorFlowNET.UnitTest/UnitTest.csproj index 58420b0a..6ff87f9d 100644 --- a/test/TensorFlowNET.UnitTest/UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/UnitTest.csproj @@ -43,6 +43,9 @@ PreserveNewest + + Always + diff --git a/test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb b/test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb new file mode 100644 index 00000000..f37debb5 Binary files /dev/null and b/test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb differ