- Separated multithreading related methods to classname.threading.cs partial file - ops: Added enforce_singlethreading(), enforce_multithreading()tags/v0.20
@@ -37,8 +37,7 @@ namespace Tensorflow | |||||
public Session as_default() | public Session as_default() | ||||
{ | { | ||||
tf._defaultSessionFactory.Value = this; | |||||
return this; | |||||
return ops.set_default_session(this); | |||||
} | } | ||||
[MethodImpl(MethodImplOptions.NoOptimization)] | [MethodImpl(MethodImplOptions.NoOptimization)] | ||||
@@ -28,10 +28,6 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class ops | public partial class ops | ||||
{ | { | ||||
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||||
public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; | |||||
public static int tensor_id(Tensor tensor) | public static int tensor_id(Tensor tensor) | ||||
{ | { | ||||
return tensor.Id; | return tensor.Id; | ||||
@@ -78,53 +74,6 @@ namespace Tensorflow | |||||
return get_default_graph().get_collection_ref<T>(key); | return get_default_graph().get_collection_ref<T>(key); | ||||
} | } | ||||
/// <summary> | |||||
/// Returns the default graph for the current thread. | |||||
/// | |||||
/// The returned graph will be the innermost graph on which a | |||||
/// `Graph.as_default()` context has been entered, or a global default | |||||
/// graph if none has been explicitly created. | |||||
/// | |||||
/// NOTE: The default graph is a property of the current thread.If you | |||||
/// create a new thread, and wish to use the default graph in that | |||||
/// thread, you must explicitly add a `with g.as_default():` in that | |||||
/// thread's function. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static Graph get_default_graph() | |||||
{ | |||||
//TODO: original source indicates there should be a _default_graph_stack! | |||||
//return _default_graph_stack.get_default() | |||||
return default_graph_stack.get_controller(); | |||||
} | |||||
public static Graph set_default_graph(Graph graph) | |||||
{ | |||||
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||||
default_graph_stack.set_controller(graph); | |||||
return default_graph_stack.get_controller(); | |||||
} | |||||
/// <summary> | |||||
/// Clears the default graph stack and resets the global default graph. | |||||
/// | |||||
/// NOTE: The default graph is a property of the current thread.This | |||||
/// function applies only to the current thread.Calling this function while | |||||
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||||
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||||
/// after calling this function will result in undefined behavior. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static void reset_default_graph() | |||||
{ | |||||
//TODO: original source indicates there should be a _default_graph_stack! | |||||
//if (!_default_graph_stack.is_cleared()) | |||||
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||||
// "nested graphs. If you need a cleared graph, " + | |||||
// "exit the nesting and create a new graph."); | |||||
default_graph_stack.reset(); | |||||
} | |||||
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | ||||
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | ||||
@@ -399,15 +348,6 @@ namespace Tensorflow | |||||
return session.run(tensor, feed_dict); | return session.run(tensor, feed_dict); | ||||
} | } | ||||
/// <summary> | |||||
/// Returns the default session for the current thread. | |||||
/// </summary> | |||||
/// <returns>The default `Session` being used in the current thread.</returns> | |||||
public static Session get_default_session() | |||||
{ | |||||
return tf.defaultSession; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Prepends name scope to a name. | /// Prepends name scope to a name. | ||||
/// </summary> | /// </summary> | ||||
@@ -0,0 +1,152 @@ | |||||
using System.Threading; | |||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class ops | |||||
{ | |||||
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||||
private static volatile Session _singleSesson; | |||||
private static volatile DefaultGraphStack _singleGraphStack; | |||||
private static readonly object _threadingLock = new object(); | |||||
public static DefaultGraphStack default_graph_stack | |||||
{ | |||||
get | |||||
{ | |||||
if (!isSingleThreaded) | |||||
return _defaultGraphFactory.Value; | |||||
if (_singleGraphStack == null) | |||||
{ | |||||
lock (_threadingLock) | |||||
{ | |||||
if (_singleGraphStack == null) | |||||
_singleGraphStack = new DefaultGraphStack(); | |||||
} | |||||
} | |||||
return _singleGraphStack; | |||||
} | |||||
} | |||||
private static bool isSingleThreaded = false; | |||||
/// <summary> | |||||
/// Does this library ignore different thread accessing. | |||||
/// </summary> | |||||
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks> | |||||
public static bool IsSingleThreaded | |||||
{ | |||||
get => isSingleThreaded; | |||||
set | |||||
{ | |||||
if (value) | |||||
enforce_singlethreading(); | |||||
else | |||||
enforce_multithreading(); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Forces the library to ignore different thread accessing. | |||||
/// </summary> | |||||
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks> | |||||
public static void enforce_singlethreading() | |||||
{ | |||||
isSingleThreaded = true; | |||||
} | |||||
/// <summary> | |||||
/// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing. | |||||
/// </summary> | |||||
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks> | |||||
public static void enforce_multithreading() | |||||
{ | |||||
isSingleThreaded = false; | |||||
} | |||||
/// <summary> | |||||
/// Returns the default session for the current thread. | |||||
/// </summary> | |||||
/// <returns>The default `Session` being used in the current thread.</returns> | |||||
public static Session get_default_session() | |||||
{ | |||||
if (!isSingleThreaded) | |||||
return tf.defaultSession; | |||||
if (_singleSesson == null) | |||||
{ | |||||
lock (_threadingLock) | |||||
{ | |||||
if (_singleSesson == null) | |||||
_singleSesson = new Session(); | |||||
} | |||||
} | |||||
return _singleSesson; | |||||
} | |||||
/// <summary> | |||||
/// Returns the default session for the current thread. | |||||
/// </summary> | |||||
/// <returns>The default `Session` being used in the current thread.</returns> | |||||
public static Session set_default_session(Session sess) | |||||
{ | |||||
if (!isSingleThreaded) | |||||
return tf.defaultSession = sess; | |||||
lock (_threadingLock) | |||||
{ | |||||
_singleSesson = sess; | |||||
} | |||||
return _singleSesson; | |||||
} | |||||
/// <summary> | |||||
/// Returns the default graph for the current thread. | |||||
/// | |||||
/// The returned graph will be the innermost graph on which a | |||||
/// `Graph.as_default()` context has been entered, or a global default | |||||
/// graph if none has been explicitly created. | |||||
/// | |||||
/// NOTE: The default graph is a property of the current thread.If you | |||||
/// create a new thread, and wish to use the default graph in that | |||||
/// thread, you must explicitly add a `with g.as_default():` in that | |||||
/// thread's function. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static Graph get_default_graph() | |||||
{ | |||||
//return _default_graph_stack.get_default() | |||||
return default_graph_stack.get_controller(); | |||||
} | |||||
public static Graph set_default_graph(Graph graph) | |||||
{ | |||||
default_graph_stack.set_controller(graph); | |||||
return default_graph_stack.get_controller(); | |||||
} | |||||
/// <summary> | |||||
/// Clears the default graph stack and resets the global default graph. | |||||
/// | |||||
/// NOTE: The default graph is a property of the current thread.This | |||||
/// function applies only to the current thread.Calling this function while | |||||
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||||
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||||
/// after calling this function will result in undefined behavior. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static void reset_default_graph() | |||||
{ | |||||
//if (!_default_graph_stack.is_cleared()) | |||||
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||||
// "nested graphs. If you need a cleared graph, " + | |||||
// "exit the nesting and create a new graph."); | |||||
default_graph_stack.reset(); | |||||
} | |||||
} | |||||
} |
@@ -21,8 +21,6 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class tensorflow : IObjectLife | public partial class tensorflow : IObjectLife | ||||
{ | { | ||||
protected internal readonly ThreadLocal<Session> _defaultSessionFactory; | |||||
public TF_DataType @byte = TF_DataType.TF_UINT8; | public TF_DataType @byte = TF_DataType.TF_UINT8; | ||||
public TF_DataType @sbyte = TF_DataType.TF_INT8; | public TF_DataType @sbyte = TF_DataType.TF_INT8; | ||||
public TF_DataType int16 = TF_DataType.TF_INT16; | public TF_DataType int16 = TF_DataType.TF_INT16; | ||||
@@ -40,10 +38,10 @@ namespace Tensorflow | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||||
_constructThreadingObjects(); | |||||
} | } | ||||
public Session defaultSession => _defaultSessionFactory.Value; | |||||
public RefVariable Variable<T>(T data, | public RefVariable Variable<T>(T data, | ||||
bool trainable = true, | bool trainable = true, | ||||
@@ -0,0 +1,53 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System.Runtime.CompilerServices; | |||||
using System.Threading; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow : IObjectLife | |||||
{ | |||||
protected ThreadLocal<Session> _defaultSessionFactory; | |||||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
public void _constructThreadingObjects() | |||||
{ | |||||
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||||
} | |||||
public Session defaultSession | |||||
{ | |||||
get | |||||
{ | |||||
if (!ops.IsSingleThreaded) | |||||
return _defaultSessionFactory.Value; | |||||
return ops.get_default_session(); | |||||
} | |||||
internal set | |||||
{ | |||||
if (!ops.IsSingleThreaded) | |||||
{ | |||||
_defaultSessionFactory.Value = value; | |||||
return; | |||||
} | |||||
ops.set_default_session(value); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,107 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.IO; | |||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | |||||
using System.Threading; | |||||
using FluentAssertions; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using NumSharp; | |||||
using Tensorflow; | |||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | |||||
namespace TensorFlowNET.UnitTest | |||||
{ | |||||
[TestClass] | |||||
public class EnforcedSinglethreadingTests : CApiTest | |||||
{ | |||||
private static readonly object _singlethreadLocker = new object(); | |||||
/// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary> | |||||
public EnforcedSinglethreadingTests() | |||||
{ | |||||
ops.IsSingleThreaded = true; | |||||
} | |||||
[TestMethod, Ignore("Has to be tested manually.")] | |||||
public void SessionCreation() | |||||
{ | |||||
lock (_singlethreadLocker) | |||||
{ | |||||
ops.IsSingleThreaded.Should().BeTrue(); | |||||
ops.uid(); //increment id by one | |||||
//the core method | |||||
tf.peak_default_graph().Should().BeNull(); | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
var default_graph = tf.peak_default_graph(); | |||||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||||
sess_graph.Should().NotBeNull(); | |||||
default_graph.Should().NotBeNull() | |||||
.And.BeEquivalentTo(sess_graph); | |||||
var (graph, session) = Parallely(() => (tf.get_default_graph(), tf.get_default_session())); | |||||
graph.Should().BeEquivalentTo(default_graph); | |||||
session.Should().BeEquivalentTo(sess); | |||||
} | |||||
} | |||||
} | |||||
T Parallely<T>(Func<T> fnc) | |||||
{ | |||||
var mrh = new ManualResetEventSlim(); | |||||
T ret = default; | |||||
Exception e = default; | |||||
new Thread(() => | |||||
{ | |||||
try | |||||
{ | |||||
ret = fnc(); | |||||
} catch (Exception ee) | |||||
{ | |||||
e = ee; | |||||
throw; | |||||
} finally | |||||
{ | |||||
mrh.Set(); | |||||
} | |||||
}).Start(); | |||||
if (!Debugger.IsAttached) | |||||
mrh.Wait(10000).Should().BeTrue(); | |||||
else | |||||
mrh.Wait(-1); | |||||
e.Should().BeNull(e?.ToString()); | |||||
return ret; | |||||
} | |||||
void Parallely(Action fnc) | |||||
{ | |||||
var mrh = new ManualResetEventSlim(); | |||||
Exception e = default; | |||||
new Thread(() => | |||||
{ | |||||
try | |||||
{ | |||||
fnc(); | |||||
} catch (Exception ee) | |||||
{ | |||||
e = ee; | |||||
throw; | |||||
} finally | |||||
{ | |||||
mrh.Set(); | |||||
} | |||||
}).Start(); | |||||
mrh.Wait(10000).Should().BeTrue(); | |||||
e.Should().BeNull(e.ToString()); | |||||
} | |||||
} | |||||
} |
@@ -283,14 +283,11 @@ namespace TensorFlowNET.UnitTest | |||||
} | } | ||||
} | } | ||||
private static string modelPath = "./model/"; | |||||
private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); | |||||
[TestMethod] | [TestMethod] | ||||
public void TF_GraphOperationByName_FromModel() | public void TF_GraphOperationByName_FromModel() | ||||
{ | { | ||||
if (!Directory.Exists(modelPath)) | |||||
return; | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | MultiThreadedUnitTestExecuter.Run(8, Core); | ||||
//the core method | //the core method | ||||
@@ -43,6 +43,9 @@ | |||||
<None Update="model\saved_model.pb"> | <None Update="model\saved_model.pb"> | ||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
</None> | </None> | ||||
<None Update="Utilities\models\example1\saved_model.pb"> | |||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory> | |||||
</None> | |||||
</ItemGroup> | </ItemGroup> | ||||
</Project> | </Project> |