Browse Source

Supported for forced singlethreading

- Separated multithreading related methods to classname.threading.cs partial file
- ops: Added enforce_singlethreading(), enforce_multithreading()
tags/v0.20
Eli Belash 5 years ago
parent
commit
a4dbb4ed3d
9 changed files with 319 additions and 70 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Sessions/Session.cs
  2. +0
    -60
      src/TensorFlowNET.Core/ops.cs
  3. +152
    -0
      src/TensorFlowNET.Core/ops.threading.cs
  4. +2
    -4
      src/TensorFlowNET.Core/tensorflow.cs
  5. +53
    -0
      src/TensorFlowNET.Core/tensorflow.threading.cs
  6. +107
    -0
      test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs
  7. +1
    -4
      test/TensorFlowNET.UnitTest/MultithreadingTests.cs
  8. +3
    -0
      test/TensorFlowNET.UnitTest/UnitTest.csproj
  9. BIN
      test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb

+ 1
- 2
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -37,8 +37,7 @@ namespace Tensorflow

public Session as_default()
{
tf._defaultSessionFactory.Value = this;
return this;
return ops.set_default_session(this);
}

[MethodImpl(MethodImplOptions.NoOptimization)]


+ 0
- 60
src/TensorFlowNET.Core/ops.cs View File

@@ -28,10 +28,6 @@ namespace Tensorflow
{
public partial class ops
{
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());

public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value;

public static int tensor_id(Tensor tensor)
{
return tensor.Id;
@@ -78,53 +74,6 @@ namespace Tensorflow
return get_default_graph().get_collection_ref<T>(key);
}

/// <summary>
/// Returns the default graph for the current thread.
///
/// The returned graph will be the innermost graph on which a
/// `Graph.as_default()` context has been entered, or a global default
/// graph if none has been explicitly created.
///
/// NOTE: The default graph is a property of the current thread.If you
/// create a new thread, and wish to use the default graph in that
/// thread, you must explicitly add a `with g.as_default():` in that
/// thread's function.
/// </summary>
/// <returns></returns>
public static Graph get_default_graph()
{
//TODO: original source indicates there should be a _default_graph_stack!
//return _default_graph_stack.get_default()
return default_graph_stack.get_controller();
}

public static Graph set_default_graph(Graph graph)
{
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
default_graph_stack.set_controller(graph);
return default_graph_stack.get_controller();
}

/// <summary>
/// Clears the default graph stack and resets the global default graph.
///
/// NOTE: The default graph is a property of the current thread.This
/// function applies only to the current thread.Calling this function while
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
/// after calling this function will result in undefined behavior.
/// </summary>
/// <returns></returns>
public static void reset_default_graph()
{
//TODO: original source indicates there should be a _default_graph_stack!
//if (!_default_graph_stack.is_cleared())
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
// "nested graphs. If you need a cleared graph, " +
// "exit the nesting and create a new graph.");
default_graph_stack.reset();
}

public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null);

@@ -399,15 +348,6 @@ namespace Tensorflow
return session.run(tensor, feed_dict);
}

/// <summary>
/// Returns the default session for the current thread.
/// </summary>
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session()
{
return tf.defaultSession;
}

/// <summary>
/// Prepends name scope to a name.
/// </summary>


+ 152
- 0
src/TensorFlowNET.Core/ops.threading.cs View File

@@ -0,0 +1,152 @@
using System.Threading;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class ops
{
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());
private static volatile Session _singleSesson;
private static volatile DefaultGraphStack _singleGraphStack;
private static readonly object _threadingLock = new object();

public static DefaultGraphStack default_graph_stack
{
get
{
if (!isSingleThreaded)
return _defaultGraphFactory.Value;

if (_singleGraphStack == null)
{
lock (_threadingLock)
{
if (_singleGraphStack == null)
_singleGraphStack = new DefaultGraphStack();
}
}

return _singleGraphStack;
}
}

private static bool isSingleThreaded = false;

/// <summary>
/// Does this library ignore different thread accessing.
/// </summary>
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks>
public static bool IsSingleThreaded
{
get => isSingleThreaded;
set
{
if (value)
enforce_singlethreading();
else
enforce_multithreading();
}
}

/// <summary>
/// Forces the library to ignore different thread accessing.
/// </summary>
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks>
public static void enforce_singlethreading()
{
isSingleThreaded = true;
}

/// <summary>
/// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing.
/// </summary>
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks>
public static void enforce_multithreading()
{
isSingleThreaded = false;
}

/// <summary>
/// Returns the default session for the current thread.
/// </summary>
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session()
{
if (!isSingleThreaded)
return tf.defaultSession;

if (_singleSesson == null)
{
lock (_threadingLock)
{
if (_singleSesson == null)
_singleSesson = new Session();
}
}

return _singleSesson;
}

/// <summary>
/// Returns the default session for the current thread.
/// </summary>
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session set_default_session(Session sess)
{
if (!isSingleThreaded)
return tf.defaultSession = sess;

lock (_threadingLock)
{
_singleSesson = sess;
}

return _singleSesson;
}

/// <summary>
/// Returns the default graph for the current thread.
///
/// The returned graph will be the innermost graph on which a
/// `Graph.as_default()` context has been entered, or a global default
/// graph if none has been explicitly created.
///
/// NOTE: The default graph is a property of the current thread.If you
/// create a new thread, and wish to use the default graph in that
/// thread, you must explicitly add a `with g.as_default():` in that
/// thread's function.
/// </summary>
/// <returns></returns>
public static Graph get_default_graph()
{
//return _default_graph_stack.get_default()
return default_graph_stack.get_controller();
}

public static Graph set_default_graph(Graph graph)
{
default_graph_stack.set_controller(graph);
return default_graph_stack.get_controller();
}

/// <summary>
/// Clears the default graph stack and resets the global default graph.
///
/// NOTE: The default graph is a property of the current thread.This
/// function applies only to the current thread.Calling this function while
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
/// after calling this function will result in undefined behavior.
/// </summary>
/// <returns></returns>
public static void reset_default_graph()
{
//if (!_default_graph_stack.is_cleared())
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
// "nested graphs. If you need a cleared graph, " +
// "exit the nesting and create a new graph.");
default_graph_stack.reset();
}
}
}

+ 2
- 4
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -21,8 +21,6 @@ namespace Tensorflow
{
public partial class tensorflow : IObjectLife
{
protected internal readonly ThreadLocal<Session> _defaultSessionFactory;

public TF_DataType @byte = TF_DataType.TF_UINT8;
public TF_DataType @sbyte = TF_DataType.TF_INT8;
public TF_DataType int16 = TF_DataType.TF_INT16;
@@ -40,10 +38,10 @@ namespace Tensorflow

public tensorflow()
{
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session());
_constructThreadingObjects();
}

public Session defaultSession => _defaultSessionFactory.Value;

public RefVariable Variable<T>(T data,
bool trainable = true,


+ 53
- 0
src/TensorFlowNET.Core/tensorflow.threading.cs View File

@@ -0,0 +1,53 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.Runtime.CompilerServices;
using System.Threading;

namespace Tensorflow
{
public partial class tensorflow : IObjectLife
{
protected ThreadLocal<Session> _defaultSessionFactory;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void _constructThreadingObjects()
{
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session());
}

public Session defaultSession
{
get
{
if (!ops.IsSingleThreaded)
return _defaultSessionFactory.Value;

return ops.get_default_session();
}
internal set
{
if (!ops.IsSingleThreaded)
{
_defaultSessionFactory.Value = value;
return;
}

ops.set_default_session(value);
}
}
}
}

+ 107
- 0
test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs View File

@@ -0,0 +1,107 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class EnforcedSinglethreadingTests : CApiTest
{
private static readonly object _singlethreadLocker = new object();

/// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary>
public EnforcedSinglethreadingTests()
{
ops.IsSingleThreaded = true;
}

[TestMethod, Ignore("Has to be tested manually.")]
public void SessionCreation()
{
lock (_singlethreadLocker)
{
ops.IsSingleThreaded.Should().BeTrue();

ops.uid(); //increment id by one

//the core method
tf.peak_default_graph().Should().BeNull();

using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);

var (graph, session) = Parallely(() => (tf.get_default_graph(), tf.get_default_session()));

graph.Should().BeEquivalentTo(default_graph);
session.Should().BeEquivalentTo(sess);
}
}
}

T Parallely<T>(Func<T> fnc)
{
var mrh = new ManualResetEventSlim();
T ret = default;
Exception e = default;
new Thread(() =>
{
try
{
ret = fnc();
} catch (Exception ee)
{
e = ee;
throw;
} finally
{
mrh.Set();
}
}).Start();

if (!Debugger.IsAttached)
mrh.Wait(10000).Should().BeTrue();
else
mrh.Wait(-1);
e.Should().BeNull(e?.ToString());
return ret;
}

void Parallely(Action fnc)
{
var mrh = new ManualResetEventSlim();
Exception e = default;
new Thread(() =>
{
try
{
fnc();
} catch (Exception ee)
{
e = ee;
throw;
} finally
{
mrh.Set();
}
}).Start();

mrh.Wait(10000).Should().BeTrue();
e.Should().BeNull(e.ToString());
}
}
}

+ 1
- 4
test/TensorFlowNET.UnitTest/MultithreadingTests.cs View File

@@ -283,14 +283,11 @@ namespace TensorFlowNET.UnitTest
}
}

private static string modelPath = "./model/";
private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/");

[TestMethod]
public void TF_GraphOperationByName_FromModel()
{
if (!Directory.Exists(modelPath))
return;

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method


+ 3
- 0
test/TensorFlowNET.UnitTest/UnitTest.csproj View File

@@ -43,6 +43,9 @@
<None Update="model\saved_model.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Utilities\models\example1\saved_model.pb">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>

BIN
test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb View File


Loading…
Cancel
Save