- Separated multithreading related methods to classname.threading.cs partial file - ops: Added enforce_singlethreading(), enforce_multithreading()tags/v0.20
@@ -37,8 +37,7 @@ namespace Tensorflow | |||
public Session as_default() | |||
{ | |||
tf._defaultSessionFactory.Value = this; | |||
return this; | |||
return ops.set_default_session(this); | |||
} | |||
[MethodImpl(MethodImplOptions.NoOptimization)] | |||
@@ -28,10 +28,6 @@ namespace Tensorflow | |||
{ | |||
public partial class ops | |||
{ | |||
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||
public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; | |||
public static int tensor_id(Tensor tensor) | |||
{ | |||
return tensor.Id; | |||
@@ -78,53 +74,6 @@ namespace Tensorflow | |||
return get_default_graph().get_collection_ref<T>(key); | |||
} | |||
/// <summary> | |||
/// Returns the default graph for the current thread. | |||
/// | |||
/// The returned graph will be the innermost graph on which a | |||
/// `Graph.as_default()` context has been entered, or a global default | |||
/// graph if none has been explicitly created. | |||
/// | |||
/// NOTE: The default graph is a property of the current thread.If you | |||
/// create a new thread, and wish to use the default graph in that | |||
/// thread, you must explicitly add a `with g.as_default():` in that | |||
/// thread's function. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static Graph get_default_graph() | |||
{ | |||
//TODO: original source indicates there should be a _default_graph_stack! | |||
//return _default_graph_stack.get_default() | |||
return default_graph_stack.get_controller(); | |||
} | |||
public static Graph set_default_graph(Graph graph) | |||
{ | |||
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||
default_graph_stack.set_controller(graph); | |||
return default_graph_stack.get_controller(); | |||
} | |||
/// <summary> | |||
/// Clears the default graph stack and resets the global default graph. | |||
/// | |||
/// NOTE: The default graph is a property of the current thread.This | |||
/// function applies only to the current thread.Calling this function while | |||
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||
/// after calling this function will result in undefined behavior. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static void reset_default_graph() | |||
{ | |||
//TODO: original source indicates there should be a _default_graph_stack! | |||
//if (!_default_graph_stack.is_cleared()) | |||
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||
// "nested graphs. If you need a cleared graph, " + | |||
// "exit the nesting and create a new graph."); | |||
default_graph_stack.reset(); | |||
} | |||
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | |||
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | |||
@@ -399,15 +348,6 @@ namespace Tensorflow | |||
return session.run(tensor, feed_dict); | |||
} | |||
/// <summary> | |||
/// Returns the default session for the current thread. | |||
/// </summary> | |||
/// <returns>The default `Session` being used in the current thread.</returns> | |||
public static Session get_default_session() | |||
{ | |||
return tf.defaultSession; | |||
} | |||
/// <summary> | |||
/// Prepends name scope to a name. | |||
/// </summary> | |||
@@ -0,0 +1,152 @@ | |||
using System.Threading; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public partial class ops | |||
{ | |||
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||
private static volatile Session _singleSesson; | |||
private static volatile DefaultGraphStack _singleGraphStack; | |||
private static readonly object _threadingLock = new object(); | |||
public static DefaultGraphStack default_graph_stack | |||
{ | |||
get | |||
{ | |||
if (!isSingleThreaded) | |||
return _defaultGraphFactory.Value; | |||
if (_singleGraphStack == null) | |||
{ | |||
lock (_threadingLock) | |||
{ | |||
if (_singleGraphStack == null) | |||
_singleGraphStack = new DefaultGraphStack(); | |||
} | |||
} | |||
return _singleGraphStack; | |||
} | |||
} | |||
private static bool isSingleThreaded = false; | |||
/// <summary> | |||
/// Does this library ignore different thread accessing. | |||
/// </summary> | |||
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks> | |||
public static bool IsSingleThreaded | |||
{ | |||
get => isSingleThreaded; | |||
set | |||
{ | |||
if (value) | |||
enforce_singlethreading(); | |||
else | |||
enforce_multithreading(); | |||
} | |||
} | |||
/// <summary> | |||
/// Forces the library to ignore different thread accessing. | |||
/// </summary> | |||
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks> | |||
public static void enforce_singlethreading() | |||
{ | |||
isSingleThreaded = true; | |||
} | |||
/// <summary> | |||
/// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing. | |||
/// </summary> | |||
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks> | |||
public static void enforce_multithreading() | |||
{ | |||
isSingleThreaded = false; | |||
} | |||
/// <summary> | |||
/// Returns the default session for the current thread. | |||
/// </summary> | |||
/// <returns>The default `Session` being used in the current thread.</returns> | |||
public static Session get_default_session() | |||
{ | |||
if (!isSingleThreaded) | |||
return tf.defaultSession; | |||
if (_singleSesson == null) | |||
{ | |||
lock (_threadingLock) | |||
{ | |||
if (_singleSesson == null) | |||
_singleSesson = new Session(); | |||
} | |||
} | |||
return _singleSesson; | |||
} | |||
/// <summary> | |||
/// Returns the default session for the current thread. | |||
/// </summary> | |||
/// <returns>The default `Session` being used in the current thread.</returns> | |||
public static Session set_default_session(Session sess) | |||
{ | |||
if (!isSingleThreaded) | |||
return tf.defaultSession = sess; | |||
lock (_threadingLock) | |||
{ | |||
_singleSesson = sess; | |||
} | |||
return _singleSesson; | |||
} | |||
/// <summary> | |||
/// Returns the default graph for the current thread. | |||
/// | |||
/// The returned graph will be the innermost graph on which a | |||
/// `Graph.as_default()` context has been entered, or a global default | |||
/// graph if none has been explicitly created. | |||
/// | |||
/// NOTE: The default graph is a property of the current thread.If you | |||
/// create a new thread, and wish to use the default graph in that | |||
/// thread, you must explicitly add a `with g.as_default():` in that | |||
/// thread's function. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static Graph get_default_graph() | |||
{ | |||
//return _default_graph_stack.get_default() | |||
return default_graph_stack.get_controller(); | |||
} | |||
public static Graph set_default_graph(Graph graph) | |||
{ | |||
default_graph_stack.set_controller(graph); | |||
return default_graph_stack.get_controller(); | |||
} | |||
/// <summary> | |||
/// Clears the default graph stack and resets the global default graph. | |||
/// | |||
/// NOTE: The default graph is a property of the current thread.This | |||
/// function applies only to the current thread.Calling this function while | |||
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||
/// after calling this function will result in undefined behavior. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static void reset_default_graph() | |||
{ | |||
//if (!_default_graph_stack.is_cleared()) | |||
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||
// "nested graphs. If you need a cleared graph, " + | |||
// "exit the nesting and create a new graph."); | |||
default_graph_stack.reset(); | |||
} | |||
} | |||
} |
@@ -21,8 +21,6 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow : IObjectLife | |||
{ | |||
protected internal readonly ThreadLocal<Session> _defaultSessionFactory; | |||
public TF_DataType @byte = TF_DataType.TF_UINT8; | |||
public TF_DataType @sbyte = TF_DataType.TF_INT8; | |||
public TF_DataType int16 = TF_DataType.TF_INT16; | |||
@@ -40,10 +38,10 @@ namespace Tensorflow | |||
public tensorflow() | |||
{ | |||
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||
_constructThreadingObjects(); | |||
} | |||
public Session defaultSession => _defaultSessionFactory.Value; | |||
public RefVariable Variable<T>(T data, | |||
bool trainable = true, | |||
@@ -0,0 +1,53 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Runtime.CompilerServices; | |||
using System.Threading; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow : IObjectLife | |||
{ | |||
protected ThreadLocal<Session> _defaultSessionFactory; | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
public void _constructThreadingObjects() | |||
{ | |||
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||
} | |||
public Session defaultSession | |||
{ | |||
get | |||
{ | |||
if (!ops.IsSingleThreaded) | |||
return _defaultSessionFactory.Value; | |||
return ops.get_default_session(); | |||
} | |||
internal set | |||
{ | |||
if (!ops.IsSingleThreaded) | |||
{ | |||
_defaultSessionFactory.Value = value; | |||
return; | |||
} | |||
ops.set_default_session(value); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,107 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using System.Threading; | |||
using FluentAssertions; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
[TestClass] | |||
public class EnforcedSinglethreadingTests : CApiTest | |||
{ | |||
private static readonly object _singlethreadLocker = new object(); | |||
/// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary> | |||
public EnforcedSinglethreadingTests() | |||
{ | |||
ops.IsSingleThreaded = true; | |||
} | |||
[TestMethod, Ignore("Has to be tested manually.")] | |||
public void SessionCreation() | |||
{ | |||
lock (_singlethreadLocker) | |||
{ | |||
ops.IsSingleThreaded.Should().BeTrue(); | |||
ops.uid(); //increment id by one | |||
//the core method | |||
tf.peak_default_graph().Should().BeNull(); | |||
using (var sess = tf.Session()) | |||
{ | |||
var default_graph = tf.peak_default_graph(); | |||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
sess_graph.Should().NotBeNull(); | |||
default_graph.Should().NotBeNull() | |||
.And.BeEquivalentTo(sess_graph); | |||
var (graph, session) = Parallely(() => (tf.get_default_graph(), tf.get_default_session())); | |||
graph.Should().BeEquivalentTo(default_graph); | |||
session.Should().BeEquivalentTo(sess); | |||
} | |||
} | |||
} | |||
T Parallely<T>(Func<T> fnc) | |||
{ | |||
var mrh = new ManualResetEventSlim(); | |||
T ret = default; | |||
Exception e = default; | |||
new Thread(() => | |||
{ | |||
try | |||
{ | |||
ret = fnc(); | |||
} catch (Exception ee) | |||
{ | |||
e = ee; | |||
throw; | |||
} finally | |||
{ | |||
mrh.Set(); | |||
} | |||
}).Start(); | |||
if (!Debugger.IsAttached) | |||
mrh.Wait(10000).Should().BeTrue(); | |||
else | |||
mrh.Wait(-1); | |||
e.Should().BeNull(e?.ToString()); | |||
return ret; | |||
} | |||
void Parallely(Action fnc) | |||
{ | |||
var mrh = new ManualResetEventSlim(); | |||
Exception e = default; | |||
new Thread(() => | |||
{ | |||
try | |||
{ | |||
fnc(); | |||
} catch (Exception ee) | |||
{ | |||
e = ee; | |||
throw; | |||
} finally | |||
{ | |||
mrh.Set(); | |||
} | |||
}).Start(); | |||
mrh.Wait(10000).Should().BeTrue(); | |||
e.Should().BeNull(e.ToString()); | |||
} | |||
} | |||
} |
@@ -283,14 +283,11 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
} | |||
private static string modelPath = "./model/"; | |||
private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); | |||
[TestMethod] | |||
public void TF_GraphOperationByName_FromModel() | |||
{ | |||
if (!Directory.Exists(modelPath)) | |||
return; | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
@@ -43,6 +43,9 @@ | |||
<None Update="model\saved_model.pb"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
<None Update="Utilities\models\example1\saved_model.pb"> | |||
<CopyToOutputDirectory>Always</CopyToOutputDirectory> | |||
</None> | |||
</ItemGroup> | |||
</Project> |